auth moved from websockets to session.

This commit is contained in:
Miroslav Šedivý 2020-11-01 18:39:12 +01:00
parent 3ea979ed47
commit c10b2212d1
8 changed files with 65 additions and 53 deletions

View File

@ -18,9 +18,9 @@ func init() {
configs := []config.Config{ configs := []config.Config{
neko.Service.Configs.Capture, neko.Service.Configs.Capture,
neko.Service.Configs.Server,
neko.Service.Configs.WebRTC, neko.Service.Configs.WebRTC,
neko.Service.Configs.WebSocket, neko.Service.Configs.Session,
neko.Service.Configs.Server,
} }
cobra.OnInitialize(func() { cobra.OnInitialize(func() {

View File

@ -10,6 +10,7 @@ type Server struct {
Key string Key string
Bind string Bind string
Static string Static string
//Proxy bool
UserToken string UserToken string
AdminToken string AdminToken string
} }
@ -35,6 +36,11 @@ func (Server) Init(cmd *cobra.Command) error {
return err return err
} }
//cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies")
//if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
// return err
//}
cmd.PersistentFlags().String("user_token", "user_secret", "JWT token for users") cmd.PersistentFlags().String("user_token", "user_secret", "JWT token for users")
if err := viper.BindPFlag("user_token", cmd.PersistentFlags().Lookup("user_token")); err != nil { if err := viper.BindPFlag("user_token", cmd.PersistentFlags().Lookup("user_token")); err != nil {
return err return err
@ -53,6 +59,7 @@ func (s *Server) Set() {
s.Key = viper.GetString("key") s.Key = viper.GetString("key")
s.Bind = viper.GetString("bind") s.Bind = viper.GetString("bind")
s.Static = viper.GetString("static") s.Static = viper.GetString("static")
//s.Proxy = viper.GetBool("proxy")
s.UserToken = viper.GetString("user_token") s.UserToken = viper.GetString("user_token")
s.AdminToken = viper.GetString("admin_token") s.AdminToken = viper.GetString("admin_token")
} }

View File

@ -5,13 +5,12 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
type WebSocket struct { type Session struct {
Password string Password string
AdminPassword string AdminPassword string
Proxy bool
} }
func (WebSocket) Init(cmd *cobra.Command) error { func (Session) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("password", "neko", "password for connecting to stream") cmd.PersistentFlags().String("password", "neko", "password for connecting to stream")
if err := viper.BindPFlag("password", cmd.PersistentFlags().Lookup("password")); err != nil { if err := viper.BindPFlag("password", cmd.PersistentFlags().Lookup("password")); err != nil {
return err return err
@ -22,16 +21,10 @@ func (WebSocket) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().Bool("proxy", false, "enable reverse proxy mode")
if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
return err
}
return nil return nil
} }
func (s *WebSocket) Set() { func (s *Session) Set() {
s.Password = viper.GetString("password") s.Password = viper.GetString("password")
s.AdminPassword = viper.GetString("password_admin") s.AdminPassword = viper.GetString("password_admin")
s.Proxy = viper.GetBool("proxy")
} }

37
internal/session/auth.go Normal file
View File

@ -0,0 +1,37 @@
package session
import (
"fmt"
"net/http"
"demodesk/neko/internal/utils"
)
// TODO: Refactor
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (string, string, bool, error) {
ip := r.RemoteAddr
//if ws.conf.Proxy {
// ip = utils.ReadUserIP(r)
//}
id, err := utils.NewUID(32)
if err != nil {
return "", ip, false, err
}
passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 {
return "", ip, false, fmt.Errorf("no password provided")
}
if passwords[0] == manager.config.AdminPassword {
return id, ip, true, nil
}
if passwords[0] == manager.config.Password {
return id, ip, false, nil
}
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
}

View File

@ -6,14 +6,16 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
"demodesk/neko/internal/config"
"demodesk/neko/internal/utils" "demodesk/neko/internal/utils"
) )
func New(capture types.CaptureManager) *SessionManagerCtx { func New(capture types.CaptureManager, config *config.Session) *SessionManagerCtx {
return &SessionManagerCtx{ return &SessionManagerCtx{
logger: log.With().Str("module", "session").Logger(), logger: log.With().Str("module", "session").Logger(),
host: nil, host: nil,
capture: capture, capture: capture,
config: config,
members: make(map[string]*SessionCtx), members: make(map[string]*SessionCtx),
emmiter: events.New(), emmiter: events.New(),
} }
@ -23,6 +25,7 @@ type SessionManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
host types.Session host types.Session
capture types.CaptureManager capture types.CaptureManager
config *config.Session
members map[string]*SessionCtx members map[string]*SessionCtx
emmiter events.EventEmmiter emmiter events.EventEmmiter
} }

View File

@ -1,5 +1,7 @@
package types package types
import "net/http"
type Session interface { type Session interface {
ID() string ID() string
Name() string Name() string
@ -38,4 +40,7 @@ type SessionManager interface {
OnDestroy(listener func(id string)) OnDestroy(listener func(id string))
OnCreated(listener func(session Session)) OnCreated(listener func(session Session))
OnConnected(listener func(session Session)) OnConnected(listener func(session Session))
// auth
Authenticate(r *http.Request) (string, string, bool, error)
} }

View File

@ -1,7 +1,6 @@
package websocket package websocket
import ( import (
"fmt"
"net/http" "net/http"
"time" "time"
@ -14,8 +13,6 @@ import (
"demodesk/neko/internal/types/message" "demodesk/neko/internal/types/message"
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
"demodesk/neko/internal/config"
"demodesk/neko/internal/utils"
) )
func New( func New(
@ -23,13 +20,11 @@ func New(
desktop types.DesktopManager, desktop types.DesktopManager,
capture types.CaptureManager, capture types.CaptureManager,
webrtc types.WebRTCManager, webrtc types.WebRTCManager,
conf *config.WebSocket,
) *WebSocketManagerCtx { ) *WebSocketManagerCtx {
logger := log.With().Str("module", "websocket").Logger() logger := log.With().Str("module", "websocket").Logger()
return &WebSocketManagerCtx{ return &WebSocketManagerCtx{
logger: logger, logger: logger,
conf: conf,
sessions: sessions, sessions: sessions,
desktop: desktop, desktop: desktop,
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
@ -49,7 +44,6 @@ type WebSocketManagerCtx struct {
upgrader websocket.Upgrader upgrader websocket.Upgrader
sessions types.SessionManager sessions types.SessionManager
desktop types.DesktopManager desktop types.DesktopManager
conf *config.WebSocket
handler *handler.MessageHandlerCtx handler *handler.MessageHandlerCtx
shutdown chan bool shutdown chan bool
} }
@ -80,8 +74,10 @@ func (ws *WebSocketManagerCtx) Start() {
}) })
go func() { go func() {
ws.logger.Info().Msg("clipboard loop started")
defer func() { defer func() {
ws.logger.Info().Msg("shutdown") ws.logger.Info().Msg("clipboard loop stopped")
}() }()
current := ws.desktop.ReadClipboard() current := ws.desktop.ReadClipboard()
@ -131,7 +127,7 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
return err return err
} }
id, ip, admin, err := ws.authenticate(r) id, ip, admin, err := ws.sessions.Authenticate(r)
if err != nil { if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed") ws.logger.Warn().Err(err).Msg("authentication failed")
@ -186,35 +182,6 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
return nil return nil
} }
// TODO: Refactor
func (ws *WebSocketManagerCtx) authenticate(r *http.Request) (string, string, bool, error) {
ip := r.RemoteAddr
if ws.conf.Proxy {
ip = utils.ReadUserIP(r)
}
id, err := utils.NewUID(32)
if err != nil {
return "", ip, false, err
}
passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 {
return "", ip, false, fmt.Errorf("no password provided")
}
if passwords[0] == ws.conf.AdminPassword {
return id, ip, true, nil
}
if passwords[0] == ws.conf.Password {
return id, ip, false, nil
}
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
}
func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, id string) { func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, id string) {
bytes := make(chan []byte) bytes := make(chan []byte)
cancel := make(chan struct{}) cancel := make(chan struct{})

View File

@ -64,8 +64,8 @@ func init() {
Root: &config.Root{}, Root: &config.Root{},
Capture: &config.Capture{}, Capture: &config.Capture{},
WebRTC: &config.WebRTC{}, WebRTC: &config.WebRTC{},
Session: &config.Session{},
Server: &config.Server{}, Server: &config.Server{},
WebSocket: &config.WebSocket{},
}, },
} }
} }
@ -103,8 +103,8 @@ type Configs struct {
Root *config.Root Root *config.Root
Capture *config.Capture Capture *config.Capture
WebRTC *config.WebRTC WebRTC *config.WebRTC
Session *config.Session
Server *config.Server Server *config.Server
WebSocket *config.WebSocket
} }
type Neko struct { type Neko struct {
@ -146,6 +146,7 @@ func (neko *Neko) Start() {
neko.sessionManager = session.New( neko.sessionManager = session.New(
neko.captureManager, neko.captureManager,
neko.Configs.Session,
) )
neko.webSocketManager = websocket.New( neko.webSocketManager = websocket.New(
@ -153,7 +154,6 @@ func (neko *Neko) Start() {
neko.desktopManager, neko.desktopManager,
neko.captureManager, neko.captureManager,
neko.webRTCManager, neko.webRTCManager,
neko.Configs.WebSocket,
) )
neko.webSocketManager.Start() neko.webSocketManager.Start()