mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
refactor WS authentication.
This commit is contained in:
parent
16d762b6ae
commit
311ed987d8
@ -3,35 +3,54 @@ package session
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"demodesk/neko/internal/types"
|
||||||
"demodesk/neko/internal/utils"
|
"demodesk/neko/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Refactor
|
const (
|
||||||
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (string, string, bool, error) {
|
token_name = "password"
|
||||||
ip := r.RemoteAddr
|
)
|
||||||
|
|
||||||
//if ws.conf.Proxy {
|
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) {
|
||||||
// ip = utils.ReadUserIP(r)
|
token := getToken(r)
|
||||||
//}
|
if token == "" {
|
||||||
|
return nil, fmt.Errorf("no password provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
isAdmin := (token == manager.config.AdminPassword)
|
||||||
|
isUser := (token == manager.config.Password)
|
||||||
|
|
||||||
|
if !isAdmin && !isUser {
|
||||||
|
return nil, fmt.Errorf("invalid password")
|
||||||
|
}
|
||||||
|
|
||||||
id, err := utils.NewUID(32)
|
id, err := utils.NewUID(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ip, false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
passwords, ok := r.URL.Query()["password"]
|
return manager.New(id, isAdmin), nil
|
||||||
if !ok || len(passwords[0]) < 1 {
|
}
|
||||||
return "", ip, false, fmt.Errorf("no password provided")
|
|
||||||
}
|
func getToken(r *http.Request) string {
|
||||||
|
// Get token from query
|
||||||
if passwords[0] == manager.config.AdminPassword {
|
if token := r.URL.Query().Get(token_name); token != "" {
|
||||||
return id, ip, true, nil
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
if passwords[0] == manager.config.Password {
|
// Get token from authorization header
|
||||||
return id, ip, false, nil
|
bearer := r.Header.Get("Authorization")
|
||||||
}
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
||||||
|
return bearer[7:]
|
||||||
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
|
}
|
||||||
|
|
||||||
|
// Get token from cookie
|
||||||
|
cookie, err := r.Cookie(token_name)
|
||||||
|
if err == nil {
|
||||||
|
return cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -30,23 +30,16 @@ type SessionManagerCtx struct {
|
|||||||
emmiter events.EventEmmiter
|
emmiter events.EventEmmiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *SessionManagerCtx) New(id string, admin bool, socket types.WebSocket) types.Session {
|
func (manager *SessionManagerCtx) New(id string, admin bool) types.Session {
|
||||||
session := &SessionCtx{
|
session := &SessionCtx{
|
||||||
id: id,
|
id: id,
|
||||||
admin: admin,
|
admin: admin,
|
||||||
manager: manager,
|
manager: manager,
|
||||||
socket: socket,
|
|
||||||
logger: manager.logger.With().Str("id", id).Logger(),
|
logger: manager.logger.With().Str("id", id).Logger(),
|
||||||
connected: false,
|
connected: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.members[id] = session
|
manager.members[id] = session
|
||||||
manager.emmiter.Emit("created", session)
|
|
||||||
|
|
||||||
if !manager.capture.Streaming() && len(manager.members) > 0 {
|
|
||||||
manager.capture.StartStream()
|
|
||||||
}
|
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,10 +59,6 @@ func (manager *SessionManagerCtx) Destroy(id string) error {
|
|||||||
delete(manager.members, id)
|
delete(manager.members, id)
|
||||||
err := session.destroy()
|
err := session.destroy()
|
||||||
|
|
||||||
if !manager.capture.Streaming() && len(manager.members) <= 0 {
|
|
||||||
manager.capture.StopStream()
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.emmiter.Emit("destroy", id)
|
manager.emmiter.Emit("destroy", id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -164,12 +153,22 @@ func (manager *SessionManagerCtx) OnHostCleared(listener func(session types.Sess
|
|||||||
|
|
||||||
func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) {
|
func (manager *SessionManagerCtx) OnDestroy(listener func(id string)) {
|
||||||
manager.emmiter.On("destroy", func(payload ...interface{}) {
|
manager.emmiter.On("destroy", func(payload ...interface{}) {
|
||||||
|
// Stop streaming, if everyone left
|
||||||
|
if manager.capture.Streaming() && len(manager.members) == 0 {
|
||||||
|
manager.capture.StopStream()
|
||||||
|
}
|
||||||
|
|
||||||
listener(payload[0].(string))
|
listener(payload[0].(string))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
|
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
|
||||||
manager.emmiter.On("created", func(payload ...interface{}) {
|
manager.emmiter.On("created", func(payload ...interface{}) {
|
||||||
|
// Start streaming, when first joins
|
||||||
|
if !manager.capture.Streaming() {
|
||||||
|
manager.capture.StartStream()
|
||||||
|
}
|
||||||
|
|
||||||
listener(payload[0].(*SessionCtx))
|
listener(payload[0].(*SessionCtx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,7 @@ func (session *SessionCtx) SetName(name string) {
|
|||||||
|
|
||||||
func (session *SessionCtx) SetSocket(socket types.WebSocket) {
|
func (session *SessionCtx) SetSocket(socket types.WebSocket) {
|
||||||
session.socket = socket
|
session.socket = socket
|
||||||
|
session.manager.emmiter.Emit("created", session)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *SessionCtx) SetPeer(peer types.Peer) {
|
func (session *SessionCtx) SetPeer(peer types.Peer) {
|
||||||
|
@ -21,7 +21,7 @@ type Session interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SessionManager interface {
|
type SessionManager interface {
|
||||||
New(id string, admin bool, socket WebSocket) Session
|
New(id string, admin bool) Session
|
||||||
Get(id string) (Session, bool)
|
Get(id string) (Session, bool)
|
||||||
Has(id string) bool
|
Has(id string) bool
|
||||||
Destroy(id string) error
|
Destroy(id string) error
|
||||||
@ -42,5 +42,5 @@ type SessionManager interface {
|
|||||||
OnConnected(listener func(session Session))
|
OnConnected(listener func(session Session))
|
||||||
|
|
||||||
// auth
|
// auth
|
||||||
Authenticate(r *http.Request) (string, string, bool, error)
|
Authenticate(r *http.Request) (Session, error)
|
||||||
}
|
}
|
||||||
|
@ -127,14 +127,14 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
id, ip, admin, err := ws.sessions.Authenticate(r)
|
session, err := ws.sessions.Authenticate(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ws.logger.Warn().Err(err).Msg("authentication failed")
|
ws.logger.Warn().Err(err).Msg("authentication failed")
|
||||||
|
|
||||||
// TODO: Refactor
|
// TODO: Refactor
|
||||||
if err = connection.WriteJSON(message.Disconnect{
|
if err = connection.WriteJSON(message.Disconnect{
|
||||||
Event: event.SYSTEM_DISCONNECT,
|
Event: event.SYSTEM_DISCONNECT,
|
||||||
Message: "invalid_password",
|
Message: "authentication failed",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||||
}
|
}
|
||||||
@ -142,14 +142,20 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
|||||||
return connection.Close()
|
return connection.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Refactor.
|
||||||
|
ip := r.RemoteAddr
|
||||||
|
// if allow poxy {
|
||||||
|
// ip = utils.ReadUserIP(r)
|
||||||
|
// }
|
||||||
|
|
||||||
socket := &WebSocketCtx{
|
socket := &WebSocketCtx{
|
||||||
id: id,
|
id: session.ID(),
|
||||||
ws: ws,
|
ws: ws,
|
||||||
address: ip,
|
address: ip,
|
||||||
connection: connection,
|
connection: connection,
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, reason := ws.handler.Connected(id, socket)
|
ok, reason := ws.handler.Connected(session.ID(), socket)
|
||||||
if !ok {
|
if !ok {
|
||||||
// TODO: Refactor
|
// TODO: Refactor
|
||||||
if err = connection.WriteJSON(message.Disconnect{
|
if err = connection.WriteJSON(message.Disconnect{
|
||||||
@ -162,23 +168,23 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
|||||||
return connection.Close()
|
return connection.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.sessions.New(id, admin, socket)
|
session.SetSocket(socket)
|
||||||
|
|
||||||
ws.logger.
|
ws.logger.
|
||||||
Debug().
|
Debug().
|
||||||
Str("session", id).
|
Str("session", session.ID()).
|
||||||
Str("address", connection.RemoteAddr().String()).
|
Str("address", connection.RemoteAddr().String()).
|
||||||
Msg("new connection created")
|
Msg("new connection created")
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
ws.logger.
|
ws.logger.
|
||||||
Debug().
|
Debug().
|
||||||
Str("session", id).
|
Str("session", session.ID()).
|
||||||
Str("address", connection.RemoteAddr().String()).
|
Str("address", connection.RemoteAddr().String()).
|
||||||
Msg("session ended")
|
Msg("session ended")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ws.handle(connection, id)
|
ws.handle(connection, session.ID())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user