move locks and bans to state.

This commit is contained in:
Miroslav Šedivý 2022-09-13 20:04:43 +02:00
parent 06e25df962
commit e045bd8a1e
6 changed files with 88 additions and 33 deletions

View File

@ -14,8 +14,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
return nil return nil
} }
_, ok := h.Locked[payload.Resource] if h.state.IsLocked(payload.Resource) {
if ok {
h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...") h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...")
return nil return nil
} }
@ -30,7 +29,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
h.sessions.SetControlLocked(true) h.sessions.SetControlLocked(true)
} }
h.Locked[payload.Resource] = id h.state.Lock(payload.Resource, id)
if err := h.sessions.Broadcast( if err := h.sessions.Broadcast(
message.AdminLock{ message.AdminLock{
@ -51,8 +50,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
return nil return nil
} }
_, ok := h.Locked[payload.Resource] if !h.state.IsLocked(payload.Resource) {
if !ok {
h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...") h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...")
return nil return nil
} }
@ -62,7 +60,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
h.sessions.SetControlLocked(false) h.sessions.SetControlLocked(false)
} }
delete(h.Locked, payload.Resource) h.state.Unlock(payload.Resource)
if err := h.sessions.Broadcast( if err := h.sessions.Broadcast(
message.AdminLock{ message.AdminLock{
@ -302,7 +300,7 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
} }
h.logger.Debug().Str("address", remote).Msg("adding address to banned") h.logger.Debug().Str("address", remote).Msg("adding address to banned")
h.Banned[address[0]] = id h.state.Ban(address[0], id)
if err := target.Kick("banned"); err != nil { if err := target.Kick("banned"); err != nil {
return err return err

View File

@ -34,8 +34,7 @@ func (h *MessageHandler) controlRequest(id string, session types.Session) error
// check for host // check for host
if !h.sessions.HasHost() { if !h.sessions.HasHost() {
// check if control is locked or user is admin // check if control is locked or user is admin
_, ok := h.Locked["control"] if h.state.IsLocked("control") && !session.Admin() {
if ok && !session.Admin() {
h.logger.Debug().Msg("control is locked") h.logger.Debug().Msg("control is locked")
return nil return nil
} }
@ -98,8 +97,7 @@ func (h *MessageHandler) controlGive(id string, session types.Session, payload *
} }
// check if control is locked or giver is admin // check if control is locked or giver is admin
_, ok := h.Locked["control"] if h.state.IsLocked("control") && !session.Admin() {
if ok && !session.Admin() {
h.logger.Debug().Msg("control is locked") h.logger.Debug().Msg("control is locked")
return nil return nil
} }

View File

@ -11,6 +11,7 @@ import (
"m1k1o/neko/internal/types/event" "m1k1o/neko/internal/types/event"
"m1k1o/neko/internal/types/message" "m1k1o/neko/internal/types/message"
"m1k1o/neko/internal/utils" "m1k1o/neko/internal/utils"
"m1k1o/neko/internal/websocket/state"
) )
type MessageHandler struct { type MessageHandler struct {
@ -20,9 +21,7 @@ type MessageHandler struct {
capture types.CaptureManager capture types.CaptureManager
webrtc types.WebRTCManager webrtc types.WebRTCManager
broadcast types.BroadcastManager broadcast types.BroadcastManager
state *state.State
Banned map[string]string // IP -> session ID (that banned it)
Locked map[string]string // resource name -> session ID (that locked it)
} }
func New( func New(
@ -31,6 +30,7 @@ func New(
capture types.CaptureManager, capture types.CaptureManager,
webrtc types.WebRTCManager, webrtc types.WebRTCManager,
broadcast types.BroadcastManager, broadcast types.BroadcastManager,
state *state.State,
) *MessageHandler { ) *MessageHandler {
return &MessageHandler{ return &MessageHandler{
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(), logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
@ -39,8 +39,7 @@ func New(
capture: capture, capture: capture,
webrtc: webrtc, webrtc: webrtc,
broadcast: broadcast, broadcast: broadcast,
Banned: make(map[string]string), state: state,
Locked: make(map[string]string),
} }
} }
@ -48,15 +47,13 @@ func (h *MessageHandler) Connected(admin bool, address string) (bool, string) {
if address == "" { if address == "" {
h.logger.Debug().Msg("no remote address") h.logger.Debug().Msg("no remote address")
} else { } else {
_, ok := h.Banned[address] if h.state.IsBanned(address) {
if ok {
h.logger.Debug().Str("address", address).Msg("banned") h.logger.Debug().Str("address", address).Msg("banned")
return false, "banned" return false, "banned"
} }
} }
_, ok := h.Locked["login"] if h.state.IsLocked("login") && !admin {
if ok && !admin {
h.logger.Debug().Msg("server locked") h.logger.Debug().Msg("server locked")
return false, "locked" return false, "locked"
} }

View File

@ -16,7 +16,7 @@ func (h *MessageHandler) SessionCreated(id string, session types.Session) error
if err := session.Send(message.SystemInit{ if err := session.Send(message.SystemInit{
Event: event.SYSTEM_INIT, Event: event.SYSTEM_INIT,
ImplicitHosting: h.webrtc.ImplicitControl(), ImplicitHosting: h.webrtc.ImplicitControl(),
Locks: h.Locked, Locks: h.state.AllLocked(),
}); err != nil { }); err != nil {
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT) h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT)
return err return err

View File

@ -0,0 +1,61 @@
package state
type State struct {
banned map[string]string // IP -> session ID (that banned it)
locked map[string]string // resource name -> session ID (that locked it)
}
func New() *State {
return &State{
banned: make(map[string]string),
locked: make(map[string]string),
}
}
// Ban
func (s *State) Ban(ip, id string) {
s.banned[ip] = id
}
func (s *State) Unban(ip string) {
delete(s.banned, ip)
}
func (s *State) IsBanned(ip string) bool {
_, ok := s.banned[ip]
return ok
}
func (s *State) GetBanned(ip string) (string, bool) {
id, ok := s.banned[ip]
return id, ok
}
func (s *State) AllBanned() map[string]string {
return s.banned
}
// Lock
func (s *State) Lock(resource, id string) {
s.locked[resource] = id
}
func (s *State) Unlock(resource string) {
delete(s.locked, resource)
}
func (s *State) IsLocked(resource string) bool {
_, ok := s.locked[resource]
return ok
}
func (s *State) GetLocked(resource string) (string, bool) {
id, ok := s.locked[resource]
return id, ok
}
func (s *State) AllLocked() map[string]string {
return s.locked
}

View File

@ -17,6 +17,7 @@ import (
"m1k1o/neko/internal/types/message" "m1k1o/neko/internal/types/message"
"m1k1o/neko/internal/utils" "m1k1o/neko/internal/utils"
"m1k1o/neko/internal/websocket/handler" "m1k1o/neko/internal/websocket/handler"
"m1k1o/neko/internal/websocket/state"
) )
const CONTROL_PROTECTION_SESSION = "by_control_protection" const CONTROL_PROTECTION_SESSION = "by_control_protection"
@ -24,17 +25,17 @@ const CONTROL_PROTECTION_SESSION = "by_control_protection"
func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler { func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
logger := log.With().Str("module", "websocket").Logger() logger := log.With().Str("module", "websocket").Logger()
locks := make(map[string]string) state := state.New()
// if control protection is enabled // if control protection is enabled
if conf.ControlProtection { if conf.ControlProtection {
locks["control"] = CONTROL_PROTECTION_SESSION state.Lock("control", CONTROL_PROTECTION_SESSION)
logger.Info().Msgf("control locked on behalf of control protection") logger.Info().Msgf("control locked on behalf of control protection")
} }
// apply default locks // apply default locks
for _, lock := range conf.Locks { for _, lock := range conf.Locks {
locks[lock] = "" // empty session ID state.Lock(lock, "") // empty session ID
} }
if len(conf.Locks) > 0 { if len(conf.Locks) > 0 {
@ -47,11 +48,9 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
capture, capture,
webrtc, webrtc,
broadcast, broadcast,
state,
) )
// set inital locks
handler.Locked = locks
return &WebSocketHandler{ return &WebSocketHandler{
logger: logger, logger: logger,
shutdown: make(chan interface{}), shutdown: make(chan interface{}),
@ -59,6 +58,7 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
sessions: sessions, sessions: sessions,
desktop: desktop, desktop: desktop,
webrtc: webrtc, webrtc: webrtc,
state: state,
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
@ -80,6 +80,7 @@ type WebSocketHandler struct {
sessions types.SessionManager sessions types.SessionManager
desktop types.DesktopManager desktop types.DesktopManager
webrtc types.WebRTCManager webrtc types.WebRTCManager
state *state.State
conf *config.WebSocket conf *config.WebSocket
handler *handler.MessageHandler handler *handler.MessageHandler
@ -108,9 +109,9 @@ func (ws *WebSocketHandler) Start() {
// if control protection is enabled and at least one admin // if control protection is enabled and at least one admin
// and if room was locked on behalf control protection, unlock // and if room was locked on behalf control protection, unlock
sess, ok := ws.handler.Locked["control"] sess, ok := ws.state.GetLocked("control")
if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 { if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 {
delete(ws.handler.Locked, "control") ws.state.Unlock("control")
ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags. ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags.
ws.logger.Info().Msgf("control unlocked on behalf of control protection") ws.logger.Info().Msgf("control unlocked on behalf of control protection")
@ -144,9 +145,9 @@ func (ws *WebSocketHandler) Start() {
// if control protection is enabled and no admin // if control protection is enabled and no admin
// and room is not locked, lock // and room is not locked, lock
_, ok := ws.handler.Locked["control"] ok := ws.state.IsLocked("control")
if !ok && ws.conf.ControlProtection && adminCount == 0 { if !ok && ws.conf.ControlProtection && adminCount == 0 {
ws.handler.Locked["control"] = CONTROL_PROTECTION_SESSION ws.state.Lock("control", CONTROL_PROTECTION_SESSION)
ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags. ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags.
ws.logger.Info().Msgf("control locked and released on behalf of control protection") ws.logger.Info().Msgf("control locked and released on behalf of control protection")
ws.handler.AdminRelease(id, session) ws.handler.AdminRelease(id, session)
@ -314,8 +315,8 @@ func (ws *WebSocketHandler) Stats() types.Stats {
Host: host, Host: host,
Members: ws.sessions.Members(), Members: ws.sessions.Members(),
Banned: ws.handler.Banned, Banned: ws.state.AllBanned(),
Locked: ws.handler.Locked, Locked: ws.state.AllLocked(),
ServerStartedAt: ws.serverStartedAt, ServerStartedAt: ws.serverStartedAt,
LastAdminLeftAt: ws.lastAdminLeftAt, LastAdminLeftAt: ws.lastAdminLeftAt,