mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
move locks and bans to state.
This commit is contained in:
parent
06e25df962
commit
e045bd8a1e
@ -14,8 +14,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := h.Locked[payload.Resource]
|
if h.state.IsLocked(payload.Resource) {
|
||||||
if ok {
|
|
||||||
h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...")
|
h.logger.Debug().Str("resource", payload.Resource).Msg("resource already locked...")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -30,7 +29,7 @@ func (h *MessageHandler) adminLock(id string, session types.Session, payload *me
|
|||||||
h.sessions.SetControlLocked(true)
|
h.sessions.SetControlLocked(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Locked[payload.Resource] = id
|
h.state.Lock(payload.Resource, id)
|
||||||
|
|
||||||
if err := h.sessions.Broadcast(
|
if err := h.sessions.Broadcast(
|
||||||
message.AdminLock{
|
message.AdminLock{
|
||||||
@ -51,8 +50,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := h.Locked[payload.Resource]
|
if !h.state.IsLocked(payload.Resource) {
|
||||||
if !ok {
|
|
||||||
h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...")
|
h.logger.Debug().Str("resource", payload.Resource).Msg("resource not locked...")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -62,7 +60,7 @@ func (h *MessageHandler) adminUnlock(id string, session types.Session, payload *
|
|||||||
h.sessions.SetControlLocked(false)
|
h.sessions.SetControlLocked(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(h.Locked, payload.Resource)
|
h.state.Unlock(payload.Resource)
|
||||||
|
|
||||||
if err := h.sessions.Broadcast(
|
if err := h.sessions.Broadcast(
|
||||||
message.AdminLock{
|
message.AdminLock{
|
||||||
@ -302,7 +300,7 @@ func (h *MessageHandler) adminBan(id string, session types.Session, payload *mes
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
|
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
|
||||||
h.Banned[address[0]] = id
|
h.state.Ban(address[0], id)
|
||||||
|
|
||||||
if err := target.Kick("banned"); err != nil {
|
if err := target.Kick("banned"); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -34,8 +34,7 @@ func (h *MessageHandler) controlRequest(id string, session types.Session) error
|
|||||||
// check for host
|
// check for host
|
||||||
if !h.sessions.HasHost() {
|
if !h.sessions.HasHost() {
|
||||||
// check if control is locked or user is admin
|
// check if control is locked or user is admin
|
||||||
_, ok := h.Locked["control"]
|
if h.state.IsLocked("control") && !session.Admin() {
|
||||||
if ok && !session.Admin() {
|
|
||||||
h.logger.Debug().Msg("control is locked")
|
h.logger.Debug().Msg("control is locked")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -98,8 +97,7 @@ func (h *MessageHandler) controlGive(id string, session types.Session, payload *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if control is locked or giver is admin
|
// check if control is locked or giver is admin
|
||||||
_, ok := h.Locked["control"]
|
if h.state.IsLocked("control") && !session.Admin() {
|
||||||
if ok && !session.Admin() {
|
|
||||||
h.logger.Debug().Msg("control is locked")
|
h.logger.Debug().Msg("control is locked")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"m1k1o/neko/internal/types/event"
|
"m1k1o/neko/internal/types/event"
|
||||||
"m1k1o/neko/internal/types/message"
|
"m1k1o/neko/internal/types/message"
|
||||||
"m1k1o/neko/internal/utils"
|
"m1k1o/neko/internal/utils"
|
||||||
|
"m1k1o/neko/internal/websocket/state"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MessageHandler struct {
|
type MessageHandler struct {
|
||||||
@ -20,9 +21,7 @@ type MessageHandler struct {
|
|||||||
capture types.CaptureManager
|
capture types.CaptureManager
|
||||||
webrtc types.WebRTCManager
|
webrtc types.WebRTCManager
|
||||||
broadcast types.BroadcastManager
|
broadcast types.BroadcastManager
|
||||||
|
state *state.State
|
||||||
Banned map[string]string // IP -> session ID (that banned it)
|
|
||||||
Locked map[string]string // resource name -> session ID (that locked it)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(
|
||||||
@ -31,6 +30,7 @@ func New(
|
|||||||
capture types.CaptureManager,
|
capture types.CaptureManager,
|
||||||
webrtc types.WebRTCManager,
|
webrtc types.WebRTCManager,
|
||||||
broadcast types.BroadcastManager,
|
broadcast types.BroadcastManager,
|
||||||
|
state *state.State,
|
||||||
) *MessageHandler {
|
) *MessageHandler {
|
||||||
return &MessageHandler{
|
return &MessageHandler{
|
||||||
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
|
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
|
||||||
@ -39,8 +39,7 @@ func New(
|
|||||||
capture: capture,
|
capture: capture,
|
||||||
webrtc: webrtc,
|
webrtc: webrtc,
|
||||||
broadcast: broadcast,
|
broadcast: broadcast,
|
||||||
Banned: make(map[string]string),
|
state: state,
|
||||||
Locked: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,15 +47,13 @@ func (h *MessageHandler) Connected(admin bool, address string) (bool, string) {
|
|||||||
if address == "" {
|
if address == "" {
|
||||||
h.logger.Debug().Msg("no remote address")
|
h.logger.Debug().Msg("no remote address")
|
||||||
} else {
|
} else {
|
||||||
_, ok := h.Banned[address]
|
if h.state.IsBanned(address) {
|
||||||
if ok {
|
|
||||||
h.logger.Debug().Str("address", address).Msg("banned")
|
h.logger.Debug().Str("address", address).Msg("banned")
|
||||||
return false, "banned"
|
return false, "banned"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := h.Locked["login"]
|
if h.state.IsLocked("login") && !admin {
|
||||||
if ok && !admin {
|
|
||||||
h.logger.Debug().Msg("server locked")
|
h.logger.Debug().Msg("server locked")
|
||||||
return false, "locked"
|
return false, "locked"
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ func (h *MessageHandler) SessionCreated(id string, session types.Session) error
|
|||||||
if err := session.Send(message.SystemInit{
|
if err := session.Send(message.SystemInit{
|
||||||
Event: event.SYSTEM_INIT,
|
Event: event.SYSTEM_INIT,
|
||||||
ImplicitHosting: h.webrtc.ImplicitControl(),
|
ImplicitHosting: h.webrtc.ImplicitControl(),
|
||||||
Locks: h.Locked,
|
Locks: h.state.AllLocked(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT)
|
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.SYSTEM_INIT)
|
||||||
return err
|
return err
|
||||||
|
61
server/internal/websocket/state/state.go
Normal file
61
server/internal/websocket/state/state.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
banned map[string]string // IP -> session ID (that banned it)
|
||||||
|
locked map[string]string // resource name -> session ID (that locked it)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *State {
|
||||||
|
return &State{
|
||||||
|
banned: make(map[string]string),
|
||||||
|
locked: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ban
|
||||||
|
|
||||||
|
func (s *State) Ban(ip, id string) {
|
||||||
|
s.banned[ip] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Unban(ip string) {
|
||||||
|
delete(s.banned, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) IsBanned(ip string) bool {
|
||||||
|
_, ok := s.banned[ip]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) GetBanned(ip string) (string, bool) {
|
||||||
|
id, ok := s.banned[ip]
|
||||||
|
return id, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) AllBanned() map[string]string {
|
||||||
|
return s.banned
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock
|
||||||
|
|
||||||
|
func (s *State) Lock(resource, id string) {
|
||||||
|
s.locked[resource] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Unlock(resource string) {
|
||||||
|
delete(s.locked, resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) IsLocked(resource string) bool {
|
||||||
|
_, ok := s.locked[resource]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) GetLocked(resource string) (string, bool) {
|
||||||
|
id, ok := s.locked[resource]
|
||||||
|
return id, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) AllLocked() map[string]string {
|
||||||
|
return s.locked
|
||||||
|
}
|
@ -17,6 +17,7 @@ import (
|
|||||||
"m1k1o/neko/internal/types/message"
|
"m1k1o/neko/internal/types/message"
|
||||||
"m1k1o/neko/internal/utils"
|
"m1k1o/neko/internal/utils"
|
||||||
"m1k1o/neko/internal/websocket/handler"
|
"m1k1o/neko/internal/websocket/handler"
|
||||||
|
"m1k1o/neko/internal/websocket/state"
|
||||||
)
|
)
|
||||||
|
|
||||||
const CONTROL_PROTECTION_SESSION = "by_control_protection"
|
const CONTROL_PROTECTION_SESSION = "by_control_protection"
|
||||||
@ -24,17 +25,17 @@ const CONTROL_PROTECTION_SESSION = "by_control_protection"
|
|||||||
func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
func New(sessions types.SessionManager, desktop types.DesktopManager, capture types.CaptureManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||||
logger := log.With().Str("module", "websocket").Logger()
|
logger := log.With().Str("module", "websocket").Logger()
|
||||||
|
|
||||||
locks := make(map[string]string)
|
state := state.New()
|
||||||
|
|
||||||
// if control protection is enabled
|
// if control protection is enabled
|
||||||
if conf.ControlProtection {
|
if conf.ControlProtection {
|
||||||
locks["control"] = CONTROL_PROTECTION_SESSION
|
state.Lock("control", CONTROL_PROTECTION_SESSION)
|
||||||
logger.Info().Msgf("control locked on behalf of control protection")
|
logger.Info().Msgf("control locked on behalf of control protection")
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply default locks
|
// apply default locks
|
||||||
for _, lock := range conf.Locks {
|
for _, lock := range conf.Locks {
|
||||||
locks[lock] = "" // empty session ID
|
state.Lock(lock, "") // empty session ID
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(conf.Locks) > 0 {
|
if len(conf.Locks) > 0 {
|
||||||
@ -47,11 +48,9 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
|
|||||||
capture,
|
capture,
|
||||||
webrtc,
|
webrtc,
|
||||||
broadcast,
|
broadcast,
|
||||||
|
state,
|
||||||
)
|
)
|
||||||
|
|
||||||
// set inital locks
|
|
||||||
handler.Locked = locks
|
|
||||||
|
|
||||||
return &WebSocketHandler{
|
return &WebSocketHandler{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
shutdown: make(chan interface{}),
|
shutdown: make(chan interface{}),
|
||||||
@ -59,6 +58,7 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty
|
|||||||
sessions: sessions,
|
sessions: sessions,
|
||||||
desktop: desktop,
|
desktop: desktop,
|
||||||
webrtc: webrtc,
|
webrtc: webrtc,
|
||||||
|
state: state,
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true
|
return true
|
||||||
@ -80,6 +80,7 @@ type WebSocketHandler struct {
|
|||||||
sessions types.SessionManager
|
sessions types.SessionManager
|
||||||
desktop types.DesktopManager
|
desktop types.DesktopManager
|
||||||
webrtc types.WebRTCManager
|
webrtc types.WebRTCManager
|
||||||
|
state *state.State
|
||||||
conf *config.WebSocket
|
conf *config.WebSocket
|
||||||
handler *handler.MessageHandler
|
handler *handler.MessageHandler
|
||||||
|
|
||||||
@ -108,9 +109,9 @@ func (ws *WebSocketHandler) Start() {
|
|||||||
|
|
||||||
// if control protection is enabled and at least one admin
|
// if control protection is enabled and at least one admin
|
||||||
// and if room was locked on behalf control protection, unlock
|
// and if room was locked on behalf control protection, unlock
|
||||||
sess, ok := ws.handler.Locked["control"]
|
sess, ok := ws.state.GetLocked("control")
|
||||||
if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 {
|
if ok && ws.conf.ControlProtection && sess == CONTROL_PROTECTION_SESSION && len(ws.sessions.Admins()) > 0 {
|
||||||
delete(ws.handler.Locked, "control")
|
ws.state.Unlock("control")
|
||||||
ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags.
|
ws.sessions.SetControlLocked(false) // TODO: Handle locks in sessions as flags.
|
||||||
ws.logger.Info().Msgf("control unlocked on behalf of control protection")
|
ws.logger.Info().Msgf("control unlocked on behalf of control protection")
|
||||||
|
|
||||||
@ -144,9 +145,9 @@ func (ws *WebSocketHandler) Start() {
|
|||||||
|
|
||||||
// if control protection is enabled and no admin
|
// if control protection is enabled and no admin
|
||||||
// and room is not locked, lock
|
// and room is not locked, lock
|
||||||
_, ok := ws.handler.Locked["control"]
|
ok := ws.state.IsLocked("control")
|
||||||
if !ok && ws.conf.ControlProtection && adminCount == 0 {
|
if !ok && ws.conf.ControlProtection && adminCount == 0 {
|
||||||
ws.handler.Locked["control"] = CONTROL_PROTECTION_SESSION
|
ws.state.Lock("control", CONTROL_PROTECTION_SESSION)
|
||||||
ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags.
|
ws.sessions.SetControlLocked(true) // TODO: Handle locks in sessions as flags.
|
||||||
ws.logger.Info().Msgf("control locked and released on behalf of control protection")
|
ws.logger.Info().Msgf("control locked and released on behalf of control protection")
|
||||||
ws.handler.AdminRelease(id, session)
|
ws.handler.AdminRelease(id, session)
|
||||||
@ -314,8 +315,8 @@ func (ws *WebSocketHandler) Stats() types.Stats {
|
|||||||
Host: host,
|
Host: host,
|
||||||
Members: ws.sessions.Members(),
|
Members: ws.sessions.Members(),
|
||||||
|
|
||||||
Banned: ws.handler.Banned,
|
Banned: ws.state.AllBanned(),
|
||||||
Locked: ws.handler.Locked,
|
Locked: ws.state.AllLocked(),
|
||||||
|
|
||||||
ServerStartedAt: ws.serverStartedAt,
|
ServerStartedAt: ws.serverStartedAt,
|
||||||
LastAdminLeftAt: ws.lastAdminLeftAt,
|
LastAdminLeftAt: ws.lastAdminLeftAt,
|
||||||
|
Loading…
Reference in New Issue
Block a user