From f96466b7b9ed205fc75b60920a75dac60b1599e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Sat, 18 Sep 2021 14:59:15 +0200 Subject: [PATCH] websocket upgrade as router handler. --- internal/http/manager.go | 8 +++----- internal/types/websocket.go | 2 +- internal/websocket/manager.go | 38 ++++++++++++++++++++--------------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/internal/http/manager.go b/internal/http/manager.go index 39fa170c..9589114a 100644 --- a/internal/http/manager.go +++ b/internal/http/manager.go @@ -37,11 +37,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c router.Route("/api", ApiManager.Route) - router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error { - return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { - return config.AllowOrigin(r.Header.Get("Origin")) - }) - }) + router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool { + return config.AllowOrigin(r.Header.Get("Origin")) + })) if config.Static != "" { fs := http.FileServer(http.Dir(config.Static)) diff --git a/internal/types/websocket.go b/internal/types/websocket.go index e8c22d6f..f876b0fc 100644 --- a/internal/types/websocket.go +++ b/internal/types/websocket.go @@ -23,5 +23,5 @@ type WebSocketManager interface { Start() Shutdown() error AddHandler(handler WebSocketHandler) - Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error + Upgrade(checkOrigin CheckOrigin) RouterHandler } diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 6a0e5a72..ef9aed37 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -13,6 +13,7 @@ import ( "demodesk/neko/internal/types" "demodesk/neko/internal/types/event" "demodesk/neko/internal/types/message" + "demodesk/neko/internal/utils" "demodesk/neko/internal/websocket/handler" ) @@ -145,21 +146,26 @@ func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) { manager.handlers = append(manager.handlers, handler) } -func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error { - manager.logger.Debug(). - Str("address", r.RemoteAddr). - Str("agent", r.UserAgent()). - Msg("attempting to upgrade connection") +func (manager *WebSocketManagerCtx) Upgrade(checkOrigin types.CheckOrigin) types.RouterHandler { + return func(w http.ResponseWriter, r *http.Request) error { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOrigin, + // Do not return any error while handshake + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, + } - upgrader := websocket.Upgrader{ - CheckOrigin: checkOrigin, - } - - connection, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return err + connection, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return utils.HttpBadRequest().WithInternalErr(err) + } + + // Cannot write HTTP response after connection upgrade + manager.connect(connection, r) + return nil } +} +func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http.Request) { // create new peer peer := newPeer(connection) @@ -167,7 +173,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if err != nil { manager.logger.Warn().Err(err).Msg("authentication failed") peer.Destroy(err.Error()) - return nil + return } // add session id to all log messages @@ -177,7 +183,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if !session.Profile().CanConnect { logger.Warn().Msg("connection disabled") peer.Destroy("connection disabled") - return nil + return } if session.State().IsConnected { @@ -185,7 +191,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if !manager.sessions.MercifulReconnect() { peer.Destroy("already connected") - return nil + return } logger.Info().Msg("replacing peer connection") @@ -210,7 +216,6 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque }() manager.handle(connection, session) - return nil } func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) { @@ -278,6 +283,7 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session t case <-manager.shutdown: err := connection.Close() manager.logger.Err(err).Msg("connection shutdown") + return case <-ticker.C: if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil { logger.Err(err).Msg("ping message has failed")