refactor authentication code.

This commit is contained in:
Miroslav Šedivý 2021-11-17 18:00:27 +01:00
parent bc961c5170
commit c4d67d416e
3 changed files with 18 additions and 29 deletions

View File

@ -39,12 +39,12 @@ func GetIP(serverUrl string) (string, error) {
return string(bytes.TrimSpace(buf)), nil return string(bytes.TrimSpace(buf)), nil
} }
func ReadUserIP(r *http.Request) string { func GetHttpRequestIP(r *http.Request, proxy bool) string {
IPAddress := r.Header.Get("X-Real-Ip") IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" { if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For") IPAddress = r.Header.Get("X-Forwarded-For")
} }
if IPAddress == "" { if IPAddress == "" || !proxy {
IPAddress = r.RemoteAddr IPAddress = r.RemoteAddr
} }
return IPAddress return IPAddress

View File

@ -22,7 +22,7 @@ type MessageHandler struct {
locked map[string]string locked map[string]string
} }
func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string, error) { func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string) {
address := socket.Address() address := socket.Address()
if address == "" { if address == "" {
h.logger.Debug().Msg("no remote address") h.logger.Debug().Msg("no remote address")
@ -30,17 +30,17 @@ func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string,
ok, banned := h.banned[address] ok, banned := h.banned[address]
if ok && banned { if ok && banned {
h.logger.Debug().Str("address", address).Msg("banned") h.logger.Debug().Str("address", address).Msg("banned")
return false, "banned", nil return false, "banned"
} }
} }
_, ok := h.locked["login"] _, ok := h.locked["login"]
if ok && !admin { if ok && !admin {
h.logger.Debug().Msg("server locked") h.logger.Debug().Msg("server locked")
return false, "locked", nil return false, "locked"
} }
return true, "", nil return true, ""
} }
func (h *MessageHandler) Disconnected(id string) { func (h *MessageHandler) Disconnected(id string) {

View File

@ -187,13 +187,19 @@ func (ws *WebSocketHandler) Shutdown() error {
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error { func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
ws.logger.Debug().Msg("attempting to upgrade connection") ws.logger.Debug().Msg("attempting to upgrade connection")
id, err := utils.NewUID(32)
if err != nil {
ws.logger.Error().Err(err).Msg("failed to generate user id")
return err
}
connection, err := ws.upgrader.Upgrade(w, r, nil) connection, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection") ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err return err
} }
id, ip, admin, err := ws.authenticate(r) admin, err := ws.authenticate(r)
if err != nil { if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed") ws.logger.Warn().Err(err).Msg("authentication failed")
@ -213,16 +219,11 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
socket := &WebSocket{ socket := &WebSocket{
id: id, id: id,
ws: ws, ws: ws,
address: ip, address: utils.GetHttpRequestIP(r, ws.conf.Proxy),
connection: connection, connection: connection,
} }
ok, reason, err := ws.handler.Connected(admin, socket) ok, reason := ws.handler.Connected(admin, socket)
if err != nil {
ws.logger.Error().Err(err).Msg("connection failed")
return err
}
if !ok { if !ok {
if err = connection.WriteJSON(message.SystemMessage{ if err = connection.WriteJSON(message.SystemMessage{
Event: event.SYSTEM_DISCONNECT, Event: event.SYSTEM_DISCONNECT,
@ -288,25 +289,13 @@ func (ws *WebSocketHandler) IsAdmin(password string) (bool, error) {
return false, fmt.Errorf("invalid password") return false, fmt.Errorf("invalid password")
} }
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) { func (ws *WebSocketHandler) authenticate(r *http.Request) (bool, error) {
ip := r.RemoteAddr
if ws.conf.Proxy {
ip = utils.ReadUserIP(r)
}
id, err := utils.NewUID(32)
if err != nil {
return "", ip, false, err
}
passwords, ok := r.URL.Query()["password"] passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 { if !ok || len(passwords[0]) < 1 {
return "", ip, false, fmt.Errorf("no password provided") return false, fmt.Errorf("no password provided")
} }
isAdmin, err := ws.IsAdmin(passwords[0]) return ws.IsAdmin(passwords[0])
return id, ip, isAdmin, err
} }
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) { func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {