From 2364facd60d054a6cd298ddcee0e539c9ee23cbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Mon, 6 Feb 2023 19:45:51 +0100 Subject: [PATCH] WebRTC congestion control (#26) * Add congestion control * Improve stream matching, add manual stream selection, add metrics * Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively * Missing signal response, fix video auto bug * Remove redundant mutex * Bitrate history queue * Get bitrate fn support h264 & float64 --------- Co-authored-by: Aleksandar Sukovic --- internal/capture/buckets.go | 105 ------------- internal/capture/buckets/buckets.go | 145 ++++++++++++++++++ internal/capture/buckets/buckets_test.go | 83 ++++++++++ internal/capture/buckets/queue.go | 88 +++++++++++ internal/capture/buckets/queue_test.go | 99 ++++++++++++ internal/capture/manager.go | 16 +- internal/capture/streamsink.go | 23 ++- internal/webrtc/manager.go | 183 ++++++++++++++++++----- internal/webrtc/metrics.go | 4 +- internal/webrtc/peer.go | 47 ++++-- internal/webrtc/track.go | 70 +++++++-- internal/websocket/handler/signal.go | 36 ++--- pkg/types/capture.go | 44 ++++-- pkg/types/message/messages.go | 10 +- pkg/types/webrtc.go | 7 +- 15 files changed, 738 insertions(+), 222 deletions(-) delete mode 100644 internal/capture/buckets.go create mode 100644 internal/capture/buckets/buckets.go create mode 100644 internal/capture/buckets/buckets_test.go create mode 100644 internal/capture/buckets/queue.go create mode 100644 internal/capture/buckets/queue_test.go diff --git a/internal/capture/buckets.go b/internal/capture/buckets.go deleted file mode 100644 index 04a87f09..00000000 --- a/internal/capture/buckets.go +++ /dev/null @@ -1,105 +0,0 @@ -package capture - -import ( - "errors" - "fmt" - "math" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - - "github.com/demodesk/neko/pkg/types" - "github.com/demodesk/neko/pkg/types/codec" -) - -type BucketsManagerCtx struct { - logger zerolog.Logger - codec codec.RTPCodec - streams map[string]*StreamSinkManagerCtx - streamIDs []string -} - -func bucketsNew(codec codec.RTPCodec, streams map[string]*StreamSinkManagerCtx, streamIDs []string) *BucketsManagerCtx { - logger := log.With(). - Str("module", "capture"). - Str("submodule", "buckets"). - Logger() - - return &BucketsManagerCtx{ - logger: logger, - codec: codec, - streams: streams, - streamIDs: streamIDs, - } -} - -func (m *BucketsManagerCtx) shutdown() { - m.logger.Info().Msgf("shutdown") -} - -func (m *BucketsManagerCtx) destroyAll() { - for _, stream := range m.streams { - if stream.Started() { - stream.destroyPipeline() - } - } -} - -func (m *BucketsManagerCtx) recreateAll() error { - for _, stream := range m.streams { - if stream.Started() { - err := stream.createPipeline() - if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { - return err - } - } - } - - return nil -} - -func (m *BucketsManagerCtx) IDs() []string { - return m.streamIDs -} - -func (m *BucketsManagerCtx) Codec() codec.RTPCodec { - return m.codec -} - -func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error { - receiver.OnBitrateChange(func(bitrate int) error { - stream, ok := m.findNearestStream(bitrate) - if !ok { - return fmt.Errorf("no stream found for bitrate %d", bitrate) - } - - return receiver.SetStream(stream) - }) - - return nil -} - -func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) { - minDiff := math.MaxInt - for _, s := range m.streams { - streamBitrate, err := s.Bitrate() - if err != nil { - m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID()) - continue - } - - diffAbs := int(math.Abs(float64(bitrate - streamBitrate))) - - if diffAbs < minDiff { - minDiff, ss = diffAbs, s - } - } - ok = ss != nil - return -} - -func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error { - receiver.OnBitrateChange(nil) - receiver.RemoveStream() - return nil -} diff --git a/internal/capture/buckets/buckets.go b/internal/capture/buckets/buckets.go new file mode 100644 index 00000000..6d0c9870 --- /dev/null +++ b/internal/capture/buckets/buckets.go @@ -0,0 +1,145 @@ +package buckets + +import ( + "errors" + "sort" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/types/codec" +) + +type BucketsManagerCtx struct { + logger zerolog.Logger + codec codec.RTPCodec + streams map[string]types.StreamSinkManager + streamIDs []string +} + +func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx { + logger := log.With(). + Str("module", "capture"). + Str("submodule", "buckets"). + Logger() + + return &BucketsManagerCtx{ + logger: logger, + codec: codec, + streams: streams, + streamIDs: streamIDs, + } +} + +func (manager *BucketsManagerCtx) Shutdown() { + manager.logger.Info().Msgf("shutdown") + + manager.DestroyAll() +} + +func (manager *BucketsManagerCtx) DestroyAll() { + for _, stream := range manager.streams { + if stream.Started() { + stream.DestroyPipeline() + } + } +} + +func (manager *BucketsManagerCtx) RecreateAll() error { + for _, stream := range manager.streams { + if stream.Started() { + err := stream.CreatePipeline() + if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { + return err + } + } + } + return nil +} + +func (manager *BucketsManagerCtx) IDs() []string { + return manager.streamIDs +} + +func (manager *BucketsManagerCtx) Codec() codec.RTPCodec { + return manager.codec +} + +func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) { + // bitrate history is per receiver + bitrateHistory := &queue{} + + receiver.OnBitrateChange(func(peerBitrate int) (bool, error) { + bitrate := peerBitrate + if receiver.VideoAuto() { + bitrate = bitrateHistory.normaliseBitrate(bitrate) + } + + stream := manager.findNearestStream(bitrate) + streamID := stream.ID() + + // TODO: make this less noisy in logs + manager.logger.Debug(). + Str("video_id", streamID). + Int("len", bitrateHistory.len()). + Int("peer_bitrate", peerBitrate). + Int("bitrate", bitrate). + Msg("change video bitrate") + + return receiver.SetStream(stream) + }) + + receiver.OnVideoChange(func(videoID string) (bool, error) { + stream := manager.streams[videoID] + manager.logger.Info(). + Str("video_id", videoID). + Msg("video change") + + return receiver.SetStream(stream) + }) +} + +func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager { + type streamDiff struct { + id string + bitrateDiff int + } + + sortDiff := func(a, b int) bool { + switch { + case a < 0 && b < 0: + return a > b + case a >= 0: + if b >= 0 { + return a <= b + } + return true + } + return false + } + + var diffs []streamDiff + + for _, stream := range manager.streams { + diffs = append(diffs, streamDiff{ + id: stream.ID(), + bitrateDiff: peerBitrate - stream.Bitrate(), + }) + } + + sort.Slice(diffs, func(i, j int) bool { + return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff) + }) + + bestDiff := diffs[0] + + return manager.streams[bestDiff.id] +} + +func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error { + receiver.OnBitrateChange(nil) + receiver.OnVideoChange(nil) + receiver.RemoveStream() + return nil +} diff --git a/internal/capture/buckets/buckets_test.go b/internal/capture/buckets/buckets_test.go new file mode 100644 index 00000000..9fedcdd2 --- /dev/null +++ b/internal/capture/buckets/buckets_test.go @@ -0,0 +1,83 @@ +package buckets + +import ( + "reflect" + "testing" + + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/types/codec" +) + +func TestBucketsManagerCtx_FindNearestStream(t *testing.T) { + type fields struct { + codec codec.RTPCodec + streams map[string]types.StreamSinkManager + } + type args struct { + peerBitrate int + } + tests := []struct { + name string + fields fields + args args + want types.StreamSinkManager + }{ + { + name: "findNearestStream", + fields: fields{ + streams: map[string]types.StreamSinkManager{ + "1": mockStreamSink{ + id: "1", + bitrate: 500, + }, + "2": mockStreamSink{ + id: "2", + bitrate: 750, + }, + "3": mockStreamSink{ + id: "3", + bitrate: 1000, + }, + "4": mockStreamSink{ + id: "4", + bitrate: 1250, + }, + "5": mockStreamSink{ + id: "5", + bitrate: 1700, + }, + }, + }, + args: args{ + peerBitrate: 950, + }, + want: mockStreamSink{ + id: "2", + bitrate: 750, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{}) + + if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) { + t.Errorf("findNearestStream() = %v, want %v", got, tt.want) + } + }) + } +} + +type mockStreamSink struct { + id string + bitrate int + types.StreamSinkManager +} + +func (m mockStreamSink) ID() string { + return m.id +} + +func (m mockStreamSink) Bitrate() int { + return m.bitrate +} diff --git a/internal/capture/buckets/queue.go b/internal/capture/buckets/queue.go new file mode 100644 index 00000000..e79eb635 --- /dev/null +++ b/internal/capture/buckets/queue.go @@ -0,0 +1,88 @@ +package buckets + +import ( + "math" + "sync" + "time" +) + +type queue struct { + sync.Mutex + q []elem +} + +type elem struct { + created time.Time + bitrate int +} + +func (q *queue) push(v elem) { + q.Lock() + defer q.Unlock() + + // if the first element is older than 10 seconds, remove it + if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second { + q.q = q.q[1:] + } + q.q = append(q.q, v) +} + +func (q *queue) len() int { + q.Lock() + defer q.Unlock() + return len(q.q) +} + +func (q *queue) avg() int { + q.Lock() + defer q.Unlock() + if len(q.q) == 0 { + return 0 + } + sum := 0 + for _, v := range q.q { + sum += v.bitrate + } + return sum / len(q.q) +} + +func (q *queue) avgLastN(n int) int { + if n <= 0 { + return q.avg() + } + q.Lock() + defer q.Unlock() + if len(q.q) == 0 { + return 0 + } + sum := 0 + for _, v := range q.q[len(q.q)-n:] { + sum += v.bitrate + } + return sum / n +} + +func (q *queue) normaliseBitrate(currentBitrate int) int { + avgBitrate := float64(q.avg()) + histLen := float64(q.len()) + + q.push(elem{ + bitrate: currentBitrate, + created: time.Now(), + }) + + if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 { + return currentBitrate + } + + lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen)) + if lastN > q.len() { + lastN = q.len() + } + + if lastN == 0 { + return currentBitrate + } + + return q.avgLastN(lastN) +} diff --git a/internal/capture/buckets/queue_test.go b/internal/capture/buckets/queue_test.go new file mode 100644 index 00000000..4deda5f8 --- /dev/null +++ b/internal/capture/buckets/queue_test.go @@ -0,0 +1,99 @@ +package buckets + +import "testing" + +func Queue_normaliseBitrate(t *testing.T) { + type fields struct { + queue *queue + } + type args struct { + currentBitrate int + } + tests := []struct { + name string + fields fields + args args + want []int + }{ + { + name: "normaliseBitrate: big drop", + fields: fields{ + queue: &queue{ + q: []elem{ + {bitrate: 900}, + {bitrate: 750}, + {bitrate: 780}, + {bitrate: 1100}, + {bitrate: 950}, + {bitrate: 700}, + {bitrate: 800}, + {bitrate: 900}, + {bitrate: 1000}, + {bitrate: 1100}, + // avg = 898 + }, + }, + }, + args: args{ + currentBitrate: 350, + }, + want: []int{816, 700, 537, 350, 350}, + }, { + name: "normaliseBitrate: small drop", + fields: fields{ + queue: &queue{ + q: []elem{ + {bitrate: 900}, + {bitrate: 750}, + {bitrate: 780}, + {bitrate: 1100}, + {bitrate: 950}, + {bitrate: 700}, + {bitrate: 800}, + {bitrate: 900}, + {bitrate: 1000}, + {bitrate: 1100}, + // avg = 898 + }, + }, + }, + args: args{ + currentBitrate: 700, + }, + want: []int{878, 842, 825, 825, 812, 787, 750, 700}, + }, { + name: "normaliseBitrate", + fields: fields{ + queue: &queue{ + q: []elem{ + {bitrate: 900}, + {bitrate: 750}, + {bitrate: 780}, + {bitrate: 1100}, + {bitrate: 950}, + {bitrate: 700}, + {bitrate: 800}, + {bitrate: 900}, + {bitrate: 1000}, + {bitrate: 1100}, + // avg = 898 + }, + }, + }, + args: args{ + currentBitrate: 1350, + }, + want: []int{943, 1003, 1060, 1085}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := tt.fields.queue + for i := 0; i < len(tt.want); i++ { + if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] { + t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i]) + } + } + }) + } +} diff --git a/internal/capture/manager.go b/internal/capture/manager.go index 2f96c8e1..c22a6162 100644 --- a/internal/capture/manager.go +++ b/internal/capture/manager.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/demodesk/neko/internal/capture/buckets" "github.com/demodesk/neko/internal/config" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types/codec" @@ -22,7 +23,7 @@ type CaptureManagerCtx struct { broadcast *BroacastManagerCtx screencast *ScreencastManagerCtx audio *StreamSinkManagerCtx - video *BucketsManagerCtx + video types.BucketsManager // sources webcam *StreamSrcManagerCtx @@ -32,7 +33,7 @@ type CaptureManagerCtx struct { func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx { logger := log.With().Str("module", "capture").Logger() - videos := map[string]*StreamSinkManagerCtx{} + videos := map[string]types.StreamSinkManager{} for video_id, cnf := range config.VideoPipelines { pipelineConf := cnf @@ -68,9 +69,10 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt Msg("syntax check for video stream pipeline passed") getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize) - if err != nil { + if _, err = getVideoBitrate(); err != nil { logger.Panic().Err(err).Msg("unable to get video bitrate") } + // append to videos videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate) } @@ -139,7 +141,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, ), nil }, "audio", nil), - video: bucketsNew(config.VideoCodec, videos, config.VideoIDs), + video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs), // sources webcam: streamSrcNew(config.WebcamEnabled, map[string]string{ @@ -200,7 +202,7 @@ func (manager *CaptureManagerCtx) Start() { } manager.desktop.OnBeforeScreenSizeChange(func() { - manager.video.destroyAll() + manager.video.DestroyAll() if manager.broadcast.Started() { manager.broadcast.destroyPipeline() @@ -212,7 +214,7 @@ func (manager *CaptureManagerCtx) Start() { }) manager.desktop.OnAfterScreenSizeChange(func() { - err := manager.video.recreateAll() + err := manager.video.RecreateAll() if err != nil { manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines") } @@ -240,7 +242,7 @@ func (manager *CaptureManagerCtx) Shutdown() error { manager.screencast.shutdown() manager.audio.shutdown() - manager.video.shutdown() + manager.video.Shutdown() manager.webcam.shutdown() manager.microphone.shutdown() diff --git a/internal/capture/streamsink.go b/internal/capture/streamsink.go index cb6225af..3400840e 100644 --- a/internal/capture/streamsink.go +++ b/internal/capture/streamsink.go @@ -105,7 +105,7 @@ func (manager *StreamSinkManagerCtx) shutdown() { } manager.listenersMu.Unlock() - manager.destroyPipeline() + manager.DestroyPipeline() manager.wg.Wait() } @@ -113,12 +113,19 @@ func (manager *StreamSinkManagerCtx) ID() string { return manager.id } -func (manager *StreamSinkManagerCtx) Bitrate() (int, error) { +func (manager *StreamSinkManagerCtx) Bitrate() int { if manager.getBitrate == nil { - return 0, nil + return 0 } + // recalculate bitrate every time, take screen resolution (and fps) into account - return manager.getBitrate() + // we called this function during startup, so it shouldn't error here + bitrate, err := manager.getBitrate() + if err != nil { + manager.logger.Err(err).Msg("unexpected error while getting bitrate") + } + + return bitrate } func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { @@ -127,7 +134,7 @@ func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { func (manager *StreamSinkManagerCtx) start() error { if len(manager.listeners) == 0 { - err := manager.createPipeline() + err := manager.CreatePipeline() if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { return err } @@ -140,7 +147,7 @@ func (manager *StreamSinkManagerCtx) start() error { func (manager *StreamSinkManagerCtx) stop() { if len(manager.listeners) == 0 { - manager.destroyPipeline() + manager.DestroyPipeline() manager.logger.Info().Msgf("last listener, stopping") } } @@ -259,7 +266,7 @@ func (manager *StreamSinkManagerCtx) Started() bool { return manager.ListenersCount() > 0 } -func (manager *StreamSinkManagerCtx) createPipeline() error { +func (manager *StreamSinkManagerCtx) CreatePipeline() error { manager.pipelineMu.Lock() defer manager.pipelineMu.Unlock() @@ -313,7 +320,7 @@ func (manager *StreamSinkManagerCtx) createPipeline() error { return nil } -func (manager *StreamSinkManagerCtx) destroyPipeline() { +func (manager *StreamSinkManagerCtx) DestroyPipeline() { manager.pipelineMu.Lock() defer manager.pipelineMu.Unlock() diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 4b348110..e220316c 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -9,6 +9,8 @@ import ( "github.com/pion/ice/v2" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/pkg/gcc" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/rs/zerolog" @@ -35,6 +37,9 @@ const keepAliveInterval = 2 * time.Second // send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval const rtcpPLIInterval = 3 * time.Second +// how often we check the bitrate of each client. Default is 250ms +const bitrateCheckInterval = 250 * time.Millisecond + func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { configuration := webrtc.Configuration{ SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, @@ -153,12 +158,12 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer { return manager.config.ICEServers } -func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, error) { +func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { // create media engine engine := &webrtc.MediaEngine{} for _, codec := range codecs { if err := codec.Register(engine); err != nil { - return nil, err + return nil, nil, err } } @@ -205,8 +210,29 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg // create interceptor registry registry := &interceptor.Registry{} + + congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { + if bitrate == 0 { + bitrate = 1000000 + } + return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(bitrate)) + }) + if err != nil { + return nil, nil, err + } + + estimatorChan := make(chan cc.BandwidthEstimator, 1) + congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) { + estimatorChan <- estimator + }) + + registry.Add(congestionController) + if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil { + return nil, nil, err + } + if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil { - return nil, err + return nil, nil, err } // create new API @@ -217,10 +243,12 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg ) // create new peer connection - return api.NewPeerConnection(manager.webrtcConfiguration) + configuration := manager.webrtcConfiguration + connection, err := api.NewPeerConnection(configuration) + return connection, <-estimatorChan, err } -func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) { +func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) { id := atomic.AddInt32(&manager.peerId, 1) manager.metrics.NewConnection(session) @@ -236,7 +264,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) video := manager.capture.Video() videoCodec := video.Codec() - connection, err := manager.newPeerConnection([]codec.RTPCodec{ + connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{ audioCodec, videoCodec, }, logger) @@ -244,6 +272,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) return nil, err } + if bitrate == 0 { + bitrate = estimator.GetTargetBitrate() + } + // asynchronously send local ICE Candidates if manager.config.ICETrickle { connection.OnICECandidate(func(candidate *webrtc.ICECandidate) { @@ -268,31 +300,117 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) } // set stream for audio track - err = audioTrack.SetStream(audio) + _, err = audioTrack.SetStream(audio) if err != nil { return nil, err } // video track - - videoTrack, err := NewTrack(logger, videoCodec, connection) + videoTrack, err := NewTrack(logger, videoCodec, connection, WithVideoAuto(videoAuto)) if err != nil { return nil, err } // let video stream bucket manager handle stream subscriptions - err = video.SetReceiver(videoTrack) - if err != nil { - return nil, err + video.SetReceiver(videoTrack) + + changeVideoFromBitrate := func(peerBitrate int) { + // when switching from manual to auto bitrate estimation, in case the estimator is + // idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate + if peerBitrate == 0 { + peerBitrate = estimator.GetTargetBitrate() + manager.logger.Debug(). + Int("peer_bitrate", peerBitrate). + Msg("evaluated bitrate") + } + + ok, err := videoTrack.SetBitrate(peerBitrate) + if err != nil { + logger.Error().Err(err). + Int("peer_bitrate", peerBitrate). + Msg("unable to set video bitrate") + return + } + + if !ok { + return + } + + videoID := videoTrack.stream.ID() + bitrate := videoTrack.stream.Bitrate() + + manager.metrics.SetVideoID(session, videoID) + manager.logger.Debug(). + Int("peer_bitrate", peerBitrate). + Int("video_bitrate", bitrate). + Str("video_id", videoID). + Msg("peer bitrate triggered video stream change") + + go session.Send( + event.SIGNAL_VIDEO, + message.SignalVideo{ + Video: videoID, + Bitrate: bitrate, + VideoAuto: videoTrack.VideoAuto(), + }) } + changeVideoFromID := func(videoID string) (bitrate int) { + changed, err := videoTrack.SetVideoID(videoID) + if err != nil { + logger.Error().Err(err). + Str("video_id", videoID). + Msg("unable to set video stream") + return + } + + if !changed { + return + } + + bitrate = videoTrack.stream.Bitrate() + + manager.logger.Debug(). + Str("video_id", videoID). + Int("video_bitrate", bitrate). + Msg("peer video id triggered video stream change") + + go session.Send( + event.SIGNAL_VIDEO, + message.SignalVideo{ + Video: videoID, + Bitrate: bitrate, + VideoAuto: videoTrack.VideoAuto(), + }) + + return + } + + manager.logger.Info(). + Int("target_bitrate", bitrate). + Msg("estimated initial peer bitrate") + // set initial video bitrate - if err = videoTrack.SetBitrate(bitrate); err != nil { - return nil, err - } + changeVideoFromBitrate(bitrate) - videoID := videoTrack.stream.ID() - manager.metrics.SetVideoID(session, videoID) + // use a ticker to get current client target bitrate + go func() { + ticker := time.NewTicker(bitrateCheckInterval) + defer ticker.Stop() + + for range ticker.C { + targetBitrate := estimator.GetTargetBitrate() + manager.metrics.SetReceiverEstimatedMaximumBitrate(session, float64(targetBitrate)) + + if connection.ConnectionState() == webrtc.PeerConnectionStateClosed { + break + } + if !videoTrack.VideoAuto() { + continue + } + changeVideoFromBitrate(targetBitrate) + } + }() // data channel @@ -302,27 +420,20 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) } peer := &WebRTCPeerCtx{ - logger: logger, - connection: connection, - dataChannel: dataChannel, - changeVideo: func(bitrate int) error { - if err := videoTrack.SetBitrate(bitrate); err != nil { - return err - } - - videoID := videoTrack.stream.ID() - manager.metrics.SetVideoID(session, videoID) - return nil - }, + logger: logger, + connection: connection, + dataChannel: dataChannel, + changeVideoFromBitrate: changeVideoFromBitrate, + changeVideoFromID: changeVideoFromID, // TODO: Refactor. - videoId: func() string { - return videoTrack.stream.ID() - }, + videoId: videoTrack.stream.ID, setPaused: func(isPaused bool) { videoTrack.SetPaused(isPaused) audioTrack.SetPaused(isPaused) }, - iceTrickle: manager.config.ICETrickle, + iceTrickle: manager.config.ICETrickle, + setVideoAuto: videoTrack.SetVideoAuto, + getVideoAuto: videoTrack.VideoAuto, } connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { @@ -515,11 +626,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) }) videoTrack.OnRTCP(func(p rtcp.Packet) { - switch rtcpPacket := p.(type) { - case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated. - manager.metrics.SetReceiverEstimatedMaximumBitrate(session, rtcpPacket.Bitrate) - - case *rtcp.ReceiverReport: + if rtcpPacket, ok := p.(*rtcp.ReceiverReport); ok { l := len(rtcpPacket.Reports) if l > 0 { // use only last report diff --git a/internal/webrtc/metrics.go b/internal/webrtc/metrics.go index dce22809..2da28e32 100644 --- a/internal/webrtc/metrics.go +++ b/internal/webrtc/metrics.go @@ -327,10 +327,10 @@ func (m *metricsCtx) SetVideoID(session types.Session, videoId string) { } } -func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float32) { +func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float64) { met := m.getBySession(session) - met.receiverEstimatedMaximumBitrate.Set(float64(bitrate)) + met.receiverEstimatedMaximumBitrate.Set(bitrate) } func (m *metricsCtx) SetReceiverReport(session types.Session, report rtcp.ReceptionReport) { diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index bee30a64..09e982fa 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -13,14 +13,17 @@ import ( ) type WebRTCPeerCtx struct { - mu sync.Mutex - logger zerolog.Logger - connection *webrtc.PeerConnection - dataChannel *webrtc.DataChannel - changeVideo func(bitrate int) error - videoId func() string - setPaused func(isPaused bool) - iceTrickle bool + mu sync.Mutex + logger zerolog.Logger + connection *webrtc.PeerConnection + dataChannel *webrtc.DataChannel + changeVideoFromBitrate func(bitrate int) + changeVideoFromID func(id string) int + videoId func() string + setPaused func(isPaused bool) + setVideoAuto func(auto bool) + getVideoAuto func() bool + iceTrickle bool } func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) { @@ -115,7 +118,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error return peer.connection.AddICECandidate(candidate) } -func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error { +func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error { peer.mu.Lock() defer peer.mu.Unlock() @@ -123,12 +126,24 @@ func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error { return types.ErrWebRTCConnectionNotFound } - peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate") - return peer.changeVideo(bitrate) + peer.changeVideoFromBitrate(peerBitrate) + return nil +} + +func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { + peer.mu.Lock() + defer peer.mu.Unlock() + + if peer.connection == nil { + return types.ErrWebRTCConnectionNotFound + } + + peer.changeVideoFromID(videoID) + return nil } // TODO: Refactor. -func (peer *WebRTCPeerCtx) GetVideoId() string { +func (peer *WebRTCPeerCtx) GetVideoID() string { peer.mu.Lock() defer peer.mu.Unlock() @@ -215,3 +230,11 @@ func (peer *WebRTCPeerCtx) Destroy() { peer.connection = nil } } + +func (peer *WebRTCPeerCtx) SetVideoAuto(auto bool) { + peer.setVideoAuto(auto) +} + +func (peer *WebRTCPeerCtx) VideoAuto() bool { + return peer.getVideoAuto() +} diff --git a/internal/webrtc/track.go b/internal/webrtc/track.go index 16cb12c9..6528d459 100644 --- a/internal/webrtc/track.go +++ b/internal/webrtc/track.go @@ -16,10 +16,12 @@ import ( ) type Track struct { - logger zerolog.Logger - track *webrtc.TrackLocalStaticSample - paused bool - listener func(sample types.Sample) + logger zerolog.Logger + track *webrtc.TrackLocalStaticSample + paused bool + videoAuto bool + videoAutoMu sync.RWMutex + listener func(sample types.Sample) stream types.StreamSinkManager streamMu sync.Mutex @@ -27,10 +29,19 @@ type Track struct { onRtcp func(rtcp.Packet) onRtcpMu sync.RWMutex - bitrateChange func(int) error + bitrateChange func(int) (bool, error) + videoChange func(string) (bool, error) } -func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) { +type option func(*Track) + +func WithVideoAuto(auto bool) option { + return func(t *Track) { + t.videoAuto = auto + } +} + +func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...option) (*Track, error) { id := codec.Type.String() track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream") if err != nil { @@ -44,6 +55,10 @@ func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.Pe track: track, } + for _, opt := range opts { + opt(t) + } + t.listener = func(sample types.Sample) { if t.paused { return @@ -96,13 +111,13 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) { } } -func (t *Track) SetStream(stream types.StreamSinkManager) error { +func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) { t.streamMu.Lock() defer t.streamMu.Unlock() // if we already listen to the stream, do nothing if t.stream == stream { - return nil + return false, nil } var err error @@ -111,12 +126,13 @@ func (t *Track) SetStream(stream types.StreamSinkManager) error { } else { err = stream.AddListener(&t.listener) } - - if err == nil { - t.stream = stream + if err != nil { + return false, err } - return err + t.stream = stream + + return true, nil } func (t *Track) RemoveStream() { @@ -140,14 +156,38 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) { t.onRtcp = f } -func (t *Track) SetBitrate(bitrate int) error { +func (t *Track) SetBitrate(bitrate int) (bool, error) { if t.bitrateChange == nil { - return fmt.Errorf("bitrate change not supported") + return false, fmt.Errorf("bitrate change not supported") } return t.bitrateChange(bitrate) } -func (t *Track) OnBitrateChange(f func(int) error) { +func (t *Track) SetVideoID(videoID string) (bool, error) { + if t.videoChange == nil { + return false, fmt.Errorf("video change not supported") + } + + return t.videoChange(videoID) +} + +func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) { t.bitrateChange = f } + +func (t *Track) OnVideoChange(f func(string) (bool, error)) { + t.videoChange = f +} + +func (t *Track) SetVideoAuto(auto bool) { + t.videoAutoMu.Lock() + defer t.videoAutoMu.Unlock() + t.videoAuto = auto +} + +func (t *Track) VideoAuto() bool { + t.videoAutoMu.RLock() + defer t.videoAutoMu.RUnlock() + return t.videoAuto +} diff --git a/internal/websocket/handler/signal.go b/internal/websocket/handler/signal.go index 9da371e8..ec1955c0 100644 --- a/internal/websocket/handler/signal.go +++ b/internal/websocket/handler/signal.go @@ -28,7 +28,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag } } - offer, err := h.webrtc.CreatePeer(session, payload.Bitrate) + offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto) if err != nil { return err } @@ -39,7 +39,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag webrtcPeer.SetPaused(true) } - payload.Video = webrtcPeer.GetVideoId() + payload.Video = webrtcPeer.GetVideoID() } session.Send( @@ -47,7 +47,9 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag message.SignalProvide{ SDP: offer.SDP, ICEServers: h.webrtc.ICEServers(), - Video: payload.Video, // TODO: Refactor. + Video: payload.Video, // TODO: Refactor + Bitrate: payload.Bitrate, + VideoAuto: payload.VideoAuto, }) return nil @@ -64,7 +66,7 @@ func (h *MessageHandlerCtx) signalRestart(session types.Session) error { return err } - // TODO: Use offer event intead. + // TODO: Use offer event instead. session.Send( event.SIGNAL_RESTART, message.SignalDescription{ @@ -123,25 +125,17 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message. return errors.New("webRTC peer does not exist") } - var err error - if payload.Bitrate == 0 { - // get bitrate from video id - payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video) - if err != nil { - return err + peer.SetVideoAuto(payload.VideoAuto) + + if payload.Video != "" { + if err := peer.SetVideoID(payload.Video); err != nil { + h.logger.Error().Err(err).Msg("failed to set video id") + } + } else { + if err := peer.SetVideoBitrate(payload.Bitrate); err != nil { + h.logger.Error().Err(err).Msg("failed to set video bitrate") } } - if err = peer.SetVideoBitrate(payload.Bitrate); err != nil { - return err - } - - session.Send( - event.SIGNAL_VIDEO, - message.SignalVideo{ - Video: peer.GetVideoId(), // TODO: Refactor. - Bitrate: payload.Bitrate, - }) - return nil } diff --git a/pkg/types/capture.go b/pkg/types/capture.go index a3a490ee..1f6b2cf6 100644 --- a/pkg/types/capture.go +++ b/pkg/types/capture.go @@ -8,9 +8,8 @@ import ( "strings" "github.com/PaesslerAG/gval" - "github.com/pion/webrtc/v3/pkg/media" - "github.com/demodesk/neko/pkg/types/codec" + "github.com/pion/webrtc/v3/pkg/media" ) var ( @@ -20,16 +19,23 @@ var ( type Sample media.Sample type Receiver interface { - SetStream(stream StreamSinkManager) error + SetStream(stream StreamSinkManager) (changed bool, err error) RemoveStream() - OnBitrateChange(f func(int) error) + OnBitrateChange(f func(bitrate int) (changed bool, err error)) + OnVideoChange(f func(videoID string) (changed bool, err error)) + VideoAuto() bool + SetVideoAuto(videoAuto bool) } type BucketsManager interface { IDs() []string Codec() codec.RTPCodec - SetReceiver(receiver Receiver) error + SetReceiver(receiver Receiver) RemoveReceiver(receiver Receiver) error + + DestroyAll() + RecreateAll() error + Shutdown() } type BroadcastManager interface { @@ -48,6 +54,7 @@ type ScreencastManager interface { type StreamSinkManager interface { ID() string Codec() codec.RTPCodec + Bitrate() int AddListener(listener *func(sample Sample)) error RemoveListener(listener *func(sample Sample)) error @@ -55,6 +62,9 @@ type StreamSinkManager interface { ListenersCount() int Started() bool + + CreatePipeline() error + DestroyPipeline() } type StreamSrcManager interface { @@ -201,17 +211,33 @@ func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (in }), } + // TOOD: do not read target-bitrate from pipeline, but only from config. + // TODO: This is only for vp8. expr, ok := config.GstParams["target-bitrate"] if !ok { - return 0, fmt.Errorf("target-bitrate not found") + // TODO: This is only for h264. + expr, ok = config.GstParams["bitrate"] + if !ok { + return 0, fmt.Errorf("bitrate not found") + } } - targetBitrate, err := gval.Evaluate(expr, values, language...) + bitrate, err := gval.Evaluate(expr, values, language...) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to evaluate bitrate: %w", err) } - return targetBitrate.(int), nil + var bitrateInt int + switch val := bitrate.(type) { + case int: + bitrateInt = val + case float64: + bitrateInt = (int)(val) + default: + return 0, fmt.Errorf("bitrate is not int or float64") + } + + return bitrateInt, nil } } diff --git a/pkg/types/message/messages.go b/pkg/types/message/messages.go index c263a006..832b24f0 100644 --- a/pkg/types/message/messages.go +++ b/pkg/types/message/messages.go @@ -48,7 +48,10 @@ type SystemDisconnect struct { type SignalProvide struct { SDP string `json:"sdp"` ICEServers []types.ICEServer `json:"iceservers"` - Video string `json:"video"` // TODO: Refactor. + // TODO: Use SignalVideo struct. + Video string `json:"video"` + Bitrate int `json:"bitrate"` + VideoAuto bool `json:"video_auto"` } type SignalCandidate struct { @@ -60,8 +63,9 @@ type SignalDescription struct { } type SignalVideo struct { - Video string `json:"video"` // TODO: Refactor. - Bitrate int `json:"bitrate"` + Video string `json:"video"` + Bitrate int `json:"bitrate"` + VideoAuto bool `json:"video_auto"` } ///////////////////////////// diff --git a/pkg/types/webrtc.go b/pkg/types/webrtc.go index c67f4857..ad5c503d 100644 --- a/pkg/types/webrtc.go +++ b/pkg/types/webrtc.go @@ -25,8 +25,11 @@ type WebRTCPeer interface { SetCandidate(candidate webrtc.ICECandidateInit) error SetVideoBitrate(bitrate int) error - GetVideoId() string + SetVideoID(videoID string) error + GetVideoID() string SetPaused(isPaused bool) error + SetVideoAuto(auto bool) + VideoAuto() bool SendCursorPosition(x, y int) error SendCursorImage(cur *CursorImage, img []byte) error @@ -40,6 +43,6 @@ type WebRTCManager interface { ICEServers() []ICEServer - CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error) + CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) SetCursorPosition(x, y int) }