refactor WS authentication.

This commit is contained in:
Miroslav Šedivý 2020-11-01 20:23:09 +01:00
parent 16d762b6ae
commit 311ed987d8
5 changed files with 68 additions and 43 deletions

View File

@ -3,35 +3,54 @@ package session
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils" "demodesk/neko/internal/utils"
) )
// TODO: Refactor const (
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (string, string, bool, error) { token_name = "password"
ip := r.RemoteAddr )
//if ws.conf.Proxy { func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) {
// ip = utils.ReadUserIP(r) token := getToken(r)
//} if token == "" {
return nil, fmt.Errorf("no password provided")
}
isAdmin := (token == manager.config.AdminPassword)
isUser := (token == manager.config.Password)
if !isAdmin && !isUser {
return nil, fmt.Errorf("invalid password")
}
id, err := utils.NewUID(32) id, err := utils.NewUID(32)
if err != nil { if err != nil {
return "", ip, false, err return nil, err
} }
passwords, ok := r.URL.Query()["password"] return manager.New(id, isAdmin), nil
if !ok || len(passwords[0]) < 1 { }
return "", ip, false, fmt.Errorf("no password provided")
} func getToken(r *http.Request) string {
// Get token from query
if passwords[0] == manager.config.AdminPassword { if token := r.URL.Query().Get(token_name); token != "" {
return id, ip, true, nil return token
} }
if passwords[0] == manager.config.Password { // Get token from authorization header
return id, ip, false, nil bearer := r.Header.Get("Authorization")
} if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0]) }
// Get token from cookie
cookie, err := r.Cookie(token_name)
if err == nil {
return cookie.Value
}
return ""
} }

View File

@ -30,23 +30,16 @@ type SessionManagerCtx struct {
emmiter events.EventEmmiter emmiter events.EventEmmiter
} }
func (manager *SessionManagerCtx) New(id string, admin bool, socket types.WebSocket) types.Session { func (manager *SessionManagerCtx) New(id string, admin bool) types.Session {
session := &SessionCtx{ session := &SessionCtx{
id: id, id: id,
admin: admin, admin: admin,
manager: manager, manager: manager,
socket: socket,
logger: manager.logger.With().Str("id", id).Logger(), logger: manager.logger.With().Str("id", id).Logger(),
connected: false, connected: false,
} }
manager.members[id] = session manager.members[id] = session
manager.emmiter.Emit("created", session)
if !manager.capture.Streaming() && len(manager.members) > 0 {
manager.capture.StartStream()
}
return session return session
} }
@ -66,10 +59,6 @@ func (manager *SessionManagerCtx) Destroy(id string) error {
delete(manager.members, id) delete(manager.members, id)
err := session.destroy() err := session.destroy()
if !manager.capture.Streaming() && len(manager.members) <= 0 {
manager.capture.StopStream()
}
manager.emmiter.Emit("destroy", id) manager.emmiter.Emit("destroy", id)
return err return err
} }
@ -164,12 +153,22 @@ func (manager *SessionManagerCtx) OnHostCleared(listener func(session types.Sess
func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) { func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) {
manager.emmiter.On("destroy", func(payload ...interface{}) { manager.emmiter.On("destroy", func(payload ...interface{}) {
// Stop streaming, if everyone left
if manager.capture.Streaming() && len(manager.members) == 0 {
manager.capture.StopStream()
}
listener(payload[0].(string)) listener(payload[0].(string))
}) })
} }
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) { func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
manager.emmiter.On("created", func(payload ...interface{}) { manager.emmiter.On("created", func(payload ...interface{}) {
// Start streaming, when first joins
if !manager.capture.Streaming() {
manager.capture.StartStream()
}
listener(payload[0].(*SessionCtx)) listener(payload[0].(*SessionCtx))
}) })
} }

View File

@ -62,6 +62,7 @@ func (session *SessionCtx) SetName(name string) {
func (session *SessionCtx) SetSocket(socket types.WebSocket) { func (session *SessionCtx) SetSocket(socket types.WebSocket) {
session.socket = socket session.socket = socket
session.manager.emmiter.Emit("created", session)
} }
func (session *SessionCtx) SetPeer(peer types.Peer) { func (session *SessionCtx) SetPeer(peer types.Peer) {

View File

@ -21,7 +21,7 @@ type Session interface {
} }
type SessionManager interface { type SessionManager interface {
New(id string, admin bool, socket WebSocket) Session New(id string, admin bool) Session
Get(id string) (Session, bool) Get(id string) (Session, bool)
Has(id string) bool Has(id string) bool
Destroy(id string) error Destroy(id string) error
@ -42,5 +42,5 @@ type SessionManager interface {
OnConnected(listener func(session Session)) OnConnected(listener func(session Session))
// auth // auth
Authenticate(r *http.Request) (string, string, bool, error) Authenticate(r *http.Request) (Session, error)
} }

View File

@ -127,14 +127,14 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
return err return err
} }
id, ip, admin, err := ws.sessions.Authenticate(r) session, err := ws.sessions.Authenticate(r)
if err != nil { if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed") ws.logger.Warn().Err(err).Msg("authentication failed")
// TODO: Refactor // TODO: Refactor
if err = connection.WriteJSON(message.Disconnect{ if err = connection.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT, Event: event.SYSTEM_DISCONNECT,
Message: "invalid_password", Message: "authentication failed",
}); err != nil { }); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect") ws.logger.Error().Err(err).Msg("failed to send disconnect")
} }
@ -142,14 +142,20 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
return connection.Close() return connection.Close()
} }
// TODO: Refactor.
ip := r.RemoteAddr
// if allow poxy {
// ip = utils.ReadUserIP(r)
// }
socket := &WebSocketCtx{ socket := &WebSocketCtx{
id: id, id: session.ID(),
ws: ws, ws: ws,
address: ip, address: ip,
connection: connection, connection: connection,
} }
ok, reason := ws.handler.Connected(id, socket) ok, reason := ws.handler.Connected(session.ID(), socket)
if !ok { if !ok {
// TODO: Refactor // TODO: Refactor
if err = connection.WriteJSON(message.Disconnect{ if err = connection.WriteJSON(message.Disconnect{
@ -162,23 +168,23 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
return connection.Close() return connection.Close()
} }
ws.sessions.New(id, admin, socket) session.SetSocket(socket)
ws.logger. ws.logger.
Debug(). Debug().
Str("session", id). Str("session", session.ID()).
Str("address", connection.RemoteAddr().String()). Str("address", connection.RemoteAddr().String()).
Msg("new connection created") Msg("new connection created")
defer func() { defer func() {
ws.logger. ws.logger.
Debug(). Debug().
Str("session", id). Str("session", session.ID()).
Str("address", connection.RemoteAddr().String()). Str("address", connection.RemoteAddr().String()).
Msg("session ended") Msg("session ended")
}() }()
ws.handle(connection, id) ws.handle(connection, session.ID())
return nil return nil
} }