From e045bd8a1ebf2b7fcb198b7b26f3fab3e0ea13ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Tue, 13 Sep 2022 20:04:43 +0200 Subject: [PATCH] move locks and bans to state. --- server/internal/websocket/handler/admin.go | 12 ++-- server/internal/websocket/handler/control.go | 6 +- server/internal/websocket/handler/handler.go | 15 ++--- server/internal/websocket/handler/session.go | 2 +- server/internal/websocket/state/state.go | 61 ++++++++++++++++++++ server/internal/websocket/websocket.go | 25 ++++---- 6 files changed, 88 insertions(+), 33 deletions(-) create mode 100644 server/internal/websocket/state/state.go diff --git a/server/internal/websocket/handler/admin.go b/server/internal/websocket/handler/admin.go index 462cd33d..7b9f1eeb 100644 --- a/server/internal/websocket/handler/admin.go +++ b/server/internal/websocket/handler/admin.go @@ -14,8 +14,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me return nil } - _, ok := h.Locked[payload.Resource] - if ok { + if h.state.IsLocked(payload.Resource) { h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...") return nil } @@ -30,7 +29,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me h.sessions.SetControlLocked(true) } - h.Locked[payload.Resource] = id + h.state.Lock(payload.Resource, id) if err := h.sessions.Broadcast( message.AdminLock{ @@ -51,8 +50,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload * return nil } - _, ok := h.Locked[payload.Resource] - if !ok { + if !h.state.IsLocked(payload.Resource) { h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...") return nil } @@ -62,7 +60,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload * h.sessions.SetControlLocked(false) } - delete(h.Locked, payload.Resource) + h.state.Unlock(payload.Resource) if err := h.sessions.Broadcast( message.AdminLock{ @@ -302,7 +300,7 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes } h.logger.Debug().Str("address", remote).Msg("adding address to banned") - h.Banned[address[0]] = id + h.state.Ban(address[0], id) if err := target.Kick("banned"); err != nil { return err diff --git a/server/internal/websocket/handler/control.go b/server/internal/websocket/handler/control.go index 4ac12eb3..a3be3f65 100644 --- a/server/internal/websocket/handler/control.go +++ b/server/internal/websocket/handler/control.go @@ -34,8 +34,7 @@ func (h *MessageHandler) controlRequest(id string, session types.Session) error // check for host if !h.sessions.HasHost() { // check if control is locked or user is admin - _, ok := h.Locked["control"] - if ok && !session.Admin() { + if h.state.IsLocked("control") && !session.Admin() { h.logger.Debug().Msg("control is locked") return nil } @@ -98,8 +97,7 @@ func (h *MessageHandler) controlGive(id string, session types.Session, payload * } // check if control is locked or giver is admin - _, ok := h.Locked["control"] - if ok && !session.Admin() { + if h.state.IsLocked("control") && !session.Admin() { h.logger.Debug().Msg("control is locked") return nil } diff --git a/server/internal/websocket/handler/handler.go b/server/internal/websocket/handler/handler.go index e45f4321..4699ef4a 100644 --- a/server/internal/websocket/handler/handler.go +++ b/server/internal/websocket/handler/handler.go @@ -11,6 +11,7 @@ import ( "m1k1o/neko/internal/types/event" "m1k1o/neko/internal/types/message" "m1k1o/neko/internal/utils" + "m1k1o/neko/internal/websocket/state" ) type MessageHandler struct { @@ -20,9 +21,7 @@ type MessageHandler struct { capture types.CaptureManager webrtc types.WebRTCManager broadcast types.BroadcastManager - - Banned map[string]string // IP -> session ID (that banned it) - Locked map[string]string // resource name -> session ID (that locked it) + state *state.State } func New( @@ -31,6 +30,7 @@ func New( capture types.CaptureManager, webrtc types.WebRTCManager, broadcast types.BroadcastManager, + state *state.State, ) *MessageHandler { return &MessageHandler{ logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(), @@ -39,8 +39,7 @@ func New( capture: capture, webrtc: webrtc, broadcast: broadcast, - Banned: make(map[string]string), - Locked: make(map[string]string), + state: state, } } @@ -48,15 +47,13 @@ func (h *MessageHandler) Connected(admin bool, address string) (bool, string) { if address == "" { h.logger.Debug().Msg("no remote address") } else { - _, ok := h.Banned[address] - if ok { + if h.state.IsBanned(address) { h.logger.Debug().Str("address", address).Msg("banned") return false, "banned" } } - _, ok := h.Locked["login"] - if ok && !admin { + if h.state.IsLocked("login") && !admin { h.logger.Debug().Msg("server locked") return false, "locked" } diff --git a/server/internal/websocket/handler/session.go b/server/internal/websocket/handler/session.go index 100aff34..33a03a31 100644 --- a/server/internal/websocket/handler/session.go +++ b/server/internal/websocket/handler/session.go @@ -16,7 +16,7 @@ func (h *MessageHandler) SessionCreated(id string, session types.Session) error if err := session.Send(message.SystemInit{ Event: event.SYSTEM_INIT, ImplicitHosting: h.webrtc.ImplicitControl(), - Locks: h.Locked, + Locks: h.state.AllLocked(), }); err != nil { h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT) return err diff --git a/server/internal/websocket/state/state.go b/server/internal/websocket/state/state.go new file mode 100644 index 00000000..3b38797f --- /dev/null +++ b/server/internal/websocket/state/state.go @@ -0,0 +1,61 @@ +package state + +type State struct { + banned map[string]string // IP -> session ID (that banned it) + locked map[string]string // resource name -> session ID (that locked it) +} + +func New() *State { + return &State{ + banned: make(map[string]string), + locked: make(map[string]string), + } +} + +// Ban + +func (s *State) Ban(ip, id string) { + s.banned[ip] = id +} + +func (s *State) Unban(ip string) { + delete(s.banned, ip) +} + +func (s *State) IsBanned(ip string) bool { + _, ok := s.banned[ip] + return ok +} + +func (s *State) GetBanned(ip string) (string, bool) { + id, ok := s.banned[ip] + return id, ok +} + +func (s *State) AllBanned() map[string]string { + return s.banned +} + +// Lock + +func (s *State) Lock(resource, id string) { + s.locked[resource] = id +} + +func (s *State) Unlock(resource string) { + delete(s.locked, resource) +} + +func (s *State) IsLocked(resource string) bool { + _, ok := s.locked[resource] + return ok +} + +func (s *State) GetLocked(resource string) (string, bool) { + id, ok := s.locked[resource] + return id, ok +} + +func (s *State) AllLocked() map[string]string { + return s.locked +} diff --git a/server/internal/websocket/websocket.go b/server/internal/websocket/websocket.go index 11e550cc..0da01604 100644 --- a/server/internal/websocket/websocket.go +++ b/server/internal/websocket/websocket.go @@ -17,6 +17,7 @@ import ( "m1k1o/neko/internal/types/message" "m1k1o/neko/internal/utils" "m1k1o/neko/internal/websocket/handler" + "m1k1o/neko/internal/websocket/state" ) const CONTROL_PROTECTION_SESSION = "by_control_protection" @@ -24,17 +25,17 @@ const CONTROL_PROTECTION_SESSION = "by_control_protection" func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler { logger := log.With().Str("module", "websocket").Logger() - locks := make(map[string]string) + state := state.New() // if control protection is enabled if conf.ControlProtection { - locks["control"] = CONTROL_PROTECTION_SESSION + state.Lock("control", CONTROL_PROTECTION_SESSION) logger.Info().Msgf("control locked on behalf of control protection") } // apply default locks for _, lock := range conf.Locks { - locks[lock] = "" // empty session ID + state.Lock(lock, "") // empty session ID } if len(conf.Locks) > 0 { @@ -47,11 +48,9 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty capture, webrtc, broadcast, + state, ) - // set inital locks - handler.Locked = locks - return &WebSocketHandler{ logger: logger, shutdown: make(chan interface{}), @@ -59,6 +58,7 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty sessions: sessions, desktop: desktop, webrtc: webrtc, + state: state, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -80,6 +80,7 @@ type WebSocketHandler struct { sessions types.SessionManager desktop types.DesktopManager webrtc types.WebRTCManager + state *state.State conf *config.WebSocket handler *handler.MessageHandler @@ -108,9 +109,9 @@ func (ws *WebSocketHandler) Start() { // if control protection is enabled and at least one admin // and if room was locked on behalf control protection, unlock - sess, ok := ws.handler.Locked["control"] + sess, ok := ws.state.GetLocked("control") if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 { - delete(ws.handler.Locked, "control") + ws.state.Unlock("control") ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags. ws.logger.Info().Msgf("control unlocked on behalf of control protection") @@ -144,9 +145,9 @@ func (ws *WebSocketHandler) Start() { // if control protection is enabled and no admin // and room is not locked, lock - _, ok := ws.handler.Locked["control"] + ok := ws.state.IsLocked("control") if !ok && ws.conf.ControlProtection && adminCount == 0 { - ws.handler.Locked["control"] = CONTROL_PROTECTION_SESSION + ws.state.Lock("control", CONTROL_PROTECTION_SESSION) ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags. ws.logger.Info().Msgf("control locked and released on behalf of control protection") ws.handler.AdminRelease(id, session) @@ -314,8 +315,8 @@ func (ws *WebSocketHandler) Stats() types.Stats { Host: host, Members: ws.sessions.Members(), - Banned: ws.handler.Banned, - Locked: ws.handler.Locked, + Banned: ws.state.AllBanned(), + Locked: ws.state.AllLocked(), ServerStartedAt: ws.serverStartedAt, LastAdminLeftAt: ws.lastAdminLeftAt,