diff --git a/server/internal/session/session.go b/server/internal/session/session.go index 78f225e0..b9733a78 100644 --- a/server/internal/session/session.go +++ b/server/internal/session/session.go @@ -42,9 +42,9 @@ func (session *Session) Connected() bool { return session.connected } -func (session *Session) Address() *string { +func (session *Session) Address() string { if session.socket == nil { - return nil + return "" } return session.socket.Address() } diff --git a/server/internal/types/session.go b/server/internal/types/session.go index fb6da7a0..1c8acabd 100644 --- a/server/internal/types/session.go +++ b/server/internal/types/session.go @@ -19,7 +19,7 @@ type Session interface { SetConnected(connected bool) error SetSocket(socket WebSocket) error SetPeer(peer Peer) error - Address() *string + Address() string Kick(message string) error Write(v interface{}) error Send(v interface{}) error diff --git a/server/internal/types/webscoket.go b/server/internal/types/webscoket.go index cbfc738d..d5b63360 100644 --- a/server/internal/types/webscoket.go +++ b/server/internal/types/webscoket.go @@ -3,7 +3,7 @@ package types import "net/http" type WebSocket interface { - Address() *string + Address() string Send(v interface{}) error Destroy() error } diff --git a/server/internal/utils/ip.go b/server/internal/utils/ip.go index 69798a86..67c7d38d 100644 --- a/server/internal/utils/ip.go +++ b/server/internal/utils/ip.go @@ -22,3 +22,14 @@ func GetIP() (string, error) { return string(bytes.TrimSpace(buf)), nil } + +func ReadUserIP(r *http.Request) string { + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + } + return IPAddress +} diff --git a/server/internal/websocket/admin.go b/server/internal/websocket/admin.go index c63ee0b4..192d4c1c 100644 --- a/server/internal/websocket/admin.go +++ b/server/internal/websocket/admin.go @@ -265,18 +265,18 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes } remote := target.Address() - if remote == nil { + if remote == "" { h.logger.Debug().Msg("no remote address, baling") return nil } - address := strings.SplitN(*remote, ":", -1) + address := strings.SplitN(remote, ":", -1) if len(address[0]) < 1 { - h.logger.Debug().Str("address", *remote).Msg("no remote address, baling") + h.logger.Debug().Str("address", remote).Msg("no remote address, baling") return nil } - h.logger.Debug().Str("address", *remote).Msg("adding address to banned") + h.logger.Debug().Str("address", remote).Msg("adding address to banned") h.banned[address[0]] = true diff --git a/server/internal/websocket/handler.go b/server/internal/websocket/handler.go index 1895ade3..61f6dca1 100644 --- a/server/internal/websocket/handler.go +++ b/server/internal/websocket/handler.go @@ -22,12 +22,12 @@ type MessageHandler struct { func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) { address := socket.Address() - if address == nil { - h.logger.Debug().Msg("no remote address, baling") + if address == "" { + h.logger.Debug().Msg("no remote address") } else { - ok, banned := h.banned[*address] + ok, banned := h.banned[address] if ok && banned { - h.logger.Debug().Str("address", *address).Msg("banned") + h.logger.Debug().Str("address", address).Msg("banned") return false, "This IP has been banned", nil } } diff --git a/server/internal/websocket/socket.go b/server/internal/websocket/socket.go index 90dbbe77..c9875bfd 100644 --- a/server/internal/websocket/socket.go +++ b/server/internal/websocket/socket.go @@ -10,18 +10,19 @@ import ( type WebSocket struct { id string + address string ws *WebSocketHandler connection *websocket.Conn mu sync.Mutex } -func (socket *WebSocket) Address() *string { - remote := socket.connection.RemoteAddr() - address := strings.SplitN(remote.String(), ":", -1) +func (socket *WebSocket) Address() string { + //remote := socket.connection.RemoteAddr() + address := strings.SplitN(socket.address, ":", -1) if len(address[0]) < 1 { - return nil + return socket.address } - return &address[0] + return address[0] } func (socket *WebSocket) Send(v interface{}) error { diff --git a/server/internal/websocket/websocket.go b/server/internal/websocket/websocket.go index 414d0f30..0ff594f3 100644 --- a/server/internal/websocket/websocket.go +++ b/server/internal/websocket/websocket.go @@ -123,7 +123,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro return err } - id, admin, err := ws.authenticate(r) + id, ip, admin, err := ws.authenticate(r) if err != nil { ws.logger.Warn().Err(err).Msg("authentication failed") @@ -143,6 +143,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro socket := &WebSocket{ id: id, ws: ws, + address: ip, connection: connection, } @@ -187,26 +188,28 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro return nil } -func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) { +func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) { + ip := utils.ReadUserIP(r) + id, err := utils.NewUID(32) if err != nil { - return "", false, err + return "", ip, false, err } passwords, ok := r.URL.Query()["password"] if !ok || len(passwords[0]) < 1 { - return "", false, fmt.Errorf("no password provided") + return "", ip, false, fmt.Errorf("no password provided") } if passwords[0] == ws.conf.AdminPassword { - return id, true, nil + return id, ip, true, nil } if passwords[0] == ws.conf.Password { - return id, false, nil + return id, ip, false, nil } - return "", false, fmt.Errorf("invalid password: %s", passwords[0]) + return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0]) } func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {