mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
should resolve #46
This commit is contained in:
parent
19466b5625
commit
362cf6c254
@ -42,9 +42,9 @@ func (session *Session) Connected() bool {
|
||||
return session.connected
|
||||
}
|
||||
|
||||
func (session *Session) Address() *string {
|
||||
func (session *Session) Address() string {
|
||||
if session.socket == nil {
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
return session.socket.Address()
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ type Session interface {
|
||||
SetConnected(connected bool) error
|
||||
SetSocket(socket WebSocket) error
|
||||
SetPeer(peer Peer) error
|
||||
Address() *string
|
||||
Address() string
|
||||
Kick(message string) error
|
||||
Write(v interface{}) error
|
||||
Send(v interface{}) error
|
||||
|
@ -3,7 +3,7 @@ package types
|
||||
import "net/http"
|
||||
|
||||
type WebSocket interface {
|
||||
Address() *string
|
||||
Address() string
|
||||
Send(v interface{}) error
|
||||
Destroy() error
|
||||
}
|
||||
|
@ -22,3 +22,14 @@ func GetIP() (string, error) {
|
||||
|
||||
return string(bytes.TrimSpace(buf)), nil
|
||||
}
|
||||
|
||||
func ReadUserIP(r *http.Request) string {
|
||||
IPAddress := r.Header.Get("X-Real-Ip")
|
||||
if IPAddress == "" {
|
||||
IPAddress = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
if IPAddress == "" {
|
||||
IPAddress = r.RemoteAddr
|
||||
}
|
||||
return IPAddress
|
||||
}
|
||||
|
@ -265,18 +265,18 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
|
||||
}
|
||||
|
||||
remote := target.Address()
|
||||
if remote == nil {
|
||||
if remote == "" {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
address := strings.SplitN(*remote, ":", -1)
|
||||
address := strings.SplitN(remote, ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
h.logger.Debug().Str("address", *remote).Msg("no remote address, baling")
|
||||
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
h.logger.Debug().Str("address", *remote).Msg("adding address to banned")
|
||||
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
|
||||
|
||||
h.banned[address[0]] = true
|
||||
|
||||
|
@ -22,12 +22,12 @@ type MessageHandler struct {
|
||||
|
||||
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
|
||||
address := socket.Address()
|
||||
if address == nil {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
if address == "" {
|
||||
h.logger.Debug().Msg("no remote address")
|
||||
} else {
|
||||
ok, banned := h.banned[*address]
|
||||
ok, banned := h.banned[address]
|
||||
if ok && banned {
|
||||
h.logger.Debug().Str("address", *address).Msg("banned")
|
||||
h.logger.Debug().Str("address", address).Msg("banned")
|
||||
return false, "This IP has been banned", nil
|
||||
}
|
||||
}
|
||||
|
@ -10,18 +10,19 @@ import (
|
||||
|
||||
type WebSocket struct {
|
||||
id string
|
||||
address string
|
||||
ws *WebSocketHandler
|
||||
connection *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Address() *string {
|
||||
remote := socket.connection.RemoteAddr()
|
||||
address := strings.SplitN(remote.String(), ":", -1)
|
||||
func (socket *WebSocket) Address() string {
|
||||
//remote := socket.connection.RemoteAddr()
|
||||
address := strings.SplitN(socket.address, ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
return nil
|
||||
return socket.address
|
||||
}
|
||||
return &address[0]
|
||||
return address[0]
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Send(v interface{}) error {
|
||||
|
@ -123,7 +123,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
|
||||
return err
|
||||
}
|
||||
|
||||
id, admin, err := ws.authenticate(r)
|
||||
id, ip, admin, err := ws.authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authentication failed")
|
||||
|
||||
@ -143,6 +143,7 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
|
||||
socket := &WebSocket{
|
||||
id: id,
|
||||
ws: ws,
|
||||
address: ip,
|
||||
connection: connection,
|
||||
}
|
||||
|
||||
@ -187,26 +188,28 @@ func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) {
|
||||
ip := utils.ReadUserIP(r)
|
||||
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
return "", ip, false, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", false, fmt.Errorf("no password provided")
|
||||
return "", ip, false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.AdminPassword {
|
||||
return id, true, nil
|
||||
return id, ip, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.Password {
|
||||
return id, false, nil
|
||||
return id, ip, false, nil
|
||||
}
|
||||
|
||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user