diff --git a/internal/session/session.go b/internal/session/session.go index b4359d4e..e7d8a847 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -1,20 +1,26 @@ package session import ( + "sync" + "github.com/rs/zerolog" "demodesk/neko/internal/types" ) type SessionCtx struct { - id string - token string - logger zerolog.Logger - manager *SessionManagerCtx - profile types.MemberProfile - state types.SessionState + id string + token string + logger zerolog.Logger + manager *SessionManagerCtx + profile types.MemberProfile + state types.SessionState + websocketPeer types.WebSocketPeer - webrtcPeer types.WebRTCPeer + websocketMu sync.Mutex + + webrtcPeer types.WebRTCPeer + webrtcMu sync.Mutex } func (session *SessionCtx) ID() string { @@ -31,11 +37,11 @@ func (session *SessionCtx) profileChanged() { } if (!session.profile.CanConnect || !session.profile.CanLogin || !session.profile.CanWatch) && session.state.IsWatching { - session.webrtcPeer.Destroy() + session.GetWebRTCPeer().Destroy() } if (!session.profile.CanConnect || !session.profile.CanLogin) && session.state.IsConnected { - session.websocketPeer.Destroy("profile changed") + session.GetWebSocketPeer().Destroy("profile changed") } } @@ -44,7 +50,7 @@ func (session *SessionCtx) State() types.SessionState { } func (session *SessionCtx) IsHost() bool { - return session.manager.host != nil && session.manager.host == session + return session.manager.GetHost() == session } // --- @@ -52,39 +58,50 @@ func (session *SessionCtx) IsHost() bool { // --- func (session *SessionCtx) SetWebSocketPeer(websocketPeer types.WebSocketPeer) { - if session.websocketPeer != nil { - session.websocketPeer.Destroy("connection replaced") - } + session.websocketMu.Lock() + session.websocketPeer, websocketPeer = websocketPeer, session.websocketPeer + session.websocketMu.Unlock() - session.websocketPeer = websocketPeer + if websocketPeer != nil && websocketPeer != session.websocketPeer { + websocketPeer.Destroy("connection replaced") + } } func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPeer, connected bool) { + session.websocketMu.Lock() if websocketPeer != session.websocketPeer { + session.websocketMu.Unlock() return } + session.websocketMu.Unlock() session.state.IsConnected = connected if connected { session.manager.emmiter.Emit("connected", session) - return - } + } else { + session.manager.emmiter.Emit("disconnected", session) - session.manager.emmiter.Emit("disconnected", session) - session.websocketPeer = nil + session.websocketMu.Lock() + if websocketPeer == session.websocketPeer { + session.websocketPeer = nil + } + session.websocketMu.Unlock() + } } func (session *SessionCtx) GetWebSocketPeer() types.WebSocketPeer { + session.websocketMu.Lock() + defer session.websocketMu.Unlock() + return session.websocketPeer } func (session *SessionCtx) Send(event string, payload interface{}) { - if session.websocketPeer == nil { - return + peer := session.GetWebSocketPeer() + if peer != nil { + peer.Send(event, payload) } - - session.websocketPeer.Send(event, payload) } // --- @@ -92,26 +109,38 @@ func (session *SessionCtx) Send(event string, payload interface{}) { // --- func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) { - if session.webrtcPeer != nil { - session.webrtcPeer.Destroy() - } + session.webrtcMu.Lock() + session.webrtcPeer, webrtcPeer = webrtcPeer, session.webrtcPeer + session.webrtcMu.Unlock() - session.webrtcPeer = webrtcPeer + if webrtcPeer != nil && webrtcPeer != session.webrtcPeer { + webrtcPeer.Destroy() + } } func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, connected bool) { + session.webrtcMu.Lock() if webrtcPeer != session.webrtcPeer { + session.webrtcMu.Unlock() return } + session.webrtcMu.Unlock() session.state.IsWatching = connected session.manager.emmiter.Emit("state_changed", session) if !connected { - session.webrtcPeer = nil + session.webrtcMu.Lock() + if webrtcPeer == session.webrtcPeer { + session.webrtcPeer = nil + } + session.webrtcMu.Unlock() } } func (session *SessionCtx) GetWebRTCPeer() types.WebRTCPeer { + session.webrtcMu.Lock() + defer session.webrtcMu.Unlock() + return session.webrtcPeer }