From 311ed987d815f9453f49ddc25b0e4235d0eb9836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Sun, 1 Nov 2020 20:23:09 +0100 Subject: [PATCH] refactor WS authentication. --- internal/session/auth.go | 61 +++++++++++++++++++++++------------ internal/session/manager.go | 23 +++++++------ internal/session/session.go | 1 + internal/types/session.go | 4 +-- internal/websocket/manager.go | 22 ++++++++----- 5 files changed, 68 insertions(+), 43 deletions(-) diff --git a/internal/session/auth.go b/internal/session/auth.go index 9fbb7f95..7d9e8c51 100644 --- a/internal/session/auth.go +++ b/internal/session/auth.go @@ -3,35 +3,54 @@ package session import ( "fmt" "net/http" + "strings" + "demodesk/neko/internal/types" "demodesk/neko/internal/utils" ) -// TODO: Refactor -func (manager *SessionManagerCtx) Authenticate(r *http.Request) (string, string, bool, error) { - ip := r.RemoteAddr +const ( + token_name = "password" +) - //if ws.conf.Proxy { - // ip = utils.ReadUserIP(r) - //} +func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) { + token := getToken(r) + if token == "" { + return nil, fmt.Errorf("no password provided") + } + + isAdmin := (token == manager.config.AdminPassword) + isUser := (token == manager.config.Password) + + if !isAdmin && !isUser { + return nil, fmt.Errorf("invalid password") + } id, err := utils.NewUID(32) if err != nil { - return "", ip, false, err + return nil, err } - passwords, ok := r.URL.Query()["password"] - if !ok || len(passwords[0]) < 1 { - return "", ip, false, fmt.Errorf("no password provided") - } - - if passwords[0] == manager.config.AdminPassword { - return id, ip, true, nil - } - - if passwords[0] == manager.config.Password { - return id, ip, false, nil - } - - return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0]) + return manager.New(id, isAdmin), nil +} + +func getToken(r *http.Request) string { + // Get token from query + if token := r.URL.Query().Get(token_name); token != "" { + return token + } + + // Get token from authorization header + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { + return bearer[7:] + } + + // Get token from cookie + cookie, err := r.Cookie(token_name) + if err == nil { + return cookie.Value + } + + return "" } diff --git a/internal/session/manager.go b/internal/session/manager.go index 6bb113d9..53d5da5a 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -30,23 +30,16 @@ type SessionManagerCtx struct { emmiter events.EventEmmiter } -func (manager *SessionManagerCtx) New(id string, admin bool, socket types.WebSocket) types.Session { +func (manager *SessionManagerCtx) New(id string, admin bool) types.Session { session := &SessionCtx{ id: id, admin: admin, manager: manager, - socket: socket, logger: manager.logger.With().Str("id", id).Logger(), connected: false, } manager.members[id] = session - manager.emmiter.Emit("created", session) - - if !manager.capture.Streaming() && len(manager.members) > 0 { - manager.capture.StartStream() - } - return session } @@ -66,10 +59,6 @@ func (manager *SessionManagerCtx) Destroy(id string) error { delete(manager.members, id) err := session.destroy() - if !manager.capture.Streaming() && len(manager.members) <= 0 { - manager.capture.StopStream() - } - manager.emmiter.Emit("destroy", id) return err } @@ -164,12 +153,22 @@ func (manager *SessionManagerCtx) OnHostCleared(listener func(session types.Sess func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) { manager.emmiter.On("destroy", func(payload ...interface{}) { + // Stop streaming, if everyone left + if manager.capture.Streaming() && len(manager.members) == 0 { + manager.capture.StopStream() + } + listener(payload[0].(string)) }) } func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) { manager.emmiter.On("created", func(payload ...interface{}) { + // Start streaming, when first joins + if !manager.capture.Streaming() { + manager.capture.StartStream() + } + listener(payload[0].(*SessionCtx)) }) } diff --git a/internal/session/session.go b/internal/session/session.go index 574d79bb..80be221d 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -62,6 +62,7 @@ func (session *SessionCtx) SetName(name string) { func (session *SessionCtx) SetSocket(socket types.WebSocket) { session.socket = socket + session.manager.emmiter.Emit("created", session) } func (session *SessionCtx) SetPeer(peer types.Peer) { diff --git a/internal/types/session.go b/internal/types/session.go index e12e53bd..ae9b6b8b 100644 --- a/internal/types/session.go +++ b/internal/types/session.go @@ -21,7 +21,7 @@ type Session interface { } type SessionManager interface { - New(id string, admin bool, socket WebSocket) Session + New(id string, admin bool) Session Get(id string) (Session, bool) Has(id string) bool Destroy(id string) error @@ -42,5 +42,5 @@ type SessionManager interface { OnConnected(listener func(session Session)) // auth - Authenticate(r *http.Request) (string, string, bool, error) + Authenticate(r *http.Request) (Session, error) } diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index bef21bb4..1d5ba136 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -127,14 +127,14 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e return err } - id, ip, admin, err := ws.sessions.Authenticate(r) + session, err := ws.sessions.Authenticate(r) if err != nil { ws.logger.Warn().Err(err).Msg("authentication failed") // TODO: Refactor if err = connection.WriteJSON(message.Disconnect{ Event: event.SYSTEM_DISCONNECT, - Message: "invalid_password", + Message: "authentication failed", }); err != nil { ws.logger.Error().Err(err).Msg("failed to send disconnect") } @@ -142,14 +142,20 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e return connection.Close() } + // TODO: Refactor. + ip := r.RemoteAddr + // if allow poxy { + // ip = utils.ReadUserIP(r) + // } + socket := &WebSocketCtx{ - id: id, + id: session.ID(), ws: ws, address: ip, connection: connection, } - ok, reason := ws.handler.Connected(id, socket) + ok, reason := ws.handler.Connected(session.ID(), socket) if !ok { // TODO: Refactor if err = connection.WriteJSON(message.Disconnect{ @@ -162,23 +168,23 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e return connection.Close() } - ws.sessions.New(id, admin, socket) + session.SetSocket(socket) ws.logger. Debug(). - Str("session", id). + Str("session", session.ID()). Str("address", connection.RemoteAddr().String()). Msg("new connection created") defer func() { ws.logger. Debug(). - Str("session", id). + Str("session", session.ID()). Str("address", connection.RemoteAddr().String()). Msg("session ended") }() - ws.handle(connection, id) + ws.handle(connection, session.ID()) return nil }