move locks and bans to state.
This commit is contained in:
parent
06e25df962
commit
e045bd8a1e
@ -14,8 +14,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
|
||||
return nil
|
||||
}
|
||||
|
||||
_, ok := h.Locked[payload.Resource]
|
||||
if ok {
|
||||
if h.state.IsLocked(payload.Resource) {
|
||||
h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...")
|
||||
return nil
|
||||
}
|
||||
@ -30,7 +29,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
|
||||
h.sessions.SetControlLocked(true)
|
||||
}
|
||||
|
||||
h.Locked[payload.Resource] = id
|
||||
h.state.Lock(payload.Resource, id)
|
||||
|
||||
if err := h.sessions.Broadcast(
|
||||
message.AdminLock{
|
||||
@ -51,8 +50,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
|
||||
return nil
|
||||
}
|
||||
|
||||
_, ok := h.Locked[payload.Resource]
|
||||
if !ok {
|
||||
if !h.state.IsLocked(payload.Resource) {
|
||||
h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...")
|
||||
return nil
|
||||
}
|
||||
@ -62,7 +60,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
|
||||
h.sessions.SetControlLocked(false)
|
||||
}
|
||||
|
||||
delete(h.Locked, payload.Resource)
|
||||
h.state.Unlock(payload.Resource)
|
||||
|
||||
if err := h.sessions.Broadcast(
|
||||
message.AdminLock{
|
||||
@ -302,7 +300,7 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
|
||||
}
|
||||
|
||||
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
|
||||
h.Banned[address[0]] = id
|
||||
h.state.Ban(address[0], id)
|
||||
|
||||
if err := target.Kick("banned"); err != nil {
|
||||
return err
|
||||
|
@ -34,8 +34,7 @@ func (h *MessageHandler) controlRequest(id string, session types.Session) error
|
||||
// check for host
|
||||
if !h.sessions.HasHost() {
|
||||
// check if control is locked or user is admin
|
||||
_, ok := h.Locked["control"]
|
||||
if ok && !session.Admin() {
|
||||
if h.state.IsLocked("control") && !session.Admin() {
|
||||
h.logger.Debug().Msg("control is locked")
|
||||
return nil
|
||||
}
|
||||
@ -98,8 +97,7 @@ func (h *MessageHandler) controlGive(id string, session types.Session, payload *
|
||||
}
|
||||
|
||||
// check if control is locked or giver is admin
|
||||
_, ok := h.Locked["control"]
|
||||
if ok && !session.Admin() {
|
||||
if h.state.IsLocked("control") && !session.Admin() {
|
||||
h.logger.Debug().Msg("control is locked")
|
||||
return nil
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"m1k1o/neko/internal/types/event"
|
||||
"m1k1o/neko/internal/types/message"
|
||||
"m1k1o/neko/internal/utils"
|
||||
"m1k1o/neko/internal/websocket/state"
|
||||
)
|
||||
|
||||
type MessageHandler struct {
|
||||
@ -20,9 +21,7 @@ type MessageHandler struct {
|
||||
capture types.CaptureManager
|
||||
webrtc types.WebRTCManager
|
||||
broadcast types.BroadcastManager
|
||||
|
||||
Banned map[string]string // IP -> session ID (that banned it)
|
||||
Locked map[string]string // resource name -> session ID (that locked it)
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func New(
|
||||
@ -31,6 +30,7 @@ func New(
|
||||
capture types.CaptureManager,
|
||||
webrtc types.WebRTCManager,
|
||||
broadcast types.BroadcastManager,
|
||||
state *state.State,
|
||||
) *MessageHandler {
|
||||
return &MessageHandler{
|
||||
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
|
||||
@ -39,8 +39,7 @@ func New(
|
||||
capture: capture,
|
||||
webrtc: webrtc,
|
||||
broadcast: broadcast,
|
||||
Banned: make(map[string]string),
|
||||
Locked: make(map[string]string),
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,15 +47,13 @@ func (h *MessageHandler) Connected(admin bool, address string) (bool, string) {
|
||||
if address == "" {
|
||||
h.logger.Debug().Msg("no remote address")
|
||||
} else {
|
||||
_, ok := h.Banned[address]
|
||||
if ok {
|
||||
if h.state.IsBanned(address) {
|
||||
h.logger.Debug().Str("address", address).Msg("banned")
|
||||
return false, "banned"
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := h.Locked["login"]
|
||||
if ok && !admin {
|
||||
if h.state.IsLocked("login") && !admin {
|
||||
h.logger.Debug().Msg("server locked")
|
||||
return false, "locked"
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func (h *MessageHandler) SessionCreated(id string, session types.Session) error
|
||||
if err := session.Send(message.SystemInit{
|
||||
Event: event.SYSTEM_INIT,
|
||||
ImplicitHosting: h.webrtc.ImplicitControl(),
|
||||
Locks: h.Locked,
|
||||
Locks: h.state.AllLocked(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT)
|
||||
return err
|
||||
|
61
server/internal/websocket/state/state.go
Normal file
61
server/internal/websocket/state/state.go
Normal file
@ -0,0 +1,61 @@
|
||||
package state
|
||||
|
||||
type State struct {
|
||||
banned map[string]string // IP -> session ID (that banned it)
|
||||
locked map[string]string // resource name -> session ID (that locked it)
|
||||
}
|
||||
|
||||
func New() *State {
|
||||
return &State{
|
||||
banned: make(map[string]string),
|
||||
locked: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Ban
|
||||
|
||||
func (s *State) Ban(ip, id string) {
|
||||
s.banned[ip] = id
|
||||
}
|
||||
|
||||
func (s *State) Unban(ip string) {
|
||||
delete(s.banned, ip)
|
||||
}
|
||||
|
||||
func (s *State) IsBanned(ip string) bool {
|
||||
_, ok := s.banned[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *State) GetBanned(ip string) (string, bool) {
|
||||
id, ok := s.banned[ip]
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func (s *State) AllBanned() map[string]string {
|
||||
return s.banned
|
||||
}
|
||||
|
||||
// Lock
|
||||
|
||||
func (s *State) Lock(resource, id string) {
|
||||
s.locked[resource] = id
|
||||
}
|
||||
|
||||
func (s *State) Unlock(resource string) {
|
||||
delete(s.locked, resource)
|
||||
}
|
||||
|
||||
func (s *State) IsLocked(resource string) bool {
|
||||
_, ok := s.locked[resource]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *State) GetLocked(resource string) (string, bool) {
|
||||
id, ok := s.locked[resource]
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func (s *State) AllLocked() map[string]string {
|
||||
return s.locked
|
||||
}
|
@ -17,6 +17,7 @@ import (
|
||||
"m1k1o/neko/internal/types/message"
|
||||
"m1k1o/neko/internal/utils"
|
||||
"m1k1o/neko/internal/websocket/handler"
|
||||
"m1k1o/neko/internal/websocket/state"
|
||||
)
|
||||
|
||||
const CONTROL_PROTECTION_SESSION = "by_control_protection"
|
||||
@ -24,17 +25,17 @@ const CONTROL_PROTECTION_SESSION = "by_control_protection"
|
||||
func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
locks := make(map[string]string)
|
||||
state := state.New()
|
||||
|
||||
// if control protection is enabled
|
||||
if conf.ControlProtection {
|
||||
locks["control"] = CONTROL_PROTECTION_SESSION
|
||||
state.Lock("control", CONTROL_PROTECTION_SESSION)
|
||||
logger.Info().Msgf("control locked on behalf of control protection")
|
||||
}
|
||||
|
||||
// apply default locks
|
||||
for _, lock := range conf.Locks {
|
||||
locks[lock] = "" // empty session ID
|
||||
state.Lock(lock, "") // empty session ID
|
||||
}
|
||||
|
||||
if len(conf.Locks) > 0 {
|
||||
@ -47,11 +48,9 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
|
||||
capture,
|
||||
webrtc,
|
||||
broadcast,
|
||||
state,
|
||||
)
|
||||
|
||||
// set inital locks
|
||||
handler.Locked = locks
|
||||
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
shutdown: make(chan interface{}),
|
||||
@ -59,6 +58,7 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
|
||||
sessions: sessions,
|
||||
desktop: desktop,
|
||||
webrtc: webrtc,
|
||||
state: state,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
@ -80,6 +80,7 @@ type WebSocketHandler struct {
|
||||
sessions types.SessionManager
|
||||
desktop types.DesktopManager
|
||||
webrtc types.WebRTCManager
|
||||
state *state.State
|
||||
conf *config.WebSocket
|
||||
handler *handler.MessageHandler
|
||||
|
||||
@ -108,9 +109,9 @@ func (ws *WebSocketHandler) Start() {
|
||||
|
||||
// if control protection is enabled and at least one admin
|
||||
// and if room was locked on behalf control protection, unlock
|
||||
sess, ok := ws.handler.Locked["control"]
|
||||
sess, ok := ws.state.GetLocked("control")
|
||||
if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 {
|
||||
delete(ws.handler.Locked, "control")
|
||||
ws.state.Unlock("control")
|
||||
ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags.
|
||||
ws.logger.Info().Msgf("control unlocked on behalf of control protection")
|
||||
|
||||
@ -144,9 +145,9 @@ func (ws *WebSocketHandler) Start() {
|
||||
|
||||
// if control protection is enabled and no admin
|
||||
// and room is not locked, lock
|
||||
_, ok := ws.handler.Locked["control"]
|
||||
ok := ws.state.IsLocked("control")
|
||||
if !ok && ws.conf.ControlProtection && adminCount == 0 {
|
||||
ws.handler.Locked["control"] = CONTROL_PROTECTION_SESSION
|
||||
ws.state.Lock("control", CONTROL_PROTECTION_SESSION)
|
||||
ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags.
|
||||
ws.logger.Info().Msgf("control locked and released on behalf of control protection")
|
||||
ws.handler.AdminRelease(id, session)
|
||||
@ -314,8 +315,8 @@ func (ws *WebSocketHandler) Stats() types.Stats {
|
||||
Host: host,
|
||||
Members: ws.sessions.Members(),
|
||||
|
||||
Banned: ws.handler.Banned,
|
||||
Locked: ws.handler.Locked,
|
||||
Banned: ws.state.AllBanned(),
|
||||
Locked: ws.state.AllLocked(),
|
||||
|
||||
ServerStartedAt: ws.serverStartedAt,
|
||||
LastAdminLeftAt: ws.lastAdminLeftAt,
|
||||
|
Reference in New Issue
Block a user