websocket upgrade as router handler.

This commit is contained in:
Miroslav Šedivý 2021-09-18 14:59:15 +02:00
parent fd9d5ec6f8
commit f96466b7b9
3 changed files with 26 additions and 22 deletions

View File

@ -37,11 +37,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
router.Route("/api", ApiManager.Route) router.Route("/api", ApiManager.Route)
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error { router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool {
return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { return config.AllowOrigin(r.Header.Get("Origin"))
return config.AllowOrigin(r.Header.Get("Origin")) }))
})
})
if config.Static != "" { if config.Static != "" {
fs := http.FileServer(http.Dir(config.Static)) fs := http.FileServer(http.Dir(config.Static))

View File

@ -23,5 +23,5 @@ type WebSocketManager interface {
Start() Start()
Shutdown() error Shutdown() error
AddHandler(handler WebSocketHandler) AddHandler(handler WebSocketHandler)
Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error Upgrade(checkOrigin CheckOrigin) RouterHandler
} }

View File

@ -13,6 +13,7 @@ import (
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
"demodesk/neko/internal/types/event" "demodesk/neko/internal/types/event"
"demodesk/neko/internal/types/message" "demodesk/neko/internal/types/message"
"demodesk/neko/internal/utils"
"demodesk/neko/internal/websocket/handler" "demodesk/neko/internal/websocket/handler"
) )
@ -145,21 +146,26 @@ func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) {
manager.handlers = append(manager.handlers, handler) manager.handlers = append(manager.handlers, handler)
} }
func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error { func (manager *WebSocketManagerCtx) Upgrade(checkOrigin types.CheckOrigin) types.RouterHandler {
manager.logger.Debug(). return func(w http.ResponseWriter, r *http.Request) error {
Str("address", r.RemoteAddr). upgrader := websocket.Upgrader{
Str("agent", r.UserAgent()). CheckOrigin: checkOrigin,
Msg("attempting to upgrade connection") // Do not return any error while handshake
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {},
}
upgrader := websocket.Upgrader{ connection, err := upgrader.Upgrade(w, r, nil)
CheckOrigin: checkOrigin, if err != nil {
} return utils.HttpBadRequest().WithInternalErr(err)
}
connection, err := upgrader.Upgrade(w, r, nil)
if err != nil { // Cannot write HTTP response after connection upgrade
return err manager.connect(connection, r)
return nil
} }
}
func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http.Request) {
// create new peer // create new peer
peer := newPeer(connection) peer := newPeer(connection)
@ -167,7 +173,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if err != nil { if err != nil {
manager.logger.Warn().Err(err).Msg("authentication failed") manager.logger.Warn().Err(err).Msg("authentication failed")
peer.Destroy(err.Error()) peer.Destroy(err.Error())
return nil return
} }
// add session id to all log messages // add session id to all log messages
@ -177,7 +183,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !session.Profile().CanConnect { if !session.Profile().CanConnect {
logger.Warn().Msg("connection disabled") logger.Warn().Msg("connection disabled")
peer.Destroy("connection disabled") peer.Destroy("connection disabled")
return nil return
} }
if session.State().IsConnected { if session.State().IsConnected {
@ -185,7 +191,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !manager.sessions.MercifulReconnect() { if !manager.sessions.MercifulReconnect() {
peer.Destroy("already connected") peer.Destroy("already connected")
return nil return
} }
logger.Info().Msg("replacing peer connection") logger.Info().Msg("replacing peer connection")
@ -210,7 +216,6 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
}() }()
manager.handle(connection, session) manager.handle(connection, session)
return nil
} }
func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) { func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {
@ -278,6 +283,7 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session t
case <-manager.shutdown: case <-manager.shutdown:
err := connection.Close() err := connection.Close()
manager.logger.Err(err).Msg("connection shutdown") manager.logger.Err(err).Msg("connection shutdown")
return
case <-ticker.C: case <-ticker.C:
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil { if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
logger.Err(err).Msg("ping message has failed") logger.Err(err).Msg("ping message has failed")