diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 0a538b8a..49a982a6 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -50,6 +50,8 @@ type WebRTCManagerCtx struct { capture types.CaptureManager curImage *cursor.ImageCtx curPosition *cursor.PositionCtx + + camStop, micStop *func() } func (manager *WebRTCManagerCtx) Start() { @@ -160,8 +162,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - defer receiver.Stop() - logger := logger.With(). Str("kind", track.Kind().String()). Str("mime", track.Codec().RTPCodecCapability.MimeType). @@ -171,6 +171,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin if !session.Profile().CanShareMedia { logger.Warn().Msg("media sharing is disabled for this session") + receiver.Stop() return } @@ -178,16 +179,46 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin codec, ok := codec.ParseRTC(track.Codec()) if !ok { logger.Warn().Msg("remote track with unknown codec") + receiver.Stop() return } var srcManager types.StreamSrcManager + + stopped := false + stopFn := func() { + if stopped { + return + } + + stopped = true + receiver.Stop() + srcManager.Stop() + logger.Info().Msg("remote track stopped") + } + if track.Kind() == webrtc.RTPCodecTypeAudio { // audio -> microphone srcManager = manager.capture.Microphone() + defer stopFn() + + if manager.micStop != nil { + (*manager.micStop)() + } + manager.micStop = &stopFn } else if track.Kind() == webrtc.RTPCodecTypeVideo { // video -> webcam srcManager = manager.capture.Webcam() + defer stopFn() + + if manager.camStop != nil { + (*manager.camStop)() + } + manager.camStop = &stopFn + } else { + logger.Warn().Msg("remote track with unsupported codec type") + receiver.Stop() + return } err := srcManager.Start(codec) @@ -195,7 +226,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin logger.Err(err).Msg("failed to start pipeline") return } - defer srcManager.Stop() ticker := time.NewTicker(rtcpPLIInterval) defer ticker.Stop()