diff --git a/internal/capture/manager.go b/internal/capture/manager.go index 40570d73..2af9144a 100644 --- a/internal/capture/manager.go +++ b/internal/capture/manager.go @@ -176,43 +176,3 @@ func (manager *CaptureManagerCtx) Video(videoID string) (types.StreamManager, bo func (manager *CaptureManagerCtx) VideoIDs() []string { return manager.videoIDs } - -func (manager *CaptureManagerCtx) StartStream() { - manager.mu.Lock() - defer manager.mu.Unlock() - - manager.logger.Info().Msgf("starting stream pipelines") - - for _, video := range manager.videos { - if err := video.Start(); err != nil { - manager.logger.Panic().Err(err).Msg("unable to start video pipeline") - } - } - - if err := manager.audio.Start(); err != nil { - manager.logger.Panic().Err(err).Msg("unable to start audio pipeline") - } - - manager.streaming = true -} - -func (manager *CaptureManagerCtx) StopStream() { - manager.mu.Lock() - defer manager.mu.Unlock() - - manager.logger.Info().Msgf("stopping stream pipelines") - - for _, video := range manager.videos { - video.Stop() - } - - manager.audio.Stop() - manager.streaming = false -} - -func (manager *CaptureManagerCtx) Streaming() bool { - manager.mu.Lock() - defer manager.mu.Unlock() - - return manager.streaming -} diff --git a/internal/capture/stream.go b/internal/capture/stream.go index 03cf3f0d..e4c366cf 100644 --- a/internal/capture/stream.go +++ b/internal/capture/stream.go @@ -92,6 +92,13 @@ func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Samp } } +func (manager *StreamManagerCtx) ListenersCount() int { + manager.emitMu.Lock() + defer manager.emitMu.Unlock() + + return len(manager.listeners) +} + func (manager *StreamManagerCtx) Start() error { manager.mu.Lock() defer manager.mu.Unlock() diff --git a/internal/types/capture.go b/internal/types/capture.go index 71afa6f1..e8ea6877 100644 --- a/internal/types/capture.go +++ b/internal/types/capture.go @@ -23,8 +23,10 @@ type ScreencastManager interface { type StreamManager interface { Codec() codec.RTPCodec + AddListener(listener *func(sample Sample)) RemoveListener(listener *func(sample Sample)) + ListenersCount() int Start() error Stop() @@ -40,8 +42,4 @@ type CaptureManager interface { Audio() StreamManager Video(videoID string) (StreamManager, bool) VideoIDs() []string - - StartStream() - StopStream() - Streaming() bool } diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index decdc992..7c401daa 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -19,7 +19,7 @@ import ( func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { return &WebRTCManagerCtx{ - logger: log.With().Str("module", "webrtc").Logger(), + logger: log.With().Str("module", "webrtc").Logger(), defaultVideoID: capture.VideoIDs()[0], desktop: desktop, capture: capture, @@ -29,7 +29,6 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con type WebRTCManagerCtx struct { logger zerolog.Logger - videoTracks map[string]*webrtc.TrackLocalStaticSample audioTrack *webrtc.TrackLocalStaticSample unsubscribe []func() defaultVideoID string @@ -60,36 +59,6 @@ func (manager *WebRTCManagerCtx) Start() { audio.RemoveListener(&listener) }) - videoIDs := manager.capture.VideoIDs() - manager.videoTracks = map[string]*webrtc.TrackLocalStaticSample{} - for _, videoID := range videoIDs { - videoID := videoID - - video, ok := manager.capture.Video(videoID) - if !ok { - manager.logger.Warn().Str("videoID", videoID).Msg("video stream not found, skipping") - continue - } - - track, err := webrtc.NewTrackLocalStaticSample(video.Codec().Capability, "video", "stream") - if err != nil { - manager.logger.Panic().Err(err).Str("videoID", videoID).Msg("unable to create video track") - } - - listener := func(sample types.Sample) { - if err := track.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe { - manager.logger.Warn().Err(err).Str("videoID", videoID).Msg("vide pipeline failed to write") - } - } - - video.AddListener(&listener) - manager.unsubscribe = append(manager.unsubscribe, func(){ - video.RemoveListener(&listener) - }) - - manager.videoTracks[videoID] = track - } - manager.logger.Info(). Str("ice_lite", fmt.Sprintf("%t", manager.config.ICELite)). Str("ice_trickle", fmt.Sprintf("%t", manager.config.ICETrickle)). @@ -159,12 +128,67 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess }) } + // create video track + videoStream, ok := manager.capture.Video(manager.defaultVideoID) + if !ok { + manager.logger.Warn().Str("videoID", manager.defaultVideoID).Msg("default video stream not found") + return nil, err + } + + videoTrack, err := webrtc.NewTrackLocalStaticSample(videoStream.Codec().Capability, "video", "stream") + if err != nil { + manager.logger.Warn().Err(err).Msg("unable to create video track") + return nil, err + } + + listener := func(sample types.Sample) { + if err := videoTrack.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe { + manager.logger.Warn().Err(err).Msg("video pipeline failed to write") + } + } + + // should be stream started + if videoStream.ListenersCount() == 0 { + if err := videoStream.Start(); err != nil { + manager.logger.Warn().Err(err).Msg("unable to start video pipeline") + return nil, err + } + } + + videoStream.AddListener(&listener) + + changeVideo := func(videoID string) error { + newVideoStream, ok := manager.capture.Video(videoID) + if !ok { + return fmt.Errorf("video stream not found") + } + + // should be new stream started + if newVideoStream.ListenersCount() == 0 { + if err := newVideoStream.Start(); err != nil { + return err + } + } + + // switch listeners + videoStream.RemoveListener(&listener) + newVideoStream.AddListener(&listener) + + // should be old stream stopped + if videoStream.ListenersCount() == 0 { + videoStream.Stop() + } + + videoStream = newVideoStream + return nil + } + _, err = connection.AddTrack(manager.audioTrack) if err != nil { return nil, err } - videoSender, err := connection.AddTrack(manager.videoTracks[manager.defaultVideoID]) + _, err = connection.AddTrack(videoTrack) if err != nil { return nil, err } @@ -208,6 +232,12 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess connection.Close() case webrtc.PeerConnectionStateClosed: session.SetWebRTCConnected(false) + videoStream.RemoveListener(&listener) + + // should be stream stopped + if videoStream.ListenersCount() == 0 { + videoStream.Stop() + } } }) @@ -229,8 +259,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess settings: settings, connection: connection, configuration: configuration, - videoTracks: manager.videoTracks, - videoSender: videoSender, + changeVideo: changeVideo, }) return connection.LocalDescription(), nil diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index 284f1a10..be46f150 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -1,10 +1,6 @@ package webrtc -import ( - "fmt" - - "github.com/pion/webrtc/v3" -) +import "github.com/pion/webrtc/v3" type WebRTCPeerCtx struct { api *webrtc.API @@ -12,8 +8,7 @@ type WebRTCPeerCtx struct { settings *webrtc.SettingEngine connection *webrtc.PeerConnection configuration *webrtc.Configuration - videoTracks map[string]*webrtc.TrackLocalStaticSample - videoSender *webrtc.RTPSender + changeVideo func(videoID string) error } func (webrtc_peer *WebRTCPeerCtx) SignalAnswer(sdp string) error { @@ -28,12 +23,7 @@ func (webrtc_peer *WebRTCPeerCtx) SignalCandidate(candidate webrtc.ICECandidateI } func (webrtc_peer *WebRTCPeerCtx) SetVideoID(videoID string) error { - track, ok := webrtc_peer.videoTracks[videoID] - if !ok { - return fmt.Errorf("videoID not found in available tracks") - } - - return webrtc_peer.videoSender.ReplaceTrack(track) + return webrtc_peer.changeVideo(videoID) } func (webrtc_peer *WebRTCPeerCtx) Destroy() error { diff --git a/internal/websocket/handler/session.go b/internal/websocket/handler/session.go index 664a7986..e430813e 100644 --- a/internal/websocket/handler/session.go +++ b/internal/websocket/handler/session.go @@ -29,9 +29,11 @@ func (h *MessageHandlerCtx) SessionDeleted(session types.Session) error { } func (h *MessageHandlerCtx) SessionConnected(session types.Session) error { - // start streaming, when first member connects - if !h.capture.Streaming() { - h.capture.StartStream() + // start audio, when first member connects + if !h.capture.Audio().Started() { + if err := h.capture.Audio().Start(); err != nil { + return err + } } if err := h.systemInit(session); err != nil { @@ -48,9 +50,9 @@ func (h *MessageHandlerCtx) SessionConnected(session types.Session) error { } func (h *MessageHandlerCtx) SessionDisconnected(session types.Session) error { - // Stop streaming, if last member disonnects - if h.capture.Streaming() && !h.sessions.HasConnectedMembers() { - h.capture.StopStream() + // stop audio, if last member disonnects + if h.capture.Audio().Started() && !h.sessions.HasConnectedMembers() { + h.capture.Audio().Stop() } // clear host if exists