diff --git a/internal/capture/stream.go b/internal/capture/stream.go index 085078b9..ddba3eda 100644 --- a/internal/capture/stream.go +++ b/internal/capture/stream.go @@ -4,7 +4,6 @@ import ( "errors" "reflect" "sync" - "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -14,8 +13,6 @@ import ( "demodesk/neko/internal/types/codec" ) -const newListenerTimeout = 500 * time.Millisecond - type StreamManagerCtx struct { logger zerolog.Logger mu sync.Mutex @@ -30,9 +27,8 @@ type StreamManagerCtx struct { sampleStop chan interface{} sampleUpdate chan interface{} - listeners map[uintptr]*func(sample types.Sample) - listenersMu sync.Mutex - listenersCount int + listeners map[uintptr]*func(sample types.Sample) + listenersMu sync.Mutex } func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) *StreamManagerCtx { @@ -95,51 +91,37 @@ func (manager *StreamManagerCtx) Codec() codec.RTPCodec { return manager.codec } -func (manager *StreamManagerCtx) NewListener(listener *func(sample types.Sample)) (dispatcher chan interface{}, err error) { - if listener == nil { - return dispatcher, errors.New("listener cannot be nil") - } - - manager.mu.Lock() - defer manager.mu.Unlock() - - manager.listenersCount++ - if manager.listenersCount == 1 { +func (manager *StreamManagerCtx) start() error { + if len(manager.listeners) == 0 { err := manager.createPipeline() if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { - return dispatcher, err + return err } manager.logger.Info().Msgf("first listener, starting") } - dispatcher = make(chan interface{}, 1) - go func() { - select { - case <-time.After(newListenerTimeout): - manager.logger.Warn().Msgf("add listener channel was not called, timeouted") - break - case <-dispatcher: - break - } - - ptr := reflect.ValueOf(listener).Pointer() - - manager.listenersMu.Lock() - manager.listeners[ptr] = listener - manager.listenersMu.Unlock() - - manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener") - }() - - return dispatcher, nil + return nil } -func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Sample)) (dispatcher chan interface{}) { - if listener == nil { - return dispatcher +func (manager *StreamManagerCtx) stop() { + if len(manager.listeners) == 0 { + manager.destroyPipeline() + manager.logger.Info().Msgf("last listener, stopping") } +} +func (manager *StreamManagerCtx) addListener(listener *func(sample types.Sample)) { + ptr := reflect.ValueOf(listener).Pointer() + + manager.listenersMu.Lock() + manager.listeners[ptr] = listener + manager.listenersMu.Unlock() + + manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener") +} + +func (manager *StreamManagerCtx) removeListener(listener *func(sample types.Sample)) { ptr := reflect.ValueOf(listener).Pointer() manager.listenersMu.Lock() @@ -147,36 +129,70 @@ func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Samp manager.listenersMu.Unlock() manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener") +} +func (manager *StreamManagerCtx) AddListener(listener *func(sample types.Sample)) error { manager.mu.Lock() - manager.listenersCount-- - manager.mu.Unlock() + defer manager.mu.Unlock() - dispatcher = make(chan interface{}, 1) - go func() { - select { - case <-time.After(newListenerTimeout): - manager.logger.Warn().Msgf("remote listener channel was not called, timeouted") - break - case <-dispatcher: - break - } + if listener == nil { + return errors.New("listener cannot be nil") + } - manager.mu.Lock() - defer manager.mu.Unlock() + // start if stopped + if err := manager.start(); err != nil { + return err + } - if manager.listenersCount <= 0 { - manager.destroyPipeline() - manager.logger.Info().Msgf("last listener, stopping") - } + // add listener + manager.addListener(listener) - if manager.listenersCount < 0 { - manager.listenersCount = 0 - manager.logger.Error().Int("listeners-count", manager.listenersCount).Msgf("listener counter is < 0, something is wrong") - } - }() + return nil +} - return dispatcher +func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Sample)) error { + manager.mu.Lock() + defer manager.mu.Unlock() + + if listener == nil { + return errors.New("listener cannot be nil") + } + + // remove listener + manager.removeListener(listener) + + // stop if started + manager.stop() + + return nil +} + +func (manager *StreamManagerCtx) MoveListenerTo(listener *func(sample types.Sample), stream types.StreamManager) error { + manager.mu.Lock() + defer manager.mu.Unlock() + + targetStream, ok := stream.(*StreamManagerCtx) + if !ok { + return errors.New("stream manager does not support moving listeners") + } + + if listener == nil { + return errors.New("listener cannot be nil") + } + + // start if stopped + if err := targetStream.start(); err != nil { + return err + } + + // swap listeners + manager.removeListener(listener) + targetStream.addListener(listener) + + // stop if started + manager.stop() + + return nil } func (manager *StreamManagerCtx) ListenersCount() int { @@ -187,10 +203,7 @@ func (manager *StreamManagerCtx) ListenersCount() int { } func (manager *StreamManagerCtx) Started() bool { - manager.mu.Lock() - defer manager.mu.Unlock() - - return manager.listenersCount > 0 + return manager.ListenersCount() > 0 } func (manager *StreamManagerCtx) createPipeline() error { diff --git a/internal/types/capture.go b/internal/types/capture.go index 1644f81b..ed00b999 100644 --- a/internal/types/capture.go +++ b/internal/types/capture.go @@ -35,11 +35,9 @@ type ScreencastManager interface { type StreamManager interface { Codec() codec.RTPCodec - // starts pipeline if was not running before - // and returns dispatcher channel - NewListener(listener *func(sample Sample)) (dispatcher chan interface{}, err error) - // stops pipeline if it was last listener - RemoveListener(listener *func(sample Sample)) (dispatcher chan interface{}) + AddListener(listener *func(sample Sample)) error + RemoveListener(listener *func(sample Sample)) error + MoveListenerTo(listener *func(sample Sample), targetStream StreamManager) error ListenersCount() int Started() bool diff --git a/internal/webrtc/peerstreamtrack.go b/internal/webrtc/peerstreamtrack.go index 7e9e7296..946d87ad 100644 --- a/internal/webrtc/peerstreamtrack.go +++ b/internal/webrtc/peerstreamtrack.go @@ -33,8 +33,8 @@ func (manager *WebRTCManagerCtx) newPeerStreamTrack(stream types.StreamManager, }, } - peer.SetStream(stream) - return peer, nil + err = peer.SetStream(stream) + return peer, err } type PeerStreamTrack struct { @@ -50,28 +50,18 @@ func (peer *PeerStreamTrack) SetStream(stream types.StreamManager) error { peer.streamMu.Lock() defer peer.streamMu.Unlock() - // prepare new listener - addDispatcher, err := stream.NewListener(&peer.listener) - if err != nil { - return err - } - - // remove previous listener (in case it existed) - var stopDispatcher chan interface{} + var err error if peer.stream != nil { - stopDispatcher = peer.stream.RemoveListener(&peer.listener) + err = peer.stream.MoveListenerTo(&peer.listener, stream) + } else { + err = peer.stream.AddListener(&peer.listener) } - // add new listener - close(addDispatcher) - - // stop old pipeline (in case it existed) - if stopDispatcher != nil { - close(stopDispatcher) + if err != nil { + peer.stream = stream } - peer.stream = stream - return nil + return err } func (peer *PeerStreamTrack) RemoveStream() { @@ -79,8 +69,8 @@ func (peer *PeerStreamTrack) RemoveStream() { defer peer.streamMu.Unlock() if peer.stream != nil { - dispatcher := peer.stream.RemoveListener(&peer.listener) - close(dispatcher) + peer.stream.RemoveListener(&peer.listener) + peer.stream = nil } }