From 3e8d686c0f37de117f745da34d01e09a8dfeeddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Mon, 15 May 2023 19:29:39 +0200 Subject: [PATCH] Bandwidth estimator refactor (#46) * rewrite to use stream selector. * WIP. * add nacks to metrics. * add estimate trend. * estimator based on trend detector. * add estimator unstable duration. * add estimator debug. * add stalled duration. * estimator move values to config. * change default estimator values. * minor style changes. * fix websocket video messages. * replace video track with ivdeo id. --- internal/capture/buckets/buckets.go | 145 ------------ internal/capture/buckets/buckets_test.go | 83 ------- internal/capture/buckets/queue.go | 88 ------- internal/capture/buckets/queue_test.go | 99 -------- internal/capture/manager.go | 31 +-- internal/capture/streamselector.go | 206 +++++++++++++++++ internal/capture/streamsink.go | 59 ++--- internal/config/webrtc.go | 81 ++++++- internal/webrtc/manager.go | 91 ++++---- internal/webrtc/metrics.go | 25 +- internal/webrtc/peer.go | 279 +++++++++++++++++------ internal/webrtc/track.go | 73 ++---- internal/websocket/handler/signal.go | 48 ++-- pkg/types/capture.go | 141 ++++++------ pkg/types/message/messages.go | 9 +- pkg/types/webrtc.go | 9 +- pkg/utils/trenddetector.go | 153 +++++++++++++ 17 files changed, 845 insertions(+), 775 deletions(-) delete mode 100644 internal/capture/buckets/buckets.go delete mode 100644 internal/capture/buckets/buckets_test.go delete mode 100644 internal/capture/buckets/queue.go delete mode 100644 internal/capture/buckets/queue_test.go create mode 100644 internal/capture/streamselector.go create mode 100644 pkg/utils/trenddetector.go diff --git a/internal/capture/buckets/buckets.go b/internal/capture/buckets/buckets.go deleted file mode 100644 index 6d0c9870..00000000 --- a/internal/capture/buckets/buckets.go +++ /dev/null @@ -1,145 +0,0 @@ -package buckets - -import ( - "errors" - "sort" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - - "github.com/demodesk/neko/pkg/types" - "github.com/demodesk/neko/pkg/types/codec" -) - -type BucketsManagerCtx struct { - logger zerolog.Logger - codec codec.RTPCodec - streams map[string]types.StreamSinkManager - streamIDs []string -} - -func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx { - logger := log.With(). - Str("module", "capture"). - Str("submodule", "buckets"). - Logger() - - return &BucketsManagerCtx{ - logger: logger, - codec: codec, - streams: streams, - streamIDs: streamIDs, - } -} - -func (manager *BucketsManagerCtx) Shutdown() { - manager.logger.Info().Msgf("shutdown") - - manager.DestroyAll() -} - -func (manager *BucketsManagerCtx) DestroyAll() { - for _, stream := range manager.streams { - if stream.Started() { - stream.DestroyPipeline() - } - } -} - -func (manager *BucketsManagerCtx) RecreateAll() error { - for _, stream := range manager.streams { - if stream.Started() { - err := stream.CreatePipeline() - if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { - return err - } - } - } - return nil -} - -func (manager *BucketsManagerCtx) IDs() []string { - return manager.streamIDs -} - -func (manager *BucketsManagerCtx) Codec() codec.RTPCodec { - return manager.codec -} - -func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) { - // bitrate history is per receiver - bitrateHistory := &queue{} - - receiver.OnBitrateChange(func(peerBitrate int) (bool, error) { - bitrate := peerBitrate - if receiver.VideoAuto() { - bitrate = bitrateHistory.normaliseBitrate(bitrate) - } - - stream := manager.findNearestStream(bitrate) - streamID := stream.ID() - - // TODO: make this less noisy in logs - manager.logger.Debug(). - Str("video_id", streamID). - Int("len", bitrateHistory.len()). - Int("peer_bitrate", peerBitrate). - Int("bitrate", bitrate). - Msg("change video bitrate") - - return receiver.SetStream(stream) - }) - - receiver.OnVideoChange(func(videoID string) (bool, error) { - stream := manager.streams[videoID] - manager.logger.Info(). - Str("video_id", videoID). - Msg("video change") - - return receiver.SetStream(stream) - }) -} - -func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager { - type streamDiff struct { - id string - bitrateDiff int - } - - sortDiff := func(a, b int) bool { - switch { - case a < 0 && b < 0: - return a > b - case a >= 0: - if b >= 0 { - return a <= b - } - return true - } - return false - } - - var diffs []streamDiff - - for _, stream := range manager.streams { - diffs = append(diffs, streamDiff{ - id: stream.ID(), - bitrateDiff: peerBitrate - stream.Bitrate(), - }) - } - - sort.Slice(diffs, func(i, j int) bool { - return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff) - }) - - bestDiff := diffs[0] - - return manager.streams[bestDiff.id] -} - -func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error { - receiver.OnBitrateChange(nil) - receiver.OnVideoChange(nil) - receiver.RemoveStream() - return nil -} diff --git a/internal/capture/buckets/buckets_test.go b/internal/capture/buckets/buckets_test.go deleted file mode 100644 index 9fedcdd2..00000000 --- a/internal/capture/buckets/buckets_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package buckets - -import ( - "reflect" - "testing" - - "github.com/demodesk/neko/pkg/types" - "github.com/demodesk/neko/pkg/types/codec" -) - -func TestBucketsManagerCtx_FindNearestStream(t *testing.T) { - type fields struct { - codec codec.RTPCodec - streams map[string]types.StreamSinkManager - } - type args struct { - peerBitrate int - } - tests := []struct { - name string - fields fields - args args - want types.StreamSinkManager - }{ - { - name: "findNearestStream", - fields: fields{ - streams: map[string]types.StreamSinkManager{ - "1": mockStreamSink{ - id: "1", - bitrate: 500, - }, - "2": mockStreamSink{ - id: "2", - bitrate: 750, - }, - "3": mockStreamSink{ - id: "3", - bitrate: 1000, - }, - "4": mockStreamSink{ - id: "4", - bitrate: 1250, - }, - "5": mockStreamSink{ - id: "5", - bitrate: 1700, - }, - }, - }, - args: args{ - peerBitrate: 950, - }, - want: mockStreamSink{ - id: "2", - bitrate: 750, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{}) - - if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) { - t.Errorf("findNearestStream() = %v, want %v", got, tt.want) - } - }) - } -} - -type mockStreamSink struct { - id string - bitrate int - types.StreamSinkManager -} - -func (m mockStreamSink) ID() string { - return m.id -} - -func (m mockStreamSink) Bitrate() int { - return m.bitrate -} diff --git a/internal/capture/buckets/queue.go b/internal/capture/buckets/queue.go deleted file mode 100644 index e79eb635..00000000 --- a/internal/capture/buckets/queue.go +++ /dev/null @@ -1,88 +0,0 @@ -package buckets - -import ( - "math" - "sync" - "time" -) - -type queue struct { - sync.Mutex - q []elem -} - -type elem struct { - created time.Time - bitrate int -} - -func (q *queue) push(v elem) { - q.Lock() - defer q.Unlock() - - // if the first element is older than 10 seconds, remove it - if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second { - q.q = q.q[1:] - } - q.q = append(q.q, v) -} - -func (q *queue) len() int { - q.Lock() - defer q.Unlock() - return len(q.q) -} - -func (q *queue) avg() int { - q.Lock() - defer q.Unlock() - if len(q.q) == 0 { - return 0 - } - sum := 0 - for _, v := range q.q { - sum += v.bitrate - } - return sum / len(q.q) -} - -func (q *queue) avgLastN(n int) int { - if n <= 0 { - return q.avg() - } - q.Lock() - defer q.Unlock() - if len(q.q) == 0 { - return 0 - } - sum := 0 - for _, v := range q.q[len(q.q)-n:] { - sum += v.bitrate - } - return sum / n -} - -func (q *queue) normaliseBitrate(currentBitrate int) int { - avgBitrate := float64(q.avg()) - histLen := float64(q.len()) - - q.push(elem{ - bitrate: currentBitrate, - created: time.Now(), - }) - - if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 { - return currentBitrate - } - - lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen)) - if lastN > q.len() { - lastN = q.len() - } - - if lastN == 0 { - return currentBitrate - } - - return q.avgLastN(lastN) -} diff --git a/internal/capture/buckets/queue_test.go b/internal/capture/buckets/queue_test.go deleted file mode 100644 index 4deda5f8..00000000 --- a/internal/capture/buckets/queue_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package buckets - -import "testing" - -func Queue_normaliseBitrate(t *testing.T) { - type fields struct { - queue *queue - } - type args struct { - currentBitrate int - } - tests := []struct { - name string - fields fields - args args - want []int - }{ - { - name: "normaliseBitrate: big drop", - fields: fields{ - queue: &queue{ - q: []elem{ - {bitrate: 900}, - {bitrate: 750}, - {bitrate: 780}, - {bitrate: 1100}, - {bitrate: 950}, - {bitrate: 700}, - {bitrate: 800}, - {bitrate: 900}, - {bitrate: 1000}, - {bitrate: 1100}, - // avg = 898 - }, - }, - }, - args: args{ - currentBitrate: 350, - }, - want: []int{816, 700, 537, 350, 350}, - }, { - name: "normaliseBitrate: small drop", - fields: fields{ - queue: &queue{ - q: []elem{ - {bitrate: 900}, - {bitrate: 750}, - {bitrate: 780}, - {bitrate: 1100}, - {bitrate: 950}, - {bitrate: 700}, - {bitrate: 800}, - {bitrate: 900}, - {bitrate: 1000}, - {bitrate: 1100}, - // avg = 898 - }, - }, - }, - args: args{ - currentBitrate: 700, - }, - want: []int{878, 842, 825, 825, 812, 787, 750, 700}, - }, { - name: "normaliseBitrate", - fields: fields{ - queue: &queue{ - q: []elem{ - {bitrate: 900}, - {bitrate: 750}, - {bitrate: 780}, - {bitrate: 1100}, - {bitrate: 950}, - {bitrate: 700}, - {bitrate: 800}, - {bitrate: 900}, - {bitrate: 1000}, - {bitrate: 1100}, - // avg = 898 - }, - }, - }, - args: args{ - currentBitrate: 1350, - }, - want: []int{943, 1003, 1060, 1085}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := tt.fields.queue - for i := 0; i < len(tt.want); i++ { - if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] { - t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i]) - } - } - }) - } -} diff --git a/internal/capture/manager.go b/internal/capture/manager.go index 3b8288f3..2a71485e 100644 --- a/internal/capture/manager.go +++ b/internal/capture/manager.go @@ -8,7 +8,6 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/demodesk/neko/internal/capture/buckets" "github.com/demodesk/neko/internal/config" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types/codec" @@ -23,7 +22,7 @@ type CaptureManagerCtx struct { broadcast *BroacastManagerCtx screencast *ScreencastManagerCtx audio *StreamSinkManagerCtx - video types.BucketsManager + video *StreamSelectorManagerCtx // sources webcam *StreamSrcManagerCtx @@ -68,13 +67,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt Str("pipeline", pipeline). Msg("syntax check for video stream pipeline passed") - getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize) - if _, err = getVideoBitrate(); err != nil { - logger.Panic().Err(err).Msg("unable to get video bitrate") - } - // append to videos - videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate) + videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id) } return &CaptureManagerCtx{ @@ -140,8 +134,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt "! %s "+ "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, ), nil - }, "audio", nil), - video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs), + }, "audio"), + video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs), // sources webcam: streamSrcNew(config.WebcamEnabled, map[string]string{ @@ -202,7 +196,7 @@ func (manager *CaptureManagerCtx) Start() { } manager.desktop.OnBeforeScreenSizeChange(func() { - manager.video.DestroyAll() + manager.video.destroyPipelines() if manager.broadcast.Started() { manager.broadcast.destroyPipeline() @@ -214,7 +208,7 @@ func (manager *CaptureManagerCtx) Start() { }) manager.desktop.OnAfterScreenSizeChange(func() { - err := manager.video.RecreateAll() + err := manager.video.recreatePipelines() if err != nil { manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines") } @@ -242,7 +236,7 @@ func (manager *CaptureManagerCtx) Shutdown() error { manager.screencast.shutdown() manager.audio.shutdown() - manager.video.Shutdown() + manager.video.shutdown() manager.webcam.shutdown() manager.microphone.shutdown() @@ -250,15 +244,6 @@ func (manager *CaptureManagerCtx) Shutdown() error { return nil } -func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) { - cfg, ok := manager.config.VideoPipelines[videoID] - if !ok { - return 0, fmt.Errorf("video config not found for %s", videoID) - } - - return cfg.GetBitrateFn(manager.desktop.GetScreenSize)() -} - func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager { return manager.broadcast } @@ -271,7 +256,7 @@ func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager { return manager.audio } -func (manager *CaptureManagerCtx) Video() types.BucketsManager { +func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager { return manager.video } diff --git a/internal/capture/streamselector.go b/internal/capture/streamselector.go new file mode 100644 index 00000000..fd139032 --- /dev/null +++ b/internal/capture/streamselector.go @@ -0,0 +1,206 @@ +package capture + +import ( + "errors" + "sort" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/types/codec" +) + +type StreamSelectorManagerCtx struct { + logger zerolog.Logger + codec codec.RTPCodec + streams map[string]types.StreamSinkManager + streamIDs []string +} + +func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx { + logger := log.With(). + Str("module", "capture"). + Str("submodule", "stream-selector"). + Logger() + + return &StreamSelectorManagerCtx{ + logger: logger, + codec: codec, + streams: streams, + streamIDs: streamIDs, + } +} + +func (manager *StreamSelectorManagerCtx) shutdown() { + manager.logger.Info().Msgf("shutdown") + + manager.destroyPipelines() +} + +func (manager *StreamSelectorManagerCtx) destroyPipelines() { + for _, stream := range manager.streams { + if stream.Started() { + stream.DestroyPipeline() + } + } +} + +func (manager *StreamSelectorManagerCtx) recreatePipelines() error { + for _, stream := range manager.streams { + if stream.Started() { + err := stream.CreatePipeline() + if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { + return err + } + } + } + return nil +} + +func (manager *StreamSelectorManagerCtx) IDs() []string { + return manager.streamIDs +} + +func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec { + return manager.codec +} + +func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) { + // select stream by ID + if selector.ID != "" { + // select lower stream + if selector.Type == types.StreamSelectorTypeLower { + var lastStream types.StreamSinkManager + for i := len(manager.streamIDs) - 1; i >= 0; i-- { + streamID := manager.streamIDs[i] + if streamID == selector.ID { + return lastStream, lastStream != nil + } + stream, ok := manager.streams[streamID] + if ok { + lastStream = stream + } + } + // we couldn't find a lower stream + return nil, false + } + + // select higher stream + if selector.Type == types.StreamSelectorTypeHigher { + var lastStream types.StreamSinkManager + for _, streamID := range manager.streamIDs { + if streamID == selector.ID { + return lastStream, lastStream != nil + } + stream, ok := manager.streams[streamID] + if ok { + lastStream = stream + } + } + // we couldn't find a higher stream + return nil, false + } + + // select exact stream + stream, ok := manager.streams[selector.ID] + return stream, ok + } + + // select stream by bitrate + if selector.Bitrate != 0 { + // select stream by nearest bitrate + if selector.Type == types.StreamSelectorTypeNearest { + return manager.nearestBitrate(selector.Bitrate), true + } + + // select lower stream + if selector.Type == types.StreamSelectorTypeLower { + // start from the highest stream, and go down, until we find a lower stream + for i := len(manager.streamIDs) - 1; i >= 0; i-- { + streamID := manager.streamIDs[i] + stream := manager.streams[streamID] + // if stream should be considered in calculation + considered := stream.Bitrate() != 0 && stream.Started() + if considered && stream.Bitrate() < selector.Bitrate { + return stream, true + } + } + // we couldn't find a lower stream + return nil, false + } + + // select higher stream + if selector.Type == types.StreamSelectorTypeHigher { + // start from the lowest stream, and go up, until we find a higher stream + for _, streamID := range manager.streamIDs { + stream := manager.streams[streamID] + // if stream should be considered in calculation + considered := stream.Bitrate() != 0 && stream.Started() + if considered && stream.Bitrate() > selector.Bitrate { + return stream, true + } + } + // we couldn't find a higher stream + return nil, false + } + + // select stream by exact bitrate + for _, stream := range manager.streams { + if stream.Bitrate() == selector.Bitrate { + return stream, true + } + } + } + + // we couldn't find a stream + return nil, false +} + +// TODO: This is a very naive implementation, we should use a binary search instead. +func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager { + type streamDiff struct { + id string + bitrateDiff int + } + + sortDiff := func(a, b int) bool { + switch { + case a < 0 && b < 0: + return a > b + case a >= 0: + if b >= 0 { + return a <= b + } + return true + } + return false + } + + var diffs []streamDiff + + for _, stream := range manager.streams { + // if stream should be considered in calculation + considered := stream.Bitrate() != 0 && stream.Started() + if !considered { + continue + } + diffs = append(diffs, streamDiff{ + id: stream.ID(), + bitrateDiff: int(bitrate) - int(stream.Bitrate()), + }) + } + + // no streams available + if len(diffs) == 0 { + // return first (lowest) stream + return manager.streams[manager.streamIDs[0]] + } + + sort.Slice(diffs, func(i, j int) bool { + return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff) + }) + + bestDiff := diffs[0] + return manager.streams[bestDiff.id] +} diff --git a/internal/capture/streamsink.go b/internal/capture/streamsink.go index 8df41a54..99703bba 100644 --- a/internal/capture/streamsink.go +++ b/internal/capture/streamsink.go @@ -21,9 +21,10 @@ import ( var moveSinkListenerMu = sync.Mutex{} type StreamSinkManagerCtx struct { - id string - getBitrate func() (int, error) - waitForKf bool // wait for a keyframe before sending samples + id string + + // wait for a keyframe before sending samples + waitForKf bool bitrate uint64 // atomic brBuckets map[int]float64 @@ -48,22 +49,23 @@ type StreamSinkManagerCtx struct { pipelinesActive prometheus.Gauge } -func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx { +func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx { logger := log.With(). Str("module", "capture"). Str("submodule", "stream-sink"). Str("id", id).Logger() manager := &StreamSinkManagerCtx{ - id: id, - getBitrate: getBitrate, - // only wait for keyframes if the codec is video - waitForKf: c.IsVideo(), + id: id, + // only wait for keyframes if the codec is video + waitForKf: codec.IsVideo(), + + bitrate: 0, brBuckets: map[int]float64{}, logger: logger, - codec: c, + codec: codec, pipelineFn: pipelineFn, listeners: map[uintptr]types.SampleListener{}, @@ -77,8 +79,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin Help: "Current number of listeners for a pipeline.", ConstLabels: map[string]string{ "video_id": id, - "codec_name": c.Name, - "codec_type": c.Type.String(), + "codec_name": codec.Name, + "codec_type": codec.Type.String(), }, }), totalBytes: promauto.NewCounter(prometheus.CounterOpts{ @@ -88,8 +90,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin Help: "Total number of bytes created by the pipeline.", ConstLabels: map[string]string{ "video_id": id, - "codec_name": c.Name, - "codec_type": c.Type.String(), + "codec_name": codec.Name, + "codec_type": codec.Type.String(), }, }), pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{ @@ -100,8 +102,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin ConstLabels: map[string]string{ "submodule": "streamsink", "video_id": id, - "codec_name": c.Name, - "codec_type": c.Type.String(), + "codec_name": codec.Name, + "codec_type": codec.Type.String(), }, }), pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{ @@ -112,8 +114,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin ConstLabels: map[string]string{ "submodule": "streamsink", "video_id": id, - "codec_name": c.Name, - "codec_type": c.Type.String(), + "codec_name": codec.Name, + "codec_type": codec.Type.String(), }, }), } @@ -141,27 +143,8 @@ func (manager *StreamSinkManagerCtx) ID() string { return manager.id } -func (manager *StreamSinkManagerCtx) Bitrate() int { - // TODO: fix bitrate switching calculation - // return real bitrate if available - //realBitrate := atomic.LoadUint64(&manager.bitrate) - //if realBitrate != 0 { - // return int(realBitrate) - //} - - // if we do not have function to estimate bitrate, return 0 - if manager.getBitrate == nil { - return 0 - } - - // recalculate bitrate every time, take screen resolution (and fps) into account - // we called this function during startup, so it shouldn't error here - bitrate, err := manager.getBitrate() - if err != nil { - manager.logger.Err(err).Msg("unexpected error while getting bitrate") - } - - return bitrate +func (manager *StreamSinkManagerCtx) Bitrate() uint64 { + return atomic.LoadUint64(&manager.bitrate) } func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { diff --git a/internal/config/webrtc.go b/internal/config/webrtc.go index 54a40b5b..40fad4bd 100644 --- a/internal/config/webrtc.go +++ b/internal/config/webrtc.go @@ -3,6 +3,7 @@ package config import ( "strconv" "strings" + "time" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -15,6 +16,28 @@ import ( // default stun server const defStunSrv = "stun:stun.l.google.com:19302" +type WebRTCEstimator struct { + Enabled bool + Passive bool + Debug bool + InitialBitrate int + + // how often to read and process bandwidth estimation reports + ReadInterval time.Duration + // how long to wait for stable connection (only neutral or upward trend) before upgrading + StableDuration time.Duration + // how long to wait for unstable connection (downward trend) before downgrading + UnstableDuration time.Duration + // how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading + StalledDuration time.Duration + // how long to wait before downgrading again after previous downgrade + DowngradeBackoff time.Duration + // how long to wait before upgrading again after previous upgrade + UpgradeBackoff time.Duration + // how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade + DiffThreshold float64 +} + type WebRTC struct { ICELite bool ICETrickle bool @@ -28,9 +51,7 @@ type WebRTC struct { NAT1To1IPs []string IpRetrievalUrl string - EstimatorEnabled bool - EstimatorPassive bool - EstimatorInitialBitrate int + Estimator WebRTCEstimator } func (WebRTC) Init(cmd *cobra.Command) error { @@ -96,11 +117,51 @@ func (WebRTC) Init(cmd *cobra.Command) error { return err } + cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator") + if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil { + return err + } + cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator") if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil { return err } + cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports") + if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading") + if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading") + if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading") + if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade") + if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade") + if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil { + return err + } + + cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade") + if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil { + return err + } + return nil } @@ -197,7 +258,15 @@ func (s *WebRTC) Set() { // bandwidth estimator - s.EstimatorEnabled = viper.GetBool("webrtc.estimator.enabled") - s.EstimatorPassive = viper.GetBool("webrtc.estimator.passive") - s.EstimatorInitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate") + s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled") + s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive") + s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug") + s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate") + s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval") + s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration") + s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration") + s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration") + s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff") + s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff") + s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold") } diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 31c439aa..e98f0b8e 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -24,6 +24,7 @@ import ( "github.com/demodesk/neko/pkg/types/codec" "github.com/demodesk/neko/pkg/types/event" "github.com/demodesk/neko/pkg/types/message" + "github.com/demodesk/neko/pkg/utils" ) const ( @@ -167,7 +168,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer { return manager.config.ICEServersFrontend } -func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { +func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { // create media engine engine := &webrtc.MediaEngine{} for _, codec := range codecs { @@ -223,14 +224,10 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs // create bandwidth estimator estimatorChan := make(chan cc.BandwidthEstimator, 1) - if manager.config.EstimatorEnabled { + if manager.config.Estimator.Enabled { congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { - if bitrate == 0 { - bitrate = manager.config.EstimatorInitialBitrate - } - return gcc.NewSendSideBWE( - gcc.SendSideBWEInitialBitrate(bitrate), + gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate), gcc.SendSideBWEPacer(gcc.NewNoOpPacer()), ) }) @@ -268,7 +265,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs return connection, <-estimatorChan, err } -func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) { +func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) { id := atomic.AddInt32(&manager.peerId, 1) // get metrics for session @@ -287,12 +284,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, video := manager.capture.Video() videoCodec := video.Codec() - connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{ - audioCodec, - videoCodec, - }, bitrate) + connection, estimator, err := manager.newPeerConnection( + logger, []codec.RTPCodec{audioCodec, videoCodec}) if err != nil { - return nil, err + return nil, nil, err } // asynchronously send local ICE Candidates @@ -311,47 +306,34 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, }) } - // if bitrate is 0, and estimator is enabled, use estimator bitrate - if bitrate == 0 && estimator != nil { - bitrate = estimator.GetTargetBitrate() - } - // audio track audioTrack, err := NewTrack(logger, audioCodec, connection) if err != nil { - return nil, err + return nil, nil, err } // set stream for audio track _, err = audioTrack.SetStream(audio) if err != nil { - return nil, err + return nil, nil, err } - // if estimator is disabled, or in passive mode, disable video auto bitrate - if !manager.config.EstimatorEnabled || manager.config.EstimatorPassive { - videoAuto = false - } - - videoRtcp := make(chan []rtcp.Packet, 1) - // video track - videoTrack, err := NewTrack(logger, videoCodec, connection, - WithVideoAuto(videoAuto), - WithRtcpChan(videoRtcp), - ) + videoRtcp := make(chan []rtcp.Packet, 1) + videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp)) if err != nil { - return nil, err + return nil, nil, err } - // let video stream bucket manager handle stream subscriptions - video.SetReceiver(videoTrack) + // + // stream for video track will be set later + // // data channel dataChannel, err := connection.CreateDataChannel("data", nil) if err != nil { - return nil, err + return nil, nil, err } peer := &WebRTCPeerCtx{ @@ -359,24 +341,29 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, session: session, metrics: metrics, connection: connection, - estimator: estimator, + // bandwidth estimator + estimator: estimator, + estimateTrend: utils.NewTrendDetector( + utils.TrendDetectorParams{ + // Probing + //RequiredSamples: 3, + //DownwardTrendThreshold: 0.0, + //CollapseValues: false, + // Non-Probing + RequiredSamples: 8, + DownwardTrendThreshold: -0.5, + CollapseValues: true, + }), + // stream selectors + videoSelector: manager.capture.Video(), // tracks & channels audioTrack: audioTrack, videoTrack: videoTrack, dataChannel: dataChannel, rtcpChannel: videoRtcp, // config - iceTrickle: manager.config.ICETrickle, - estimatorPassive: manager.config.EstimatorPassive, - } - - logger.Info(). - Int("target_bitrate", bitrate). - Msg("estimated initial peer bitrate") - - // set initial video bitrate - if err := peer.SetVideoBitrate(bitrate); err != nil { - return nil, err + iceTrickle: manager.config.ICETrickle, + estimatorConfig: manager.config.Estimator, } connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { @@ -492,9 +479,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, // ensure we only run this once once.Do(func() { session.SetWebRTCConnected(peer, false) - if err = video.RemoveReceiver(videoTrack); err != nil { - logger.Err(err).Msg("failed to remove video receiver") - } + // + // TODO: Shutdown peer? + // audioTrack.Shutdown() videoTrack.Shutdown() close(videoRtcp) @@ -542,7 +529,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, offer, err := peer.CreateOffer(false) if err != nil { - return nil, err + return nil, nil, err } // on negotiation needed handler must be registered after creating initial @@ -576,7 +563,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, // start estimator reader go peer.estimatorReader() - return offer, nil + return offer, peer, nil } func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) { diff --git a/internal/webrtc/metrics.go b/internal/webrtc/metrics.go index 79cbd620..cc212575 100644 --- a/internal/webrtc/metrics.go +++ b/internal/webrtc/metrics.go @@ -121,7 +121,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics { Name: "receiver_estimated_maximum_bitrate", Namespace: "neko", Subsystem: "webrtc", - Help: "Receiver Estimated Maximum Bitrate from SCTP.", + Help: "Receiver Estimated Maximum Bitrate from RTCP.", ConstLabels: map[string]string{ "session_id": sessionId, }, @@ -140,7 +140,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics { Name: "receiver_report_delay", Namespace: "neko", Subsystem: "webrtc", - Help: "Receiver Report Delay from SCTP, expressed in units of 1/65536 seconds.", + Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.", ConstLabels: map[string]string{ "session_id": sessionId, }, @@ -149,7 +149,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics { Name: "receiver_report_jitter", Namespace: "neko", Subsystem: "webrtc", - Help: "Receiver Report Jitter from SCTP.", + Help: "Receiver Report Jitter from RTCP.", ConstLabels: map[string]string{ "session_id": sessionId, }, @@ -158,7 +158,17 @@ func (m *metricsManager) getBySession(session types.Session) *metrics { Name: "receiver_report_total_lost", Namespace: "neko", Subsystem: "webrtc", - Help: "Receiver Report Total Lost from SCTP.", + Help: "Receiver Report Total Lost from RTCP.", + ConstLabels: map[string]string{ + "session_id": sessionId, + }, + }), + + transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{ + Name: "transport_layer_nacks", + Namespace: "neko", + Subsystem: "webrtc", + Help: "Transport Layer NACKs from RTCP.", ConstLabels: map[string]string{ "session_id": sessionId, }, @@ -236,6 +246,8 @@ type metrics struct { receiverReportJitter prometheus.Gauge receiverReportTotalLost prometheus.Gauge + transportLayerNacks prometheus.Counter + iceBytesSent prometheus.Gauge iceBytesReceived prometheus.Gauge sctpBytesSent prometheus.Gauge @@ -386,6 +398,11 @@ func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) { // use only last report met.SetReceiverReport(rtcpPacket.Reports[l-1]) } + case *rtcp.TransportLayerNack: + for _, pair := range rtcpPacket.Nacks { + packetList := pair.PacketList() + met.transportLayerNacks.Add(float64(len(packetList))) + } } } } diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index 91b405c9..590b598e 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -11,15 +11,12 @@ import ( "github.com/pion/webrtc/v3" "github.com/rs/zerolog" + "github.com/demodesk/neko/internal/config" "github.com/demodesk/neko/internal/webrtc/payload" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types/event" "github.com/demodesk/neko/pkg/types/message" -) - -const ( - // how often to read and process bandwidth estimation reports - estimatorReadInterval = 250 * time.Millisecond + "github.com/demodesk/neko/pkg/utils" ) type WebRTCPeerCtx struct { @@ -28,15 +25,20 @@ type WebRTCPeerCtx struct { session types.Session metrics *metrics connection *webrtc.PeerConnection - estimator cc.BandwidthEstimator + // bandwidth estimator + estimator cc.BandwidthEstimator + estimateTrend *utils.TrendDetector + // stream selectors + videoSelector types.StreamSelectorManager // tracks & channels audioTrack *Track videoTrack *Track dataChannel *webrtc.DataChannel rtcpChannel chan []rtcp.Packet // config - iceTrickle bool - estimatorPassive bool + iceTrickle bool + estimatorConfig config.WebRTCEstimator + videoAuto bool } // @@ -102,6 +104,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error return peer.connection.AddICECandidate(candidate) } +// TODO: Add shutdown function? func (peer *WebRTCPeerCtx) Destroy() { peer.mu.Lock() defer peer.mu.Unlock() @@ -111,32 +114,186 @@ func (peer *WebRTCPeerCtx) Destroy() { } func (peer *WebRTCPeerCtx) estimatorReader() { + conf := peer.estimatorConfig + + // if estimator is not in debug mode, use a nop logger + var debugLogger zerolog.Logger + if conf.Debug { + debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel) + } else { + debugLogger = zerolog.Nop() + } + // if estimator is disabled, do nothing if peer.estimator == nil { return } // use a ticker to get current client target bitrate - ticker := time.NewTicker(estimatorReadInterval) + ticker := time.NewTicker(conf.ReadInterval) defer ticker.Stop() + // since when is the estimate stable/unstable + stableSince := time.Now() // we asume stable at start + unstableSince := time.Time{} + // since when are we neutral but cannot accomodate current bitrate + // we migt be stalled or estimator just reached zer (very bad connection) + stalledSince := time.Time{} + // when was the last upgrade/downgrade + lastUpgradeTime := time.Time{} + lastDowngradeTime := time.Time{} + for range ticker.C { targetBitrate := peer.estimator.GetTargetBitrate() peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate)) + // if peer connection is closed, stop reading if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed { break } - if !peer.videoTrack.VideoAuto() { + // if estimation is disabled, do nothing + if !peer.videoAuto || conf.Passive { continue } - if !peer.estimatorPassive { - err := peer.SetVideoBitrate(targetBitrate) - if err != nil { - peer.logger.Warn().Err(err).Msg("failed to set video bitrate") + // get trend direction to decide if we should upgrade or downgrade + peer.estimateTrend.AddValue(int64(targetBitrate)) + direction := peer.estimateTrend.GetDirection() + + // get current stream bitrate + stream, ok := peer.videoTrack.Stream() + if !ok { + debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation") + continue + } + + // if stream bitrate is 0, we need to wait for some time until we get a valid value + streamId, streamBitrate := stream.ID(), stream.Bitrate() + if streamBitrate == 0 { + debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time") + continue + } + + // check whats the difference between target and stream bitrate + diff := float64(targetBitrate) / float64(streamBitrate) + + debugLogger.Info(). + Float64("diff", diff). + Int("target_bitrate", targetBitrate). + Uint64("stream_bitrate", streamBitrate). + Str("direction", direction.String()). + Msg("got bitrate from estimator") + + // if we can accomodate current stream or we are not netural anymore, + // we are not stalled so we reset the stalled time + if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold { + stalledSince = time.Now() + } + + // if we are neutral and stalled for too long, we might be congesting + stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration + if stalled { + debugLogger.Warn(). + Time("stalled_since", stalledSince). + Msgf("it looks like we are stalled") + } + + // if we have an downward trend or are stalled, we might be congesting + if direction == utils.TrendDirectionDownward || stalled { + // we reset the stable time because we are congesting + stableSince = time.Now() + + // if we downgraded recently, we wait for some more time + if time.Since(lastDowngradeTime) < conf.DowngradeBackoff { + debugLogger.Debug(). + Time("last_downgrade", lastDowngradeTime). + Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff) + continue } + + // if we are not unstable but we fluctuate we should wait for some more time + if time.Since(unstableSince) < conf.UnstableDuration { + debugLogger.Debug(). + Time("unstable_since", unstableSince). + Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration) + continue + } + + // if we still have a big difference between target and stream bitrate, we wait for some more time + if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold { + debugLogger.Debug(). + Float64("diff", diff). + Float64("threshold", conf.DiffThreshold). + Msgf("we still have a big difference between target and stream bitrate, " + + "therefore we still should be able to accomodate current stream") + continue + } + + err := peer.SetVideo(types.StreamSelector{ + ID: streamId, + Type: types.StreamSelectorTypeLower, + }) + if err != nil && err != types.ErrWebRTCStreamNotFound { + peer.logger.Warn().Err(err).Msg("failed to downgrade video stream") + } + lastDowngradeTime = time.Now() + + if err == types.ErrWebRTCStreamNotFound { + debugLogger.Info().Msg("looks like we are already on the lowest stream") + } else { + debugLogger.Info().Msg("downgraded video stream") + } + continue + } + + // we reset the unstable time because we are not congesting + unstableSince = time.Now() + + // if we have a neutral or upward trend, that means our estimate is stable + // if we are on the highest stream, we don't need to do anything + // but if there is a higher stream, we should try to upgrade and see if it works + + // if we upgraded recently, we wait for some more time + if time.Since(lastUpgradeTime) < conf.UpgradeBackoff { + debugLogger.Debug(). + Time("last_upgrade", lastUpgradeTime). + Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff) + continue + } + + // if we are not stable for long enough, we wait for some more time + // because bandwidth estimation might fluctuate + if time.Since(stableSince) < conf.StableDuration { + debugLogger.Debug(). + Time("stable_since", stableSince). + Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration) + continue + } + + // upgrade only if estimated bitrate passed the threshold + if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold { + debugLogger.Debug(). + Float64("diff", diff). + Float64("threshold", conf.DiffThreshold). + Msgf("looks like we don't have enough bitrate to accomodate higher stream, " + + "therefore we should wait for some more time") + continue + } + + err := peer.SetVideo(types.StreamSelector{ + ID: streamId, + Type: types.StreamSelectorTypeHigher, + }) + if err != nil && err != types.ErrWebRTCStreamNotFound { + peer.logger.Warn().Err(err).Msg("failed to upgrade video stream") + } + lastUpgradeTime = time.Now() + + if err == types.ErrWebRTCStreamNotFound { + debugLogger.Info().Msg("looks like we are already on the highest stream") + } else { + debugLogger.Info().Msg("upgraded video stream") } } } @@ -145,88 +302,52 @@ func (peer *WebRTCPeerCtx) estimatorReader() { // video // -func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error { +func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error { peer.mu.Lock() defer peer.mu.Unlock() - // when switching from manual to auto bitrate estimation, in case the estimator is - // idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate - if peerBitrate == 0 && peer.estimator != nil && !peer.estimatorPassive { - peerBitrate = peer.estimator.GetTargetBitrate() - peer.logger.Debug(). - Int("peer_bitrate", peerBitrate). - Msg("evaluated bitrate") + // get requested video stream from selector + stream, ok := peer.videoSelector.GetStream(selector) + if !ok { + return types.ErrWebRTCStreamNotFound } - changed, err := peer.videoTrack.SetBitrate(peerBitrate) + // set video stream to track + changed, err := peer.videoTrack.SetStream(stream) if err != nil { return err } + // if video stream was already set, do nothing if !changed { - // TODO: return error? return nil } - videoID := peer.videoTrack.stream.ID() - bitrate := peer.videoTrack.stream.Bitrate() - + videoID := stream.ID() peer.metrics.SetVideoID(videoID) - peer.logger.Debug(). - Int("peer_bitrate", peerBitrate). - Int("video_bitrate", bitrate). - Str("video_id", videoID). - Msg("peer bitrate triggered video stream change") + + peer.logger.Info().Str("video_id", videoID).Msg("set video") go peer.session.Send( event.SIGNAL_VIDEO, message.SignalVideo{ - Video: videoID, - Bitrate: bitrate, - VideoAuto: peer.videoTrack.VideoAuto(), + Video: videoID, + Auto: peer.videoAuto, }) return nil } -func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { +func (peer *WebRTCPeerCtx) VideoID() (string, bool) { peer.mu.Lock() defer peer.mu.Unlock() - changed, err := peer.videoTrack.SetVideoID(videoID) - if err != nil { - return err + stream, ok := peer.videoTrack.Stream() + if !ok { + return "", false } - if !changed { - // TODO: return error? - return nil - } - - bitrate := peer.videoTrack.stream.Bitrate() - - peer.logger.Debug(). - Str("video_id", videoID). - Int("video_bitrate", bitrate). - Msg("peer video id triggered video stream change") - - go peer.session.Send( - event.SIGNAL_VIDEO, - message.SignalVideo{ - Video: videoID, - Bitrate: bitrate, - VideoAuto: peer.videoTrack.VideoAuto(), - }) - - return nil -} - -func (peer *WebRTCPeerCtx) GetVideoID() string { - peer.mu.Lock() - defer peer.mu.Unlock() - - // TODO: Refactor. - return peer.videoTrack.stream.ID() + return stream.ID(), true } func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error { @@ -239,18 +360,32 @@ func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error { return nil } +func (peer *WebRTCPeerCtx) Paused() bool { + peer.mu.Lock() + defer peer.mu.Unlock() + + return peer.videoTrack.Paused() || peer.audioTrack.Paused() +} + func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) { + peer.mu.Lock() + defer peer.mu.Unlock() + // if estimator is enabled and is not passive, enable video auto bitrate - if peer.estimator != nil && !peer.estimatorPassive { - peer.videoTrack.SetVideoAuto(videoAuto) + if peer.estimator != nil && !peer.estimatorConfig.Passive { + peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto") + peer.videoAuto = videoAuto } else { peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto") - peer.videoTrack.SetVideoAuto(false) // ensure video auto is disabled + peer.videoAuto = false // ensure video auto is disabled } } func (peer *WebRTCPeerCtx) VideoAuto() bool { - return peer.videoTrack.VideoAuto() + peer.mu.Lock() + defer peer.mu.Unlock() + + return peer.videoAuto } // diff --git a/internal/webrtc/track.go b/internal/webrtc/track.go index dd528bda..f18122e2 100644 --- a/internal/webrtc/track.go +++ b/internal/webrtc/track.go @@ -2,7 +2,6 @@ package webrtc import ( "errors" - "fmt" "io" "sync" @@ -22,25 +21,13 @@ type Track struct { rtcpCh chan []rtcp.Packet sample chan types.Sample - videoAuto bool - videoAutoMu sync.RWMutex - paused bool stream types.StreamSinkManager streamMu sync.Mutex - - bitrateChange func(int) (bool, error) - videoChange func(string) (bool, error) } type trackOption func(*Track) -func WithVideoAuto(auto bool) trackOption { - return func(t *Track) { - t.videoAuto = auto - } -} - func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption { return func(t *Track) { t.rtcpCh = rtcp @@ -100,6 +87,8 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) { } } +// --- sample --- + func (t *Track) sampleReader() { for { sample, ok := <-t.sample @@ -120,6 +109,12 @@ func (t *Track) sampleReader() { } } +func (t *Track) WriteSample(sample types.Sample) { + t.sample <- sample +} + +// --- stream --- + func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) { t.streamMu.Lock() defer t.streamMu.Unlock() @@ -167,6 +162,15 @@ func (t *Track) RemoveStream() { t.stream = nil } +func (t *Track) Stream() (types.StreamSinkManager, bool) { + t.streamMu.Lock() + defer t.streamMu.Unlock() + + return t.stream, t.stream != nil +} + +// --- paused --- + func (t *Track) SetPaused(paused bool) { t.streamMu.Lock() defer t.streamMu.Unlock() @@ -190,42 +194,9 @@ func (t *Track) SetPaused(paused bool) { t.paused = paused } -func (t *Track) WriteSample(sample types.Sample) { - t.sample <- sample -} - -func (t *Track) SetBitrate(bitrate int) (bool, error) { - if t.bitrateChange == nil { - return false, fmt.Errorf("bitrate change not supported") - } - - return t.bitrateChange(bitrate) -} - -func (t *Track) SetVideoID(videoID string) (bool, error) { - if t.videoChange == nil { - return false, fmt.Errorf("video change not supported") - } - - return t.videoChange(videoID) -} - -func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) { - t.bitrateChange = f -} - -func (t *Track) OnVideoChange(f func(string) (bool, error)) { - t.videoChange = f -} - -func (t *Track) SetVideoAuto(auto bool) { - t.videoAutoMu.Lock() - defer t.videoAutoMu.Unlock() - t.videoAuto = auto -} - -func (t *Track) VideoAuto() bool { - t.videoAutoMu.RLock() - defer t.videoAutoMu.RUnlock() - return t.videoAuto +func (t *Track) Paused() bool { + t.streamMu.Lock() + defer t.streamMu.Unlock() + + return t.paused } diff --git a/internal/websocket/handler/signal.go b/internal/websocket/handler/signal.go index f25f2bc1..0fe6918a 100644 --- a/internal/websocket/handler/signal.go +++ b/internal/websocket/handler/signal.go @@ -20,28 +20,26 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag payload.Video = videos[0] } - var err error - if payload.Bitrate == 0 { - // get bitrate from video id - payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video) - if err != nil { - return err - } - } - - offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto) + offer, peer, err := h.webrtc.CreatePeer(session) if err != nil { return err } - if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil { - // set webrtc as paused if session has private mode enabled - if session.PrivateModeEnabled() { - webrtcPeer.SetPaused(true) - } + // set webrtc as paused if session has private mode enabled + if session.PrivateModeEnabled() { + peer.SetPaused(true) + } - payload.Video = webrtcPeer.GetVideoID() - payload.VideoAuto = webrtcPeer.VideoAuto() + // set video auto state + peer.SetVideoAuto(payload.Auto) + + // set video stream + err = peer.SetVideo(types.StreamSelector{ + ID: payload.Video, + Type: types.StreamSelectorTypeNearest, + }) + if err != nil { + return err } session.Send( @@ -49,9 +47,6 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag message.SignalProvide{ SDP: offer.SDP, ICEServers: h.webrtc.ICEServers(), - Video: payload.Video, // TODO: Refactor - Bitrate: payload.Bitrate, - VideoAuto: payload.VideoAuto, }) return nil @@ -133,16 +128,13 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message. return errors.New("webRTC peer does not exist") } - peer.SetVideoAuto(payload.VideoAuto) + peer.SetVideoAuto(payload.Auto) if payload.Video != "" { - if err := peer.SetVideoID(payload.Video); err != nil { - h.logger.Error().Err(err).Msg("failed to set video id") - } - } else { - if err := peer.SetVideoBitrate(payload.Bitrate); err != nil { - h.logger.Error().Err(err).Msg("failed to set video bitrate") - } + return peer.SetVideo(types.StreamSelector{ + ID: payload.Video, + Type: types.StreamSelectorTypeNearest, + }) } return nil diff --git a/pkg/types/capture.go b/pkg/types/capture.go index 485161a4..140cde44 100644 --- a/pkg/types/capture.go +++ b/pkg/types/capture.go @@ -31,26 +31,6 @@ type SampleListener interface { WriteSample(Sample) } -type Receiver interface { - SetStream(stream StreamSinkManager) (changed bool, err error) - RemoveStream() - OnBitrateChange(f func(bitrate int) (changed bool, err error)) - OnVideoChange(f func(videoID string) (changed bool, err error)) - VideoAuto() bool - SetVideoAuto(videoAuto bool) -} - -type BucketsManager interface { - IDs() []string - Codec() codec.RTPCodec - SetReceiver(receiver Receiver) - RemoveReceiver(receiver Receiver) error - - DestroyAll() - RecreateAll() error - Shutdown() -} - type BroadcastManager interface { Start(url string) error Stop() @@ -64,10 +44,74 @@ type ScreencastManager interface { Image() ([]byte, error) } +type StreamSelectorType int + +const ( + // select exact stream + StreamSelectorTypeExact StreamSelectorType = iota + // select nearest stream (in either direction) if exact stream is not available + StreamSelectorTypeNearest + // if exact stream is found select the next lower stream, otherwise select the nearest lower stream + StreamSelectorTypeLower + // if exact stream is found select the next higher stream, otherwise select the nearest higher stream + StreamSelectorTypeHigher +) + +func (s StreamSelectorType) String() string { + switch s { + case StreamSelectorTypeExact: + return "exact" + case StreamSelectorTypeNearest: + return "nearest" + case StreamSelectorTypeLower: + return "lower" + case StreamSelectorTypeHigher: + return "higher" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +func (s *StreamSelectorType) UnmarshalText(text []byte) error { + switch strings.ToLower(string(text)) { + case "exact", "": + *s = StreamSelectorTypeExact + case "nearest": + *s = StreamSelectorTypeNearest + case "lower": + *s = StreamSelectorTypeLower + case "higher": + *s = StreamSelectorTypeHigher + default: + return fmt.Errorf("invalid stream selector type: %s", string(text)) + } + return nil +} + +func (s StreamSelectorType) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +type StreamSelector struct { + // type of stream selector + Type StreamSelectorType + // select stream by its ID + ID string + // select stream by its bitrate + Bitrate uint64 +} + +type StreamSelectorManager interface { + IDs() []string + Codec() codec.RTPCodec + + GetStream(selector StreamSelector) (StreamSinkManager, bool) +} + type StreamSinkManager interface { ID() string Codec() codec.RTPCodec - Bitrate() int + Bitrate() uint64 AddListener(listener SampleListener) error RemoveListener(listener SampleListener) error @@ -94,12 +138,10 @@ type CaptureManager interface { Start() Shutdown() error - GetBitrateFromVideoID(videoID string) (int, error) - Broadcast() BroadcastManager Screencast() ScreencastManager Audio() StreamSinkManager - Video() BucketsManager + Video() StreamSelectorManager Webcam() StreamSrcManager Microphone() StreamSrcManager @@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) { config.GstSuffix, }[:], " "), nil } - -func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) { - return func() (int, error) { - if config.Bitrate > 0 { - return config.Bitrate, nil - } - - screen := getScreen() - - values := map[string]any{ - "width": screen.Width, - "height": screen.Height, - "fps": screen.Rate, - } - - language := []gval.Language{ - gval.Function("round", func(args ...any) (any, error) { - return (int)(math.Round(args[0].(float64))), nil - }), - } - - // TOOD: do not read target-bitrate from pipeline, but only from config. - - // TODO: This is only for vp8. - expr, ok := config.GstParams["target-bitrate"] - if !ok { - // TODO: This is only for h264. - expr, ok = config.GstParams["bitrate"] - if !ok { - return 0, fmt.Errorf("bitrate not found") - } - } - - bitrate, err := gval.Evaluate(expr, values, language...) - if err != nil { - return 0, fmt.Errorf("failed to evaluate bitrate: %w", err) - } - - var bitrateInt int - switch val := bitrate.(type) { - case int: - bitrateInt = val - case float64: - bitrateInt = (int)(val) - default: - return 0, fmt.Errorf("bitrate is not int or float64") - } - - return bitrateInt, nil - } -} diff --git a/pkg/types/message/messages.go b/pkg/types/message/messages.go index 832b24f0..3b536b1e 100644 --- a/pkg/types/message/messages.go +++ b/pkg/types/message/messages.go @@ -48,10 +48,6 @@ type SystemDisconnect struct { type SignalProvide struct { SDP string `json:"sdp"` ICEServers []types.ICEServer `json:"iceservers"` - // TODO: Use SignalVideo struct. - Video string `json:"video"` - Bitrate int `json:"bitrate"` - VideoAuto bool `json:"video_auto"` } type SignalCandidate struct { @@ -63,9 +59,8 @@ type SignalDescription struct { } type SignalVideo struct { - Video string `json:"video"` - Bitrate int `json:"bitrate"` - VideoAuto bool `json:"video_auto"` + Video string `json:"video"` + Auto bool `json:"auto"` } ///////////////////////////// diff --git a/pkg/types/webrtc.go b/pkg/types/webrtc.go index b255e6da..0b6f026c 100644 --- a/pkg/types/webrtc.go +++ b/pkg/types/webrtc.go @@ -9,6 +9,7 @@ import ( var ( ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found") ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found") + ErrWebRTCStreamNotFound = errors.New("webrtc stream not found") ) type ICEServer struct { @@ -23,10 +24,10 @@ type WebRTCPeer interface { SetRemoteDescription(webrtc.SessionDescription) error SetCandidate(webrtc.ICECandidateInit) error - SetVideoBitrate(bitrate int) error - SetVideoID(videoID string) error - GetVideoID() string + SetVideo(StreamSelector) error + VideoID() (string, bool) SetPaused(isPaused bool) error + Paused() bool SetVideoAuto(auto bool) VideoAuto() bool @@ -42,6 +43,6 @@ type WebRTCManager interface { ICEServers() []ICEServer - CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) + CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error) SetCursorPosition(x, y int) } diff --git a/pkg/utils/trenddetector.go b/pkg/utils/trenddetector.go new file mode 100644 index 00000000..a07f68be --- /dev/null +++ b/pkg/utils/trenddetector.go @@ -0,0 +1,153 @@ +// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go +package utils + +import ( + "fmt" + "time" +) + +// ------------------------------------------------ + +type TrendDirection int + +const ( + TrendDirectionNeutral TrendDirection = iota + TrendDirectionUpward + TrendDirectionDownward +) + +func (t TrendDirection) String() string { + switch t { + case TrendDirectionNeutral: + return "NEUTRAL" + case TrendDirectionUpward: + return "UPWARD" + case TrendDirectionDownward: + return "DOWNWARD" + default: + return fmt.Sprintf("%d", int(t)) + } +} + +// ------------------------------------------------ + +type TrendDetectorParams struct { + RequiredSamples int + DownwardTrendThreshold float64 + CollapseValues bool +} + +type TrendDetector struct { + params TrendDetectorParams + + startTime time.Time + numSamples int + values []int64 + lowestValue int64 + highestValue int64 + + direction TrendDirection +} + +func NewTrendDetector(params TrendDetectorParams) *TrendDetector { + return &TrendDetector{ + params: params, + startTime: time.Now(), + direction: TrendDirectionNeutral, + } +} + +func (t *TrendDetector) Seed(value int64) { + if len(t.values) != 0 { + return + } + + t.values = append(t.values, value) +} + +func (t *TrendDetector) AddValue(value int64) { + t.numSamples++ + if t.lowestValue == 0 || value < t.lowestValue { + t.lowestValue = value + } + if value > t.highestValue { + t.highestValue = value + } + + // ignore duplicate values + if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value { + return + } + + if len(t.values) == t.params.RequiredSamples { + t.values = t.values[1:] + } + t.values = append(t.values, value) + + t.updateDirection() +} + +func (t *TrendDetector) GetLowest() int64 { + return t.lowestValue +} + +func (t *TrendDetector) GetHighest() int64 { + return t.highestValue +} + +func (t *TrendDetector) GetValues() []int64 { + return t.values +} + +func (t *TrendDetector) GetDirection() TrendDirection { + return t.direction +} + +func (t *TrendDetector) ToString() string { + now := time.Now() + elapsed := now.Sub(t.startTime).Seconds() + str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed) + str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values)) + return str +} + +func (t *TrendDetector) updateDirection() { + if len(t.values) < t.params.RequiredSamples { + t.direction = TrendDirectionNeutral + return + } + + // using Kendall's Tau to find trend + kt := kendallsTau(t.values) + + t.direction = TrendDirectionNeutral + switch { + case kt > 0: + t.direction = TrendDirectionUpward + case kt < t.params.DownwardTrendThreshold: + t.direction = TrendDirectionDownward + } +} + +// ------------------------------------------------ + +func kendallsTau(values []int64) float64 { + concordantPairs := 0 + discordantPairs := 0 + + for i := 0; i < len(values)-1; i++ { + for j := i + 1; j < len(values); j++ { + if values[i] < values[j] { + concordantPairs++ + } else if values[i] > values[j] { + discordantPairs++ + } + } + } + + if (concordantPairs + discordantPairs) == 0 { + return 0.0 + } + + return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs)) +}