should resolve #46

This commit is contained in:
Craig 2020-04-05 03:49:43 +00:00
parent 19466b5625
commit 362cf6c254
8 changed files with 39 additions and 24 deletions

View File

@ -42,9 +42,9 @@ func (session *Session) Connected() bool {
return session.connected
}
func (session *Session) Address() *string {
func (session *Session) Address() string {
if session.socket == nil {
return nil
return ""
}
return session.socket.Address()
}

View File

@ -19,7 +19,7 @@ type Session interface {
SetConnected(connected bool) error
SetSocket(socket WebSocket) error
SetPeer(peer Peer) error
Address() *string
Address() string
Kick(message string) error
Write(v interface{}) error
Send(v interface{}) error

View File

@ -3,7 +3,7 @@ package types
import "net/http"
type WebSocket interface {
Address() *string
Address() string
Send(v interface{}) error
Destroy() error
}

View File

@ -22,3 +22,14 @@ func GetIP() (string, error) {
return string(bytes.TrimSpace(buf)), nil
}
func ReadUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
}
return IPAddress
}

View File

@ -265,18 +265,18 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
}
remote := target.Address()
if remote == nil {
if remote == "" {
h.logger.Debug().Msg("no remote address, baling")
return nil
}
address := strings.SplitN(*remote, ":", -1)
address := strings.SplitN(remote, ":", -1)
if len(address[0]) < 1 {
h.logger.Debug().Str("address", *remote).Msg("no remote address, baling")
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
return nil
}
h.logger.Debug().Str("address", *remote).Msg("adding address to banned")
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
h.banned[address[0]] = true

View File

@ -22,12 +22,12 @@ type MessageHandler struct {
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
address := socket.Address()
if address == nil {
h.logger.Debug().Msg("no remote address, baling")
if address == "" {
h.logger.Debug().Msg("no remote address")
} else {
ok, banned := h.banned[*address]
ok, banned := h.banned[address]
if ok && banned {
h.logger.Debug().Str("address", *address).Msg("banned")
h.logger.Debug().Str("address", address).Msg("banned")
return false, "This IP has been banned", nil
}
}

View File

@ -10,18 +10,19 @@ import (
type WebSocket struct {
id string
address string
ws *WebSocketHandler
connection *websocket.Conn
mu sync.Mutex
}
func (socket *WebSocket) Address() *string {
remote := socket.connection.RemoteAddr()
address := strings.SplitN(remote.String(), ":", -1)
func (socket *WebSocket) Address() string {
//remote := socket.connection.RemoteAddr()
address := strings.SplitN(socket.address, ":", -1)
if len(address[0]) < 1 {
return nil
return socket.address
}
return &address[0]
return address[0]
}
func (socket *WebSocket) Send(v interface{}) error {

View File

@ -123,7 +123,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
return err
}
id, admin, err := ws.authenticate(r)
id, ip, admin, err := ws.authenticate(r)
if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed")
@ -143,6 +143,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
socket := &WebSocket{
id: id,
ws: ws,
address: ip,
connection: connection,
}
@ -187,26 +188,28 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
return nil
}
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) {
ip := utils.ReadUserIP(r)
id, err := utils.NewUID(32)
if err != nil {
return "", false, err
return "", ip, false, err
}
passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 {
return "", false, fmt.Errorf("no password provided")
return "", ip, false, fmt.Errorf("no password provided")
}
if passwords[0] == ws.conf.AdminPassword {
return id, true, nil
return id, ip, true, nil
}
if passwords[0] == ws.conf.Password {
return id, false, nil
return id, ip, false, nil
}
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
}
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {