diff --git a/internal/session/manager.go b/internal/session/manager.go index 2d1ff7ba..2da614c4 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -142,7 +142,7 @@ func (manager *SessionManagerCtx) Delete(id string) error { manager.sessionsMu.Unlock() if session.State().IsConnected { - session.GetWebSocketPeer().Destroy("session deleted") + session.DestroyWebSocketPeer("session deleted") } if session.State().IsWatching { diff --git a/internal/session/session.go b/internal/session/session.go index ec1dff77..f5118407 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -51,7 +51,7 @@ func (session *SessionCtx) profileChanged() { } if (!session.profile.CanConnect || !session.profile.CanLogin) && session.state.IsConnected { - session.GetWebSocketPeer().Destroy("profile changed") + session.DestroyWebSocketPeer("profile changed") } // update webrtc paused state @@ -82,30 +82,49 @@ func (session *SessionCtx) SetCursor(cursor types.Cursor) { // websocket // --- -func (session *SessionCtx) SetWebSocketPeer(websocketPeer types.WebSocketPeer) { +// +// Connect WebSocket peer sets current peer and emits connected event. It also destroys the +// previous peer, if there was one. If the peer is already set, it will be ignored. +// +func (session *SessionCtx) ConnectWebSocketPeer(websocketPeer types.WebSocketPeer) { session.websocketMu.Lock() + isCurrentPeer := websocketPeer == session.websocketPeer session.websocketPeer, websocketPeer = websocketPeer, session.websocketPeer session.websocketMu.Unlock() - if websocketPeer != nil && websocketPeer != session.websocketPeer { + // ignore if already set + if isCurrentPeer { + return + } + + session.logger.Info().Msg("set websocket connected") + session.state.IsConnected = true + session.manager.emmiter.Emit("connected", session) + + // if there is a previous peer, destroy it + if websocketPeer != nil { websocketPeer.Destroy("connection replaced") } } -func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPeer, connected bool, delayed bool) { +// +// Disconnect WebSocket peer sets current peer to nil and emits disconnected event. It also +// allows for a delayed disconnect. That means, the peer will not be disconnected immediately, +// but after a delay. If the peer is connected again before the delay, the disconnect will be +// cancelled. +// +// If the peer is not the current peer or the peer is nil, it will be ignored. +// +func (session *SessionCtx) DisconnectWebSocketPeer(websocketPeer types.WebSocketPeer, delayed bool) { session.websocketMu.Lock() - isCurrentPeer := websocketPeer == session.websocketPeer + isCurrentPeer := websocketPeer == session.websocketPeer && websocketPeer != nil session.websocketMu.Unlock() + // ignore if not current peer if !isCurrentPeer { return } - session.logger.Info(). - Bool("connected", connected). - Bool("delayed", delayed). - Msg("set websocket connected") - // // ws delayed // @@ -114,7 +133,7 @@ func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPe if delayed { wsDelayedTimer = time.AfterFunc(WS_DELAYED_DURATION, func() { - session.SetWebSocketConnected(websocketPeer, connected, false) + session.DisconnectWebSocketPeer(websocketPeer, false) }) } @@ -126,6 +145,7 @@ func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPe session.wsDelayedMu.Unlock() if delayed { + session.logger.Info().Msg("delayed websocket disconnected") return } @@ -133,13 +153,8 @@ func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPe // not delayed // - session.state.IsConnected = connected - - if connected { - session.manager.emmiter.Emit("connected", session) - return - } - + session.logger.Info().Msg("set websocket disconnected") + session.state.IsConnected = false session.manager.emmiter.Emit("disconnected", session) session.websocketMu.Lock() @@ -149,15 +164,34 @@ func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPe session.websocketMu.Unlock() } -func (session *SessionCtx) GetWebSocketPeer() types.WebSocketPeer { +// +// Destroy WebSocket peer disconnects the peer and destroys it. It ensures that the peer is +// disconnected immediately even though normal flow would be to disconnect it delayed. +// +func (session *SessionCtx) DestroyWebSocketPeer(reason string) { session.websocketMu.Lock() - defer session.websocketMu.Unlock() + peer := session.websocketPeer + session.websocketMu.Unlock() - return session.websocketPeer + if peer == nil { + return + } + + // disconnect peer first, so that it is not used anymore + session.DisconnectWebSocketPeer(peer, false) + + // destroy it afterwards + peer.Destroy(reason) } +// +// Send event to websocket peer. +// func (session *SessionCtx) Send(event string, payload any) { - peer := session.GetWebSocketPeer() + session.websocketMu.Lock() + peer := session.websocketPeer + session.websocketMu.Unlock() + if peer != nil { peer.Send(event, payload) } @@ -167,6 +201,9 @@ func (session *SessionCtx) Send(event string, payload any) { // webrtc // --- +// +// Set webrtc peer and destroy the old one, if there is old one. +// func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) { session.webrtcMu.Lock() session.webrtcPeer, webrtcPeer = webrtcPeer, session.webrtcPeer @@ -177,6 +214,14 @@ func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) { } } +// +// Set if current webrtc peer is connected or not. Since there might be lefover calls from +// webrtc peer, that are not used anymore, we need to check if the webrtc peer is still the +// same as the one we are setting the connected state for. +// +// If webrtc peer is disconnected, we don't expect it to be reconnected, so we set it to nil +// and send a signal close to the client. New connection is expected to use a new webrtc peer. +// func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, connected bool) { session.webrtcMu.Lock() isCurrentPeer := webrtcPeer == session.webrtcPeer @@ -209,6 +254,9 @@ func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, conne } } +// +// Get current WebRTC peer. Nil if not connected. +// func (session *SessionCtx) GetWebRTCPeer() types.WebRTCPeer { session.webrtcMu.Lock() defer session.webrtcMu.Unlock() diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 282d14c1..5ad9a0f8 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -207,19 +207,18 @@ func (manager *WebSocketManagerCtx) Upgrade(checkOrigin types.CheckOrigin) types } func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http.Request) { - // create new peer - peer := newPeer(connection) - session, err := manager.sessions.Authenticate(r) if err != nil { manager.logger.Warn().Err(err).Msg("authentication failed") - peer.Destroy(err.Error()) + newPeer(manager.logger, connection).Destroy(err.Error()) return } // add session id to all log messages logger := manager.logger.With().Str("session_id", session.ID()).Logger() - peer.setSessionID(session.ID()) + + // create new peer + peer := newPeer(logger, connection) if !session.Profile().CanConnect { logger.Warn().Msg("connection disabled") @@ -238,14 +237,12 @@ func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http. logger.Info().Msg("replacing peer connection") } - session.SetWebSocketPeer(peer) - logger.Info(). Str("address", connection.RemoteAddr().String()). Str("agent", r.UserAgent()). Msg("connection started") - session.SetWebSocketConnected(peer, true, false) + session.ConnectWebSocketPeer(peer) // this is a blocking function that lives // throughout whole websocket connection @@ -277,7 +274,7 @@ func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http. } } - session.SetWebSocketConnected(peer, false, delayedDisconnect) + session.DisconnectWebSocketPeer(peer, delayedDisconnect) } func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer types.WebSocketPeer, session types.Session) error { diff --git a/internal/websocket/peer.go b/internal/websocket/peer.go index e3d29e08..203caa83 100644 --- a/internal/websocket/peer.go +++ b/internal/websocket/peer.go @@ -2,12 +2,10 @@ package websocket import ( "encoding/json" - "errors" "sync" "github.com/gorilla/websocket" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types/event" @@ -21,30 +19,17 @@ type WebSocketPeerCtx struct { connection *websocket.Conn } -func newPeer(connection *websocket.Conn) *WebSocketPeerCtx { - logger := log.With(). - Str("module", "websocket"). - Str("submodule", "peer"). - Logger() - +func newPeer(logger zerolog.Logger, connection *websocket.Conn) *WebSocketPeerCtx { return &WebSocketPeerCtx{ - logger: logger, + logger: logger.With().Str("submodule", "peer").Logger(), connection: connection, } } -func (peer *WebSocketPeerCtx) setSessionID(sessionId string) { - peer.logger = peer.logger.With().Str("session_id", sessionId).Logger() -} - func (peer *WebSocketPeerCtx) Send(event string, payload any) { peer.mu.Lock() defer peer.mu.Unlock() - if peer.connection == nil { - return - } - raw, err := json.Marshal(payload) if err != nil { peer.logger.Err(err).Str("event", event).Msg("message marshalling has failed") @@ -79,10 +64,6 @@ func (peer *WebSocketPeerCtx) Ping() error { peer.mu.Lock() defer peer.mu.Unlock() - if peer.connection == nil { - return errors.New("peer connection not found") - } - // application level heartbeat if err := peer.connection.WriteJSON(types.WebSocketMessage{ Event: event.SYSTEM_HEARTBEAT, @@ -103,9 +84,6 @@ func (peer *WebSocketPeerCtx) Destroy(reason string) { peer.mu.Lock() defer peer.mu.Unlock() - if peer.connection != nil { - err := peer.connection.Close() - peer.logger.Err(err).Msg("peer connection destroyed") - peer.connection = nil - } + err := peer.connection.Close() + peer.logger.Err(err).Msg("peer connection destroyed") } diff --git a/pkg/types/session.go b/pkg/types/session.go index 4527d687..45d5ae92 100644 --- a/pkg/types/session.go +++ b/pkg/types/session.go @@ -46,9 +46,9 @@ type Session interface { SetCursor(cursor Cursor) // websocket - SetWebSocketPeer(websocketPeer WebSocketPeer) - SetWebSocketConnected(websocketPeer WebSocketPeer, connected bool, delayed bool) - GetWebSocketPeer() WebSocketPeer + ConnectWebSocketPeer(websocketPeer WebSocketPeer) + DisconnectWebSocketPeer(websocketPeer WebSocketPeer, delayed bool) + DestroyWebSocketPeer(reason string) Send(event string, payload any) // webrtc