should resolve #46

This commit is contained in:
Craig 2020-04-05 03:49:43 +00:00
parent 19466b5625
commit 362cf6c254
8 changed files with 39 additions and 24 deletions

View File

@ -42,9 +42,9 @@ func (session *Session) Connected() bool {
return session.connected return session.connected
} }
func (session *Session) Address() *string { func (session *Session) Address() string {
if session.socket == nil { if session.socket == nil {
return nil return ""
} }
return session.socket.Address() return session.socket.Address()
} }

View File

@ -19,7 +19,7 @@ type Session interface {
SetConnected(connected bool) error SetConnected(connected bool) error
SetSocket(socket WebSocket) error SetSocket(socket WebSocket) error
SetPeer(peer Peer) error SetPeer(peer Peer) error
Address() *string Address() string
Kick(message string) error Kick(message string) error
Write(v interface{}) error Write(v interface{}) error
Send(v interface{}) error Send(v interface{}) error

View File

@ -3,7 +3,7 @@ package types
import "net/http" import "net/http"
type WebSocket interface { type WebSocket interface {
Address() *string Address() string
Send(v interface{}) error Send(v interface{}) error
Destroy() error Destroy() error
} }

View File

@ -22,3 +22,14 @@ func GetIP() (string, error) {
return string(bytes.TrimSpace(buf)), nil return string(bytes.TrimSpace(buf)), nil
} }
func ReadUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
}
return IPAddress
}

View File

@ -265,18 +265,18 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
} }
remote := target.Address() remote := target.Address()
if remote == nil { if remote == "" {
h.logger.Debug().Msg("no remote address, baling") h.logger.Debug().Msg("no remote address, baling")
return nil return nil
} }
address := strings.SplitN(*remote, ":", -1) address := strings.SplitN(remote, ":", -1)
if len(address[0]) < 1 { if len(address[0]) < 1 {
h.logger.Debug().Str("address", *remote).Msg("no remote address, baling") h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
return nil return nil
} }
h.logger.Debug().Str("address", *remote).Msg("adding address to banned") h.logger.Debug().Str("address", remote).Msg("adding address to banned")
h.banned[address[0]] = true h.banned[address[0]] = true

View File

@ -22,12 +22,12 @@ type MessageHandler struct {
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) { func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
address := socket.Address() address := socket.Address()
if address == nil { if address == "" {
h.logger.Debug().Msg("no remote address, baling") h.logger.Debug().Msg("no remote address")
} else { } else {
ok, banned := h.banned[*address] ok, banned := h.banned[address]
if ok && banned { if ok && banned {
h.logger.Debug().Str("address", *address).Msg("banned") h.logger.Debug().Str("address", address).Msg("banned")
return false, "This IP has been banned", nil return false, "This IP has been banned", nil
} }
} }

View File

@ -10,18 +10,19 @@ import (
type WebSocket struct { type WebSocket struct {
id string id string
address string
ws *WebSocketHandler ws *WebSocketHandler
connection *websocket.Conn connection *websocket.Conn
mu sync.Mutex mu sync.Mutex
} }
func (socket *WebSocket) Address() *string { func (socket *WebSocket) Address() string {
remote := socket.connection.RemoteAddr() //remote := socket.connection.RemoteAddr()
address := strings.SplitN(remote.String(), ":", -1) address := strings.SplitN(socket.address, ":", -1)
if len(address[0]) < 1 { if len(address[0]) < 1 {
return nil return socket.address
} }
return &address[0] return address[0]
} }
func (socket *WebSocket) Send(v interface{}) error { func (socket *WebSocket) Send(v interface{}) error {

View File

@ -123,7 +123,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
return err return err
} }
id, admin, err := ws.authenticate(r) id, ip, admin, err := ws.authenticate(r)
if err != nil { if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed") ws.logger.Warn().Err(err).Msg("authentication failed")
@ -143,6 +143,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
socket := &WebSocket{ socket := &WebSocket{
id: id, id: id,
ws: ws, ws: ws,
address: ip,
connection: connection, connection: connection,
} }
@ -187,26 +188,28 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
return nil return nil
} }
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) { func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) {
ip := utils.ReadUserIP(r)
id, err := utils.NewUID(32) id, err := utils.NewUID(32)
if err != nil { if err != nil {
return "", false, err return "", ip, false, err
} }
passwords, ok := r.URL.Query()["password"] passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 { if !ok || len(passwords[0]) < 1 {
return "", false, fmt.Errorf("no password provided") return "", ip, false, fmt.Errorf("no password provided")
} }
if passwords[0] == ws.conf.AdminPassword { if passwords[0] == ws.conf.AdminPassword {
return id, true, nil return id, ip, true, nil
} }
if passwords[0] == ws.conf.Password { if passwords[0] == ws.conf.Password {
return id, false, nil return id, ip, false, nil
} }
return "", false, fmt.Errorf("invalid password: %s", passwords[0]) return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
} }
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) { func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {