ws close connections on shutdown and add wg.

This commit is contained in:
Miroslav Šedivý 2021-09-09 23:55:53 +02:00
parent 51207c2b50
commit 4f7bd48bec

View File

@ -3,6 +3,7 @@ package websocket
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -28,6 +29,7 @@ func New(
return &WebSocketManagerCtx{ return &WebSocketManagerCtx{
logger: logger, logger: logger,
shutdown: make(chan interface{}),
sessions: sessions, sessions: sessions,
desktop: desktop, desktop: desktop,
handler: handler.New(sessions, desktop, capture, webrtc), handler: handler.New(sessions, desktop, capture, webrtc),
@ -37,6 +39,8 @@ func New(
type WebSocketManagerCtx struct { type WebSocketManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
wg sync.WaitGroup
shutdown chan interface{}
sessions types.SessionManager sessions types.SessionManager
desktop types.DesktopManager desktop types.DesktopManager
handler *handler.MessageHandlerCtx handler *handler.MessageHandlerCtx
@ -132,7 +136,8 @@ func (manager *WebSocketManagerCtx) Start() {
func (manager *WebSocketManagerCtx) Shutdown() error { func (manager *WebSocketManagerCtx) Shutdown() error {
manager.logger.Info().Msg("shutdown") manager.logger.Info().Msg("shutdown")
// TODO: Kill all connections and add waitgroup for gorutines. close(manager.shutdown)
manager.wg.Wait()
return nil return nil
} }
@ -218,7 +223,10 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session t
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(pingPeriod)
defer ticker.Stop() defer ticker.Stop()
manager.wg.Add(1)
go func() { go func() {
defer manager.wg.Done()
for { for {
_, raw, err := connection.ReadMessage() _, raw, err := connection.ReadMessage()
if err != nil { if err != nil {
@ -267,6 +275,9 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session t
} }
case <-cancel: case <-cancel:
return return
case <-manager.shutdown:
connection.Close()
return
case <-ticker.C: case <-ticker.C:
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil { if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
logger.Err(err).Msg("ping message has failed") logger.Err(err).Msg("ping message has failed")