diff --git a/internal/session/manager.go b/internal/session/manager.go index 6871e6e9..67095db9 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -40,7 +40,7 @@ func (manager *SessionManager) New(id string, admin bool, socket types.WebSocket } manager.members[id] = session - manager.emmiter.Emit("created", id, session) + manager.emmiter.Emit("created", session) if !manager.remote.Streaming() && len(manager.members) > 0 { manager.remote.StartStream() @@ -53,15 +53,11 @@ func (manager *SessionManager) HasHost() bool { return manager.host != "" } -func (manager *SessionManager) IsHost(id string) bool { - return manager.host == id -} - func (manager *SessionManager) SetHost(id string) error { - _, ok := manager.members[id] + host, ok := manager.GetHost() if ok { manager.host = id - manager.emmiter.Emit("host", id) + manager.emmiter.Emit("host", host) return nil } return fmt.Errorf("invalid session id %s", id) @@ -73,9 +69,11 @@ func (manager *SessionManager) GetHost() (types.Session, bool) { } func (manager *SessionManager) ClearHost() { - id := manager.host + host, ok := manager.GetHost() manager.host = "" - manager.emmiter.Emit("host_cleared", id) + if ok { + manager.emmiter.Emit("host_cleared", host) + } } func (manager *SessionManager) Has(id string) bool { @@ -122,13 +120,13 @@ func (manager *SessionManager) Destroy(id string) error { session, ok := manager.members[id] if ok { err := session.destroy() - delete(manager.members, id) if !manager.remote.Streaming() && len(manager.members) <= 0 { manager.remote.StopStream() } - manager.emmiter.Emit("destroyed", id, session) + manager.emmiter.Emit("before_destroy", session) + delete(manager.members, id) return err } @@ -154,32 +152,32 @@ func (manager *SessionManager) Broadcast(v interface{}, exclude interface{}) err return nil } -func (manager *SessionManager) OnHost(listener func(id string)) { +func (manager *SessionManager) OnHost(listener func(session types.Session)) { manager.emmiter.On("host", func(payload ...interface{}) { - listener(payload[0].(string)) + listener(payload[0].(*Session)) }) } -func (manager *SessionManager) OnHostCleared(listener func(id string)) { +func (manager *SessionManager) OnHostCleared(listener func(session types.Session)) { manager.emmiter.On("host_cleared", func(payload ...interface{}) { - listener(payload[0].(string)) + listener(payload[0].(*Session)) }) } -func (manager *SessionManager) OnDestroy(listener func(id string, session types.Session)) { - manager.emmiter.On("destroyed", func(payload ...interface{}) { - listener(payload[0].(string), payload[1].(*Session)) +func (manager *SessionManager) OnBeforeDestroy(listener func(session types.Session)) { + manager.emmiter.On("before_destroy", func(payload ...interface{}) { + listener(payload[0].(*Session)) }) } -func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) { +func (manager *SessionManager) OnCreated(listener func(session types.Session)) { manager.emmiter.On("created", func(payload ...interface{}) { - listener(payload[0].(string), payload[1].(*Session)) + listener(payload[0].(*Session)) }) } -func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) { +func (manager *SessionManager) OnConnected(listener func(session types.Session)) { manager.emmiter.On("connected", func(payload ...interface{}) { - listener(payload[0].(string), payload[1].(*Session)) + listener(payload[0].(*Session)) }) } diff --git a/internal/session/session.go b/internal/session/session.go index c1e2b429..4074f4f8 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -79,7 +79,7 @@ func (session *Session) SetPeer(peer types.Peer) { func (session *Session) SetConnected() { session.connected = true - session.manager.emmiter.Emit("connected", session.id, session) + session.manager.emmiter.Emit("connected", session) } func (session *Session) Disconnect(reason string) error { diff --git a/internal/types/session.go b/internal/types/session.go index e72ac5dc..f7a75e42 100644 --- a/internal/types/session.go +++ b/internal/types/session.go @@ -29,7 +29,6 @@ type Session interface { type SessionManager interface { New(id string, admin bool, socket WebSocket) Session HasHost() bool - IsHost(id string) bool SetHost(id string) error GetHost() (Session, bool) ClearHost() @@ -39,9 +38,9 @@ type SessionManager interface { Admins() []*Member Destroy(id string) error Broadcast(v interface{}, exclude interface{}) error - OnHost(listener func(id string)) - OnHostCleared(listener func(id string)) - OnDestroy(listener func(id string, session Session)) - OnCreated(listener func(id string, session Session)) - OnConnected(listener func(id string, session Session)) + OnHost(listener func(session Session)) + OnHostCleared(listener func(session Session)) + OnBeforeDestroy(listener func(session Session)) + OnCreated(listener func(session Session)) + OnConnected(listener func(session Session)) } diff --git a/internal/websocket/control.go b/internal/websocket/control.go index 9d3b39cd..828e3de1 100644 --- a/internal/websocket/control.go +++ b/internal/websocket/control.go @@ -8,7 +8,7 @@ import ( func (h *MessageHandler) controlRelease(session types.Session) error { // check if session is host - if !h.sessions.IsHost(session.ID()) { + if !session.IsHost() { h.logger.Debug().Str("id", session.ID()).Msg("is not the host") return nil } @@ -80,7 +80,7 @@ func (h *MessageHandler) controlRequest(session types.Session) error { func (h *MessageHandler) controlGive(session types.Session, payload *message.Control) error { // check if session is host - if !h.sessions.IsHost(session.ID()) { + if !session.IsHost() { h.logger.Debug().Str("id", session.ID()).Msg("is not the host") return nil } @@ -112,7 +112,7 @@ func (h *MessageHandler) controlGive(session types.Session, payload *message.Con func (h *MessageHandler) controlClipboard(session types.Session, payload *message.Clipboard) error { // check if session is host - if !h.sessions.IsHost(session.ID()) { + if !session.IsHost() { h.logger.Debug().Str("id", session.ID()).Msg("is not the host") return nil } @@ -123,7 +123,7 @@ func (h *MessageHandler) controlClipboard(session types.Session, payload *messag func (h *MessageHandler) controlKeyboard(session types.Session, payload *message.Keyboard) error { // check if session is host - if !h.sessions.IsHost(session.ID()) { + if !session.IsHost() { h.logger.Debug().Str("id", session.ID()).Msg("is not the host") return nil } diff --git a/internal/websocket/session.go b/internal/websocket/session.go index ea10e0bb..5f2e000e 100644 --- a/internal/websocket/session.go +++ b/internal/websocket/session.go @@ -67,13 +67,13 @@ func (h *MessageHandler) SessionConnected(session types.Session) error { return nil } -func (h *MessageHandler) SessionDestroyed(id string) error { +func (h *MessageHandler) SessionDestroyed(session types.Session) error { // clear host if exists - if h.sessions.IsHost(id) { + if session.IsHost() { h.sessions.ClearHost() if err := h.sessions.Broadcast(message.Control{ Event: event.CONTROL_RELEASE, - ID: id, + ID: session.ID(), }, nil); err != nil { h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE) } @@ -83,7 +83,7 @@ func (h *MessageHandler) SessionDestroyed(id string) error { if err := h.sessions.Broadcast( message.MemberDisconnected{ Event: event.MEMBER_DISCONNECTED, - ID: id, + ID: session.ID(), }, nil); err != nil { h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.MEMBER_DISCONNECTED) return err diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index 3bad5c29..fef8a888 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -55,27 +55,27 @@ type WebSocketHandler struct { } func (ws *WebSocketHandler) Start() { - ws.sessions.OnCreated(func(id string, session types.Session) { + ws.sessions.OnCreated(func(session types.Session) { if err := ws.handler.SessionCreated(session); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error") + ws.logger.Warn().Str("id", session.ID()).Err(err).Msg("session created with and error") } else { - ws.logger.Debug().Str("id", id).Msg("session created") + ws.logger.Debug().Str("id", session.ID()).Msg("session created") } }) - ws.sessions.OnConnected(func(id string, session types.Session) { + ws.sessions.OnConnected(func(session types.Session) { if err := ws.handler.SessionConnected(session); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error") + ws.logger.Warn().Str("id", session.ID()).Err(err).Msg("session connected with and error") } else { - ws.logger.Debug().Str("id", id).Msg("session connected") + ws.logger.Debug().Str("id", session.ID()).Msg("session connected") } }) - ws.sessions.OnDestroy(func(id string, session types.Session) { - if err := ws.handler.SessionDestroyed(id); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error") + ws.sessions.OnBeforeDestroy(func(session types.Session) { + if err := ws.handler.SessionDestroyed(session); err != nil { + ws.logger.Warn().Str("id", session.ID()).Err(err).Msg("session destroyed with and error") } else { - ws.logger.Debug().Str("id", id).Msg("session destroyed") + ws.logger.Debug().Str("id", session.ID()).Msg("session destroyed") } })