diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index c94b651c..91ee6d7d 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -5,6 +5,7 @@ import ( "io" "reflect" "strings" + "sync" "github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3/pkg/media" @@ -19,10 +20,11 @@ import ( func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { return &WebRTCManagerCtx{ - logger: log.With().Str("module", "webrtc").Logger(), - desktop: desktop, - capture: capture, - config: config, + logger: log.With().Str("module", "webrtc").Logger(), + desktop: desktop, + capture: capture, + config: config, + participants: 0, // TODO: Refactor. curImgListeners: map[uintptr]*func(cur *types.CursorImage){}, curPosListeners: map[uintptr]*func(x, y int){}, @@ -30,12 +32,14 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con } type WebRTCManagerCtx struct { - logger zerolog.Logger - audioTrack *webrtc.TrackLocalStaticSample - audioStop func() - desktop types.DesktopManager - capture types.CaptureManager - config *config.WebRTC + mu sync.Mutex + logger zerolog.Logger + audioTrack *webrtc.TrackLocalStaticSample + audioStop func() + desktop types.DesktopManager + capture types.CaptureManager + config *config.WebRTC + participants uint32 // TODO: Refactor. curImgListeners map[uintptr]*func(cur *types.CursorImage) curPosListeners map[uintptr]*func(x, y int) @@ -51,15 +55,15 @@ func (manager *WebRTCManagerCtx) Start() { manager.logger.Panic().Err(err).Msg("unable to create audio track") } - listener := func(sample types.Sample) { + audioListener := func(sample types.Sample) { if err := manager.audioTrack.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe { manager.logger.Warn().Err(err).Msg("audio pipeline failed to write") } } - audio.AddListener(&listener) + audio.AddListener(&audioListener) manager.audioStop = func() { - audio.RemoveListener(&listener) + audio.RemoveListener(&audioListener) } manager.logger.Info(). @@ -156,12 +160,14 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin return nil, err } - listener := func(sample types.Sample) { + videoListener := func(sample types.Sample) { if err := videoTrack.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe { manager.logger.Warn().Err(err).Msg("video pipeline failed to write") } } + manager.mu.Lock() + // should be stream started if videoStream.ListenersCount() == 0 { if err := videoStream.Start(); err != nil { @@ -170,7 +176,17 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } } - videoStream.AddListener(&listener) + videoStream.AddListener(&videoListener) + + // start audio, when first participant connects + if !manager.capture.Audio().Started() { + if err := manager.capture.Audio().Start(); err != nil { + manager.logger.Panic().Err(err).Msg("unable to start audio stream") + } + } + + manager.participants = manager.participants + 1 + manager.mu.Unlock() changeVideo := func(videoID string) error { newVideoStream, ok := manager.capture.Video(videoID) @@ -185,9 +201,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } } - // switch listeners - videoStream.RemoveListener(&listener) - newVideoStream.AddListener(&listener) + // switch videoListeners + videoStream.RemoveListener(&videoListener) + newVideoStream.AddListener(&videoListener) // should be old stream stopped if videoStream.ListenersCount() == 0 { @@ -269,17 +285,33 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin case webrtc.PeerConnectionStateFailed: connection.Close() case webrtc.PeerConnectionStateClosed: + manager.mu.Lock() + session.SetWebRTCConnected(peer, false) - videoStream.RemoveListener(&listener) + videoStream.RemoveListener(&videoListener) // should be stream stopped if videoStream.ListenersCount() == 0 { videoStream.Stop() } + // decrease participants + manager.participants = manager.participants - 1 + + // stop audio, if last participant disonnects + if manager.participants <= 0 { + manager.participants = 0 + + if manager.capture.Audio().Started() { + manager.capture.Audio().Stop() + } + } + // TODO: Refactor. delete(manager.curImgListeners, cursorChangePtr) delete(manager.curPosListeners, cursorPositionPtr) + + manager.mu.Unlock() } }) diff --git a/internal/websocket/handler/session.go b/internal/websocket/handler/session.go index aebfb070..522b2908 100644 --- a/internal/websocket/handler/session.go +++ b/internal/websocket/handler/session.go @@ -29,13 +29,6 @@ func (h *MessageHandlerCtx) SessionDeleted(session types.Session) error { } func (h *MessageHandlerCtx) SessionConnected(session types.Session) error { - // start audio, when first member connects - if !h.capture.Audio().Started() { - if err := h.capture.Audio().Start(); err != nil { - return err - } - } - if err := h.systemInit(session); err != nil { return err } @@ -50,11 +43,6 @@ func (h *MessageHandlerCtx) SessionConnected(session types.Session) error { } func (h *MessageHandlerCtx) SessionDisconnected(session types.Session) error { - // stop audio, if last member disonnects - if h.capture.Audio().Started() && !h.sessions.HasConnectedMembers() { - h.capture.Audio().Stop() - } - // clear host if exists if session.IsHost() { h.desktop.ResetKeys()