mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
refactor WS authentication.
This commit is contained in:
parent
16d762b6ae
commit
311ed987d8
@ -3,35 +3,54 @@ package session
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"demodesk/neko/internal/types"
|
||||
"demodesk/neko/internal/utils"
|
||||
)
|
||||
|
||||
// TODO: Refactor
|
||||
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (string, string, bool, error) {
|
||||
ip := r.RemoteAddr
|
||||
const (
|
||||
token_name = "password"
|
||||
)
|
||||
|
||||
//if ws.conf.Proxy {
|
||||
// ip = utils.ReadUserIP(r)
|
||||
//}
|
||||
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) {
|
||||
token := getToken(r)
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
isAdmin := (token == manager.config.AdminPassword)
|
||||
isUser := (token == manager.config.Password)
|
||||
|
||||
if !isAdmin && !isUser {
|
||||
return nil, fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", ip, false, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", ip, false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == manager.config.AdminPassword {
|
||||
return id, ip, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == manager.config.Password {
|
||||
return id, ip, false, nil
|
||||
}
|
||||
|
||||
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
return manager.New(id, isAdmin), nil
|
||||
}
|
||||
|
||||
func getToken(r *http.Request) string {
|
||||
// Get token from query
|
||||
if token := r.URL.Query().Get(token_name); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// Get token from authorization header
|
||||
bearer := r.Header.Get("Authorization")
|
||||
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
||||
return bearer[7:]
|
||||
}
|
||||
|
||||
// Get token from cookie
|
||||
cookie, err := r.Cookie(token_name)
|
||||
if err == nil {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
@ -30,23 +30,16 @@ type SessionManagerCtx struct {
|
||||
emmiter events.EventEmmiter
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) New(id string, admin bool, socket types.WebSocket) types.Session {
|
||||
func (manager *SessionManagerCtx) New(id string, admin bool) types.Session {
|
||||
session := &SessionCtx{
|
||||
id: id,
|
||||
admin: admin,
|
||||
manager: manager,
|
||||
socket: socket,
|
||||
logger: manager.logger.With().Str("id", id).Logger(),
|
||||
connected: false,
|
||||
}
|
||||
|
||||
manager.members[id] = session
|
||||
manager.emmiter.Emit("created", session)
|
||||
|
||||
if !manager.capture.Streaming() && len(manager.members) > 0 {
|
||||
manager.capture.StartStream()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@ -66,10 +59,6 @@ func (manager *SessionManagerCtx) Destroy(id string) error {
|
||||
delete(manager.members, id)
|
||||
err := session.destroy()
|
||||
|
||||
if !manager.capture.Streaming() && len(manager.members) <= 0 {
|
||||
manager.capture.StopStream()
|
||||
}
|
||||
|
||||
manager.emmiter.Emit("destroy", id)
|
||||
return err
|
||||
}
|
||||
@ -164,12 +153,22 @@ func (manager *SessionManagerCtx) OnHostCleared(listener func(session types.Sess
|
||||
|
||||
func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) {
|
||||
manager.emmiter.On("destroy", func(payload ...interface{}) {
|
||||
// Stop streaming, if everyone left
|
||||
if manager.capture.Streaming() && len(manager.members) == 0 {
|
||||
manager.capture.StopStream()
|
||||
}
|
||||
|
||||
listener(payload[0].(string))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
|
||||
manager.emmiter.On("created", func(payload ...interface{}) {
|
||||
// Start streaming, when first joins
|
||||
if !manager.capture.Streaming() {
|
||||
manager.capture.StartStream()
|
||||
}
|
||||
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
@ -62,6 +62,7 @@ func (session *SessionCtx) SetName(name string) {
|
||||
|
||||
func (session *SessionCtx) SetSocket(socket types.WebSocket) {
|
||||
session.socket = socket
|
||||
session.manager.emmiter.Emit("created", session)
|
||||
}
|
||||
|
||||
func (session *SessionCtx) SetPeer(peer types.Peer) {
|
||||
|
@ -21,7 +21,7 @@ type Session interface {
|
||||
}
|
||||
|
||||
type SessionManager interface {
|
||||
New(id string, admin bool, socket WebSocket) Session
|
||||
New(id string, admin bool) Session
|
||||
Get(id string) (Session, bool)
|
||||
Has(id string) bool
|
||||
Destroy(id string) error
|
||||
@ -42,5 +42,5 @@ type SessionManager interface {
|
||||
OnConnected(listener func(session Session))
|
||||
|
||||
// auth
|
||||
Authenticate(r *http.Request) (string, string, bool, error)
|
||||
Authenticate(r *http.Request) (Session, error)
|
||||
}
|
||||
|
@ -127,14 +127,14 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
||||
return err
|
||||
}
|
||||
|
||||
id, ip, admin, err := ws.sessions.Authenticate(r)
|
||||
session, err := ws.sessions.Authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authentication failed")
|
||||
|
||||
// TODO: Refactor
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "invalid_password",
|
||||
Message: "authentication failed",
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
@ -142,14 +142,20 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
||||
return connection.Close()
|
||||
}
|
||||
|
||||
// TODO: Refactor.
|
||||
ip := r.RemoteAddr
|
||||
// if allow poxy {
|
||||
// ip = utils.ReadUserIP(r)
|
||||
// }
|
||||
|
||||
socket := &WebSocketCtx{
|
||||
id: id,
|
||||
id: session.ID(),
|
||||
ws: ws,
|
||||
address: ip,
|
||||
connection: connection,
|
||||
}
|
||||
|
||||
ok, reason := ws.handler.Connected(id, socket)
|
||||
ok, reason := ws.handler.Connected(session.ID(), socket)
|
||||
if !ok {
|
||||
// TODO: Refactor
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
@ -162,23 +168,23 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
||||
return connection.Close()
|
||||
}
|
||||
|
||||
ws.sessions.New(id, admin, socket)
|
||||
session.SetSocket(socket)
|
||||
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("session", session.ID()).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("new connection created")
|
||||
|
||||
defer func() {
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("session", session.ID()).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("session ended")
|
||||
}()
|
||||
|
||||
ws.handle(connection, id)
|
||||
ws.handle(connection, session.ID())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user