diff --git a/server/internal/utils/ip.go b/server/internal/utils/ip.go index 02011c00..8bde7b0e 100644 --- a/server/internal/utils/ip.go +++ b/server/internal/utils/ip.go @@ -39,12 +39,12 @@ func GetIP(serverUrl string) (string, error) { return string(bytes.TrimSpace(buf)), nil } -func ReadUserIP(r *http.Request) string { +func GetHttpRequestIP(r *http.Request, proxy bool) string { IPAddress := r.Header.Get("X-Real-Ip") if IPAddress == "" { IPAddress = r.Header.Get("X-Forwarded-For") } - if IPAddress == "" { + if IPAddress == "" || !proxy { IPAddress = r.RemoteAddr } return IPAddress diff --git a/server/internal/websocket/handler.go b/server/internal/websocket/handler.go index fb86731d..1d290390 100644 --- a/server/internal/websocket/handler.go +++ b/server/internal/websocket/handler.go @@ -22,7 +22,7 @@ type MessageHandler struct { locked map[string]string } -func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string, error) { +func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string) { address := socket.Address() if address == "" { h.logger.Debug().Msg("no remote address") @@ -30,17 +30,17 @@ func (h *MessageHandler) Connected(admin bool, socket *WebSocket) (bool, string, ok, banned := h.banned[address] if ok && banned { h.logger.Debug().Str("address", address).Msg("banned") - return false, "banned", nil + return false, "banned" } } _, ok := h.locked["login"] if ok && !admin { h.logger.Debug().Msg("server locked") - return false, "locked", nil + return false, "locked" } - return true, "", nil + return true, "" } func (h *MessageHandler) Disconnected(id string) { diff --git a/server/internal/websocket/websocket.go b/server/internal/websocket/websocket.go index c7643e37..2da6eb1d 100644 --- a/server/internal/websocket/websocket.go +++ b/server/internal/websocket/websocket.go @@ -187,13 +187,19 @@ func (ws *WebSocketHandler) Shutdown() error { func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error { ws.logger.Debug().Msg("attempting to upgrade connection") + id, err := utils.NewUID(32) + if err != nil { + ws.logger.Error().Err(err).Msg("failed to generate user id") + return err + } + connection, err := ws.upgrader.Upgrade(w, r, nil) if err != nil { ws.logger.Error().Err(err).Msg("failed to upgrade connection") return err } - id, ip, admin, err := ws.authenticate(r) + admin, err := ws.authenticate(r) if err != nil { ws.logger.Warn().Err(err).Msg("authentication failed") @@ -213,16 +219,11 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro socket := &WebSocket{ id: id, ws: ws, - address: ip, + address: utils.GetHttpRequestIP(r, ws.conf.Proxy), connection: connection, } - ok, reason, err := ws.handler.Connected(admin, socket) - if err != nil { - ws.logger.Error().Err(err).Msg("connection failed") - return err - } - + ok, reason := ws.handler.Connected(admin, socket) if !ok { if err = connection.WriteJSON(message.SystemMessage{ Event: event.SYSTEM_DISCONNECT, @@ -288,25 +289,13 @@ func (ws *WebSocketHandler) IsAdmin(password string) (bool, error) { return false, fmt.Errorf("invalid password") } -func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) { - ip := r.RemoteAddr - - if ws.conf.Proxy { - ip = utils.ReadUserIP(r) - } - - id, err := utils.NewUID(32) - if err != nil { - return "", ip, false, err - } - +func (ws *WebSocketHandler) authenticate(r *http.Request) (bool, error) { passwords, ok := r.URL.Query()["password"] if !ok || len(passwords[0]) < 1 { - return "", ip, false, fmt.Errorf("no password provided") + return false, fmt.Errorf("no password provided") } - isAdmin, err := ws.IsAdmin(passwords[0]) - return id, ip, isAdmin, err + return ws.IsAdmin(passwords[0]) } func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {