Bandwidth estimator refactor (#46)

* rewrite to use stream selector.

* WIP.

* add nacks to metrics.

* add estimate trend.

* estimator based on trend detector.

* add estimator unstable duration.

* add estimator debug.

* add stalled duration.

* estimator move values to config.

* change default estimator values.

* minor style changes.

* fix websocket video messages.

* replace video track with ivdeo id.
This commit is contained in:
Miroslav Šedivý 2023-05-15 19:29:39 +02:00 committed by GitHub
parent 8660c1a256
commit 3e8d686c0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 845 additions and 775 deletions

View File

@ -1,145 +0,0 @@
package buckets
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *BucketsManagerCtx) Shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.DestroyAll()
}
func (manager *BucketsManagerCtx) DestroyAll() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *BucketsManagerCtx) RecreateAll() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *BucketsManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
// bitrate history is per receiver
bitrateHistory := &queue{}
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
bitrate := peerBitrate
if receiver.VideoAuto() {
bitrate = bitrateHistory.normaliseBitrate(bitrate)
}
stream := manager.findNearestStream(bitrate)
streamID := stream.ID()
// TODO: make this less noisy in logs
manager.logger.Debug().
Str("video_id", streamID).
Int("len", bitrateHistory.len()).
Int("peer_bitrate", peerBitrate).
Int("bitrate", bitrate).
Msg("change video bitrate")
return receiver.SetStream(stream)
})
receiver.OnVideoChange(func(videoID string) (bool, error) {
stream := manager.streams[videoID]
manager.logger.Info().
Str("video_id", videoID).
Msg("video change")
return receiver.SetStream(stream)
})
}
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: peerBitrate - stream.Bitrate(),
})
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.OnVideoChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -1,83 +0,0 @@
package buckets
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
type fields struct {
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
}
type args struct {
peerBitrate int
}
tests := []struct {
name string
fields fields
args args
want types.StreamSinkManager
}{
{
name: "findNearestStream",
fields: fields{
streams: map[string]types.StreamSinkManager{
"1": mockStreamSink{
id: "1",
bitrate: 500,
},
"2": mockStreamSink{
id: "2",
bitrate: 750,
},
"3": mockStreamSink{
id: "3",
bitrate: 1000,
},
"4": mockStreamSink{
id: "4",
bitrate: 1250,
},
"5": mockStreamSink{
id: "5",
bitrate: 1700,
},
},
},
args: args{
peerBitrate: 950,
},
want: mockStreamSink{
id: "2",
bitrate: 750,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
}
})
}
}
type mockStreamSink struct {
id string
bitrate int
types.StreamSinkManager
}
func (m mockStreamSink) ID() string {
return m.id
}
func (m mockStreamSink) Bitrate() int {
return m.bitrate
}

View File

@ -1,88 +0,0 @@
package buckets
import (
"math"
"sync"
"time"
)
type queue struct {
sync.Mutex
q []elem
}
type elem struct {
created time.Time
bitrate int
}
func (q *queue) push(v elem) {
q.Lock()
defer q.Unlock()
// if the first element is older than 10 seconds, remove it
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
q.q = q.q[1:]
}
q.q = append(q.q, v)
}
func (q *queue) len() int {
q.Lock()
defer q.Unlock()
return len(q.q)
}
func (q *queue) avg() int {
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q {
sum += v.bitrate
}
return sum / len(q.q)
}
func (q *queue) avgLastN(n int) int {
if n <= 0 {
return q.avg()
}
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q[len(q.q)-n:] {
sum += v.bitrate
}
return sum / n
}
func (q *queue) normaliseBitrate(currentBitrate int) int {
avgBitrate := float64(q.avg())
histLen := float64(q.len())
q.push(elem{
bitrate: currentBitrate,
created: time.Now(),
})
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
return currentBitrate
}
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
if lastN > q.len() {
lastN = q.len()
}
if lastN == 0 {
return currentBitrate
}
return q.avgLastN(lastN)
}

View File

@ -1,99 +0,0 @@
package buckets
import "testing"
func Queue_normaliseBitrate(t *testing.T) {
type fields struct {
queue *queue
}
type args struct {
currentBitrate int
}
tests := []struct {
name string
fields fields
args args
want []int
}{
{
name: "normaliseBitrate: big drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 350,
},
want: []int{816, 700, 537, 350, 350},
}, {
name: "normaliseBitrate: small drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 700,
},
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
}, {
name: "normaliseBitrate",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 1350,
},
want: []int{943, 1003, 1060, 1085},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.queue
for i := 0; i < len(tt.want); i++ {
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
}
}
})
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/capture/buckets"
"github.com/demodesk/neko/internal/config" "github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec" "github.com/demodesk/neko/pkg/types/codec"
@ -23,7 +22,7 @@ type CaptureManagerCtx struct {
broadcast *BroacastManagerCtx broadcast *BroacastManagerCtx
screencast *ScreencastManagerCtx screencast *ScreencastManagerCtx
audio *StreamSinkManagerCtx audio *StreamSinkManagerCtx
video types.BucketsManager video *StreamSelectorManagerCtx
// sources // sources
webcam *StreamSrcManagerCtx webcam *StreamSrcManagerCtx
@ -68,13 +67,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Str("pipeline", pipeline). Str("pipeline", pipeline).
Msg("syntax check for video stream pipeline passed") Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if _, err = getVideoBitrate(); err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate")
}
// append to videos // append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate) videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
} }
return &CaptureManagerCtx{ return &CaptureManagerCtx{
@ -140,8 +134,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! %s "+ "! %s "+
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil ), nil
}, "audio", nil), }, "audio"),
video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs), video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
// sources // sources
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{ webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
@ -202,7 +196,7 @@ func (manager *CaptureManagerCtx) Start() {
} }
manager.desktop.OnBeforeScreenSizeChange(func() { manager.desktop.OnBeforeScreenSizeChange(func() {
manager.video.DestroyAll() manager.video.destroyPipelines()
if manager.broadcast.Started() { if manager.broadcast.Started() {
manager.broadcast.destroyPipeline() manager.broadcast.destroyPipeline()
@ -214,7 +208,7 @@ func (manager *CaptureManagerCtx) Start() {
}) })
manager.desktop.OnAfterScreenSizeChange(func() { manager.desktop.OnAfterScreenSizeChange(func() {
err := manager.video.RecreateAll() err := manager.video.recreatePipelines()
if err != nil { if err != nil {
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines") manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
} }
@ -242,7 +236,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
manager.screencast.shutdown() manager.screencast.shutdown()
manager.audio.shutdown() manager.audio.shutdown()
manager.video.Shutdown() manager.video.shutdown()
manager.webcam.shutdown() manager.webcam.shutdown()
manager.microphone.shutdown() manager.microphone.shutdown()
@ -250,15 +244,6 @@ func (manager *CaptureManagerCtx) Shutdown() error {
return nil return nil
} }
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
cfg, ok := manager.config.VideoPipelines[videoID]
if !ok {
return 0, fmt.Errorf("video config not found for %s", videoID)
}
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
}
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager { func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
return manager.broadcast return manager.broadcast
} }
@ -271,7 +256,7 @@ func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
return manager.audio return manager.audio
} }
func (manager *CaptureManagerCtx) Video() types.BucketsManager { func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
return manager.video return manager.video
} }

View File

@ -0,0 +1,206 @@
package capture
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type StreamSelectorManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-selector").
Logger()
return &StreamSelectorManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *StreamSelectorManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.destroyPipelines()
}
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *StreamSelectorManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
// select stream by ID
if selector.ID != "" {
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
var lastStream types.StreamSinkManager
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
var lastStream types.StreamSinkManager
for _, streamID := range manager.streamIDs {
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a higher stream
return nil, false
}
// select exact stream
stream, ok := manager.streams[selector.ID]
return stream, ok
}
// select stream by bitrate
if selector.Bitrate != 0 {
// select stream by nearest bitrate
if selector.Type == types.StreamSelectorTypeNearest {
return manager.nearestBitrate(selector.Bitrate), true
}
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
// start from the highest stream, and go down, until we find a lower stream
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() < selector.Bitrate {
return stream, true
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
// start from the lowest stream, and go up, until we find a higher stream
for _, streamID := range manager.streamIDs {
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() > selector.Bitrate {
return stream, true
}
}
// we couldn't find a higher stream
return nil, false
}
// select stream by exact bitrate
for _, stream := range manager.streams {
if stream.Bitrate() == selector.Bitrate {
return stream, true
}
}
}
// we couldn't find a stream
return nil, false
}
// TODO: This is a very naive implementation, we should use a binary search instead.
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if !considered {
continue
}
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
})
}
// no streams available
if len(diffs) == 0 {
// return first (lowest) stream
return manager.streams[manager.streamIDs[0]]
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}

View File

@ -21,9 +21,10 @@ import (
var moveSinkListenerMu = sync.Mutex{} var moveSinkListenerMu = sync.Mutex{}
type StreamSinkManagerCtx struct { type StreamSinkManagerCtx struct {
id string id string
getBitrate func() (int, error)
waitForKf bool // wait for a keyframe before sending samples // wait for a keyframe before sending samples
waitForKf bool
bitrate uint64 // atomic bitrate uint64 // atomic
brBuckets map[int]float64 brBuckets map[int]float64
@ -48,22 +49,23 @@ type StreamSinkManagerCtx struct {
pipelinesActive prometheus.Gauge pipelinesActive prometheus.Gauge
} }
func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx { func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
logger := log.With(). logger := log.With().
Str("module", "capture"). Str("module", "capture").
Str("submodule", "stream-sink"). Str("submodule", "stream-sink").
Str("id", id).Logger() Str("id", id).Logger()
manager := &StreamSinkManagerCtx{ manager := &StreamSinkManagerCtx{
id: id, id: id,
getBitrate: getBitrate,
// only wait for keyframes if the codec is video
waitForKf: c.IsVideo(),
// only wait for keyframes if the codec is video
waitForKf: codec.IsVideo(),
bitrate: 0,
brBuckets: map[int]float64{}, brBuckets: map[int]float64{},
logger: logger, logger: logger,
codec: c, codec: codec,
pipelineFn: pipelineFn, pipelineFn: pipelineFn,
listeners: map[uintptr]types.SampleListener{}, listeners: map[uintptr]types.SampleListener{},
@ -77,8 +79,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
Help: "Current number of listeners for a pipeline.", Help: "Current number of listeners for a pipeline.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"video_id": id, "video_id": id,
"codec_name": c.Name, "codec_name": codec.Name,
"codec_type": c.Type.String(), "codec_type": codec.Type.String(),
}, },
}), }),
totalBytes: promauto.NewCounter(prometheus.CounterOpts{ totalBytes: promauto.NewCounter(prometheus.CounterOpts{
@ -88,8 +90,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
Help: "Total number of bytes created by the pipeline.", Help: "Total number of bytes created by the pipeline.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"video_id": id, "video_id": id,
"codec_name": c.Name, "codec_name": codec.Name,
"codec_type": c.Type.String(), "codec_type": codec.Type.String(),
}, },
}), }),
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{ pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
@ -100,8 +102,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"submodule": "streamsink", "submodule": "streamsink",
"video_id": id, "video_id": id,
"codec_name": c.Name, "codec_name": codec.Name,
"codec_type": c.Type.String(), "codec_type": codec.Type.String(),
}, },
}), }),
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{ pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
@ -112,8 +114,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"submodule": "streamsink", "submodule": "streamsink",
"video_id": id, "video_id": id,
"codec_name": c.Name, "codec_name": codec.Name,
"codec_type": c.Type.String(), "codec_type": codec.Type.String(),
}, },
}), }),
} }
@ -141,27 +143,8 @@ func (manager *StreamSinkManagerCtx) ID() string {
return manager.id return manager.id
} }
func (manager *StreamSinkManagerCtx) Bitrate() int { func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
// TODO: fix bitrate switching calculation return atomic.LoadUint64(&manager.bitrate)
// return real bitrate if available
//realBitrate := atomic.LoadUint64(&manager.bitrate)
//if realBitrate != 0 {
// return int(realBitrate)
//}
// if we do not have function to estimate bitrate, return 0
if manager.getBitrate == nil {
return 0
}
// recalculate bitrate every time, take screen resolution (and fps) into account
// we called this function during startup, so it shouldn't error here
bitrate, err := manager.getBitrate()
if err != nil {
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
}
return bitrate
} }
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {

View File

@ -3,6 +3,7 @@ package config
import ( import (
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -15,6 +16,28 @@ import (
// default stun server // default stun server
const defStunSrv = "stun:stun.l.google.com:19302" const defStunSrv = "stun:stun.l.google.com:19302"
type WebRTCEstimator struct {
Enabled bool
Passive bool
Debug bool
InitialBitrate int
// how often to read and process bandwidth estimation reports
ReadInterval time.Duration
// how long to wait for stable connection (only neutral or upward trend) before upgrading
StableDuration time.Duration
// how long to wait for unstable connection (downward trend) before downgrading
UnstableDuration time.Duration
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
StalledDuration time.Duration
// how long to wait before downgrading again after previous downgrade
DowngradeBackoff time.Duration
// how long to wait before upgrading again after previous upgrade
UpgradeBackoff time.Duration
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
DiffThreshold float64
}
type WebRTC struct { type WebRTC struct {
ICELite bool ICELite bool
ICETrickle bool ICETrickle bool
@ -28,9 +51,7 @@ type WebRTC struct {
NAT1To1IPs []string NAT1To1IPs []string
IpRetrievalUrl string IpRetrievalUrl string
EstimatorEnabled bool Estimator WebRTCEstimator
EstimatorPassive bool
EstimatorInitialBitrate int
} }
func (WebRTC) Init(cmd *cobra.Command) error { func (WebRTC) Init(cmd *cobra.Command) error {
@ -96,11 +117,51 @@ func (WebRTC) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
return err
}
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator") cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil { if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
return err return err
} }
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
return err
}
return nil return nil
} }
@ -197,7 +258,15 @@ func (s *WebRTC) Set() {
// bandwidth estimator // bandwidth estimator
s.EstimatorEnabled = viper.GetBool("webrtc.estimator.enabled") s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
s.EstimatorPassive = viper.GetBool("webrtc.estimator.passive") s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
s.EstimatorInitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate") s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
} }

View File

@ -24,6 +24,7 @@ import (
"github.com/demodesk/neko/pkg/types/codec" "github.com/demodesk/neko/pkg/types/codec"
"github.com/demodesk/neko/pkg/types/event" "github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message" "github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
) )
const ( const (
@ -167,7 +168,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServersFrontend return manager.config.ICEServersFrontend
} }
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine // create media engine
engine := &webrtc.MediaEngine{} engine := &webrtc.MediaEngine{}
for _, codec := range codecs { for _, codec := range codecs {
@ -223,14 +224,10 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
// create bandwidth estimator // create bandwidth estimator
estimatorChan := make(chan cc.BandwidthEstimator, 1) estimatorChan := make(chan cc.BandwidthEstimator, 1)
if manager.config.EstimatorEnabled { if manager.config.Estimator.Enabled {
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
if bitrate == 0 {
bitrate = manager.config.EstimatorInitialBitrate
}
return gcc.NewSendSideBWE( return gcc.NewSendSideBWE(
gcc.SendSideBWEInitialBitrate(bitrate), gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()), gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
) )
}) })
@ -268,7 +265,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
return connection, <-estimatorChan, err return connection, <-estimatorChan, err
} }
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) { func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
id := atomic.AddInt32(&manager.peerId, 1) id := atomic.AddInt32(&manager.peerId, 1)
// get metrics for session // get metrics for session
@ -287,12 +284,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
video := manager.capture.Video() video := manager.capture.Video()
videoCodec := video.Codec() videoCodec := video.Codec()
connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{ connection, estimator, err := manager.newPeerConnection(
audioCodec, logger, []codec.RTPCodec{audioCodec, videoCodec})
videoCodec,
}, bitrate)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// asynchronously send local ICE Candidates // asynchronously send local ICE Candidates
@ -311,47 +306,34 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
}) })
} }
// if bitrate is 0, and estimator is enabled, use estimator bitrate
if bitrate == 0 && estimator != nil {
bitrate = estimator.GetTargetBitrate()
}
// audio track // audio track
audioTrack, err := NewTrack(logger, audioCodec, connection) audioTrack, err := NewTrack(logger, audioCodec, connection)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// set stream for audio track // set stream for audio track
_, err = audioTrack.SetStream(audio) _, err = audioTrack.SetStream(audio)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// if estimator is disabled, or in passive mode, disable video auto bitrate
if !manager.config.EstimatorEnabled || manager.config.EstimatorPassive {
videoAuto = false
}
videoRtcp := make(chan []rtcp.Packet, 1)
// video track // video track
videoTrack, err := NewTrack(logger, videoCodec, connection, videoRtcp := make(chan []rtcp.Packet, 1)
WithVideoAuto(videoAuto), videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
WithRtcpChan(videoRtcp),
)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// let video stream bucket manager handle stream subscriptions //
video.SetReceiver(videoTrack) // stream for video track will be set later
//
// data channel // data channel
dataChannel, err := connection.CreateDataChannel("data", nil) dataChannel, err := connection.CreateDataChannel("data", nil)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
peer := &WebRTCPeerCtx{ peer := &WebRTCPeerCtx{
@ -359,24 +341,29 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
session: session, session: session,
metrics: metrics, metrics: metrics,
connection: connection, connection: connection,
estimator: estimator, // bandwidth estimator
estimator: estimator,
estimateTrend: utils.NewTrendDetector(
utils.TrendDetectorParams{
// Probing
//RequiredSamples: 3,
//DownwardTrendThreshold: 0.0,
//CollapseValues: false,
// Non-Probing
RequiredSamples: 8,
DownwardTrendThreshold: -0.5,
CollapseValues: true,
}),
// stream selectors
videoSelector: manager.capture.Video(),
// tracks & channels // tracks & channels
audioTrack: audioTrack, audioTrack: audioTrack,
videoTrack: videoTrack, videoTrack: videoTrack,
dataChannel: dataChannel, dataChannel: dataChannel,
rtcpChannel: videoRtcp, rtcpChannel: videoRtcp,
// config // config
iceTrickle: manager.config.ICETrickle, iceTrickle: manager.config.ICETrickle,
estimatorPassive: manager.config.EstimatorPassive, estimatorConfig: manager.config.Estimator,
}
logger.Info().
Int("target_bitrate", bitrate).
Msg("estimated initial peer bitrate")
// set initial video bitrate
if err := peer.SetVideoBitrate(bitrate); err != nil {
return nil, err
} }
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
@ -492,9 +479,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
// ensure we only run this once // ensure we only run this once
once.Do(func() { once.Do(func() {
session.SetWebRTCConnected(peer, false) session.SetWebRTCConnected(peer, false)
if err = video.RemoveReceiver(videoTrack); err != nil { //
logger.Err(err).Msg("failed to remove video receiver") // TODO: Shutdown peer?
} //
audioTrack.Shutdown() audioTrack.Shutdown()
videoTrack.Shutdown() videoTrack.Shutdown()
close(videoRtcp) close(videoRtcp)
@ -542,7 +529,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
offer, err := peer.CreateOffer(false) offer, err := peer.CreateOffer(false)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
// on negotiation needed handler must be registered after creating initial // on negotiation needed handler must be registered after creating initial
@ -576,7 +563,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
// start estimator reader // start estimator reader
go peer.estimatorReader() go peer.estimatorReader()
return offer, nil return offer, peer, nil
} }
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) { func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {

View File

@ -121,7 +121,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_estimated_maximum_bitrate", Name: "receiver_estimated_maximum_bitrate",
Namespace: "neko", Namespace: "neko",
Subsystem: "webrtc", Subsystem: "webrtc",
Help: "Receiver Estimated Maximum Bitrate from SCTP.", Help: "Receiver Estimated Maximum Bitrate from RTCP.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"session_id": sessionId, "session_id": sessionId,
}, },
@ -140,7 +140,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_delay", Name: "receiver_report_delay",
Namespace: "neko", Namespace: "neko",
Subsystem: "webrtc", Subsystem: "webrtc",
Help: "Receiver Report Delay from SCTP, expressed in units of 1/65536 seconds.", Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"session_id": sessionId, "session_id": sessionId,
}, },
@ -149,7 +149,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_jitter", Name: "receiver_report_jitter",
Namespace: "neko", Namespace: "neko",
Subsystem: "webrtc", Subsystem: "webrtc",
Help: "Receiver Report Jitter from SCTP.", Help: "Receiver Report Jitter from RTCP.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"session_id": sessionId, "session_id": sessionId,
}, },
@ -158,7 +158,17 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_total_lost", Name: "receiver_report_total_lost",
Namespace: "neko", Namespace: "neko",
Subsystem: "webrtc", Subsystem: "webrtc",
Help: "Receiver Report Total Lost from SCTP.", Help: "Receiver Report Total Lost from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
Name: "transport_layer_nacks",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Transport Layer NACKs from RTCP.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"session_id": sessionId, "session_id": sessionId,
}, },
@ -236,6 +246,8 @@ type metrics struct {
receiverReportJitter prometheus.Gauge receiverReportJitter prometheus.Gauge
receiverReportTotalLost prometheus.Gauge receiverReportTotalLost prometheus.Gauge
transportLayerNacks prometheus.Counter
iceBytesSent prometheus.Gauge iceBytesSent prometheus.Gauge
iceBytesReceived prometheus.Gauge iceBytesReceived prometheus.Gauge
sctpBytesSent prometheus.Gauge sctpBytesSent prometheus.Gauge
@ -386,6 +398,11 @@ func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
// use only last report // use only last report
met.SetReceiverReport(rtcpPacket.Reports[l-1]) met.SetReceiverReport(rtcpPacket.Reports[l-1])
} }
case *rtcp.TransportLayerNack:
for _, pair := range rtcpPacket.Nacks {
packetList := pair.PacketList()
met.transportLayerNacks.Add(float64(len(packetList)))
}
} }
} }
} }

View File

@ -11,15 +11,12 @@ import (
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/payload" "github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event" "github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message" "github.com/demodesk/neko/pkg/types/message"
) "github.com/demodesk/neko/pkg/utils"
const (
// how often to read and process bandwidth estimation reports
estimatorReadInterval = 250 * time.Millisecond
) )
type WebRTCPeerCtx struct { type WebRTCPeerCtx struct {
@ -28,15 +25,20 @@ type WebRTCPeerCtx struct {
session types.Session session types.Session
metrics *metrics metrics *metrics
connection *webrtc.PeerConnection connection *webrtc.PeerConnection
estimator cc.BandwidthEstimator // bandwidth estimator
estimator cc.BandwidthEstimator
estimateTrend *utils.TrendDetector
// stream selectors
videoSelector types.StreamSelectorManager
// tracks & channels // tracks & channels
audioTrack *Track audioTrack *Track
videoTrack *Track videoTrack *Track
dataChannel *webrtc.DataChannel dataChannel *webrtc.DataChannel
rtcpChannel chan []rtcp.Packet rtcpChannel chan []rtcp.Packet
// config // config
iceTrickle bool iceTrickle bool
estimatorPassive bool estimatorConfig config.WebRTCEstimator
videoAuto bool
} }
// //
@ -102,6 +104,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate) return peer.connection.AddICECandidate(candidate)
} }
// TODO: Add shutdown function?
func (peer *WebRTCPeerCtx) Destroy() { func (peer *WebRTCPeerCtx) Destroy() {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
@ -111,32 +114,186 @@ func (peer *WebRTCPeerCtx) Destroy() {
} }
func (peer *WebRTCPeerCtx) estimatorReader() { func (peer *WebRTCPeerCtx) estimatorReader() {
conf := peer.estimatorConfig
// if estimator is not in debug mode, use a nop logger
var debugLogger zerolog.Logger
if conf.Debug {
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
} else {
debugLogger = zerolog.Nop()
}
// if estimator is disabled, do nothing // if estimator is disabled, do nothing
if peer.estimator == nil { if peer.estimator == nil {
return return
} }
// use a ticker to get current client target bitrate // use a ticker to get current client target bitrate
ticker := time.NewTicker(estimatorReadInterval) ticker := time.NewTicker(conf.ReadInterval)
defer ticker.Stop() defer ticker.Stop()
// since when is the estimate stable/unstable
stableSince := time.Now() // we asume stable at start
unstableSince := time.Time{}
// since when are we neutral but cannot accomodate current bitrate
// we migt be stalled or estimator just reached zer (very bad connection)
stalledSince := time.Time{}
// when was the last upgrade/downgrade
lastUpgradeTime := time.Time{}
lastDowngradeTime := time.Time{}
for range ticker.C { for range ticker.C {
targetBitrate := peer.estimator.GetTargetBitrate() targetBitrate := peer.estimator.GetTargetBitrate()
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate)) peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
// if peer connection is closed, stop reading
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed { if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break break
} }
if !peer.videoTrack.VideoAuto() { // if estimation is disabled, do nothing
if !peer.videoAuto || conf.Passive {
continue continue
} }
if !peer.estimatorPassive { // get trend direction to decide if we should upgrade or downgrade
err := peer.SetVideoBitrate(targetBitrate) peer.estimateTrend.AddValue(int64(targetBitrate))
if err != nil { direction := peer.estimateTrend.GetDirection()
peer.logger.Warn().Err(err).Msg("failed to set video bitrate")
// get current stream bitrate
stream, ok := peer.videoTrack.Stream()
if !ok {
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
continue
}
// if stream bitrate is 0, we need to wait for some time until we get a valid value
streamId, streamBitrate := stream.ID(), stream.Bitrate()
if streamBitrate == 0 {
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
continue
}
// check whats the difference between target and stream bitrate
diff := float64(targetBitrate) / float64(streamBitrate)
debugLogger.Info().
Float64("diff", diff).
Int("target_bitrate", targetBitrate).
Uint64("stream_bitrate", streamBitrate).
Str("direction", direction.String()).
Msg("got bitrate from estimator")
// if we can accomodate current stream or we are not netural anymore,
// we are not stalled so we reset the stalled time
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
stalledSince = time.Now()
}
// if we are neutral and stalled for too long, we might be congesting
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
if stalled {
debugLogger.Warn().
Time("stalled_since", stalledSince).
Msgf("it looks like we are stalled")
}
// if we have an downward trend or are stalled, we might be congesting
if direction == utils.TrendDirectionDownward || stalled {
// we reset the stable time because we are congesting
stableSince = time.Now()
// if we downgraded recently, we wait for some more time
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
debugLogger.Debug().
Time("last_downgrade", lastDowngradeTime).
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
continue
} }
// if we are not unstable but we fluctuate we should wait for some more time
if time.Since(unstableSince) < conf.UnstableDuration {
debugLogger.Debug().
Time("unstable_since", unstableSince).
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
continue
}
// if we still have a big difference between target and stream bitrate, we wait for some more time
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("we still have a big difference between target and stream bitrate, " +
"therefore we still should be able to accomodate current stream")
continue
}
err := peer.SetVideo(types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeLower,
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
}
lastDowngradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the lowest stream")
} else {
debugLogger.Info().Msg("downgraded video stream")
}
continue
}
// we reset the unstable time because we are not congesting
unstableSince = time.Now()
// if we have a neutral or upward trend, that means our estimate is stable
// if we are on the highest stream, we don't need to do anything
// but if there is a higher stream, we should try to upgrade and see if it works
// if we upgraded recently, we wait for some more time
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
debugLogger.Debug().
Time("last_upgrade", lastUpgradeTime).
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
continue
}
// if we are not stable for long enough, we wait for some more time
// because bandwidth estimation might fluctuate
if time.Since(stableSince) < conf.StableDuration {
debugLogger.Debug().
Time("stable_since", stableSince).
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
continue
}
// upgrade only if estimated bitrate passed the threshold
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
"therefore we should wait for some more time")
continue
}
err := peer.SetVideo(types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeHigher,
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
}
lastUpgradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the highest stream")
} else {
debugLogger.Info().Msg("upgraded video stream")
} }
} }
} }
@ -145,88 +302,52 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
// video // video
// //
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error { func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
// when switching from manual to auto bitrate estimation, in case the estimator is // get requested video stream from selector
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate stream, ok := peer.videoSelector.GetStream(selector)
if peerBitrate == 0 && peer.estimator != nil && !peer.estimatorPassive { if !ok {
peerBitrate = peer.estimator.GetTargetBitrate() return types.ErrWebRTCStreamNotFound
peer.logger.Debug().
Int("peer_bitrate", peerBitrate).
Msg("evaluated bitrate")
} }
changed, err := peer.videoTrack.SetBitrate(peerBitrate) // set video stream to track
changed, err := peer.videoTrack.SetStream(stream)
if err != nil { if err != nil {
return err return err
} }
// if video stream was already set, do nothing
if !changed { if !changed {
// TODO: return error?
return nil return nil
} }
videoID := peer.videoTrack.stream.ID() videoID := stream.ID()
bitrate := peer.videoTrack.stream.Bitrate()
peer.metrics.SetVideoID(videoID) peer.metrics.SetVideoID(videoID)
peer.logger.Debug().
Int("peer_bitrate", peerBitrate). peer.logger.Info().Str("video_id", videoID).Msg("set video")
Int("video_bitrate", bitrate).
Str("video_id", videoID).
Msg("peer bitrate triggered video stream change")
go peer.session.Send( go peer.session.Send(
event.SIGNAL_VIDEO, event.SIGNAL_VIDEO,
message.SignalVideo{ message.SignalVideo{
Video: videoID, Video: videoID,
Bitrate: bitrate, Auto: peer.videoAuto,
VideoAuto: peer.videoTrack.VideoAuto(),
}) })
return nil return nil
} }
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { func (peer *WebRTCPeerCtx) VideoID() (string, bool) {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
changed, err := peer.videoTrack.SetVideoID(videoID) stream, ok := peer.videoTrack.Stream()
if err != nil { if !ok {
return err return "", false
} }
if !changed { return stream.ID(), true
// TODO: return error?
return nil
}
bitrate := peer.videoTrack.stream.Bitrate()
peer.logger.Debug().
Str("video_id", videoID).
Int("video_bitrate", bitrate).
Msg("peer video id triggered video stream change")
go peer.session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: peer.videoTrack.VideoAuto(),
})
return nil
}
func (peer *WebRTCPeerCtx) GetVideoID() string {
peer.mu.Lock()
defer peer.mu.Unlock()
// TODO: Refactor.
return peer.videoTrack.stream.ID()
} }
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error { func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
@ -239,18 +360,32 @@ func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
return nil return nil
} }
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoTrack.Paused() || peer.audioTrack.Paused()
}
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) { func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
peer.mu.Lock()
defer peer.mu.Unlock()
// if estimator is enabled and is not passive, enable video auto bitrate // if estimator is enabled and is not passive, enable video auto bitrate
if peer.estimator != nil && !peer.estimatorPassive { if peer.estimator != nil && !peer.estimatorConfig.Passive {
peer.videoTrack.SetVideoAuto(videoAuto) peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
peer.videoAuto = videoAuto
} else { } else {
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto") peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
peer.videoTrack.SetVideoAuto(false) // ensure video auto is disabled peer.videoAuto = false // ensure video auto is disabled
} }
} }
func (peer *WebRTCPeerCtx) VideoAuto() bool { func (peer *WebRTCPeerCtx) VideoAuto() bool {
return peer.videoTrack.VideoAuto() peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoAuto
} }
// //

View File

@ -2,7 +2,6 @@ package webrtc
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"sync" "sync"
@ -22,25 +21,13 @@ type Track struct {
rtcpCh chan []rtcp.Packet rtcpCh chan []rtcp.Packet
sample chan types.Sample sample chan types.Sample
videoAuto bool
videoAutoMu sync.RWMutex
paused bool paused bool
stream types.StreamSinkManager stream types.StreamSinkManager
streamMu sync.Mutex streamMu sync.Mutex
bitrateChange func(int) (bool, error)
videoChange func(string) (bool, error)
} }
type trackOption func(*Track) type trackOption func(*Track)
func WithVideoAuto(auto bool) trackOption {
return func(t *Track) {
t.videoAuto = auto
}
}
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption { func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
return func(t *Track) { return func(t *Track) {
t.rtcpCh = rtcp t.rtcpCh = rtcp
@ -100,6 +87,8 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
} }
} }
// --- sample ---
func (t *Track) sampleReader() { func (t *Track) sampleReader() {
for { for {
sample, ok := <-t.sample sample, ok := <-t.sample
@ -120,6 +109,12 @@ func (t *Track) sampleReader() {
} }
} }
func (t *Track) WriteSample(sample types.Sample) {
t.sample <- sample
}
// --- stream ---
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) { func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock() t.streamMu.Lock()
defer t.streamMu.Unlock() defer t.streamMu.Unlock()
@ -167,6 +162,15 @@ func (t *Track) RemoveStream() {
t.stream = nil t.stream = nil
} }
func (t *Track) Stream() (types.StreamSinkManager, bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.stream, t.stream != nil
}
// --- paused ---
func (t *Track) SetPaused(paused bool) { func (t *Track) SetPaused(paused bool) {
t.streamMu.Lock() t.streamMu.Lock()
defer t.streamMu.Unlock() defer t.streamMu.Unlock()
@ -190,42 +194,9 @@ func (t *Track) SetPaused(paused bool) {
t.paused = paused t.paused = paused
} }
func (t *Track) WriteSample(sample types.Sample) { func (t *Track) Paused() bool {
t.sample <- sample t.streamMu.Lock()
} defer t.streamMu.Unlock()
func (t *Track) SetBitrate(bitrate int) (bool, error) { return t.paused
if t.bitrateChange == nil {
return false, fmt.Errorf("bitrate change not supported")
}
return t.bitrateChange(bitrate)
}
func (t *Track) SetVideoID(videoID string) (bool, error) {
if t.videoChange == nil {
return false, fmt.Errorf("video change not supported")
}
return t.videoChange(videoID)
}
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
t.bitrateChange = f
}
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
t.videoChange = f
}
func (t *Track) SetVideoAuto(auto bool) {
t.videoAutoMu.Lock()
defer t.videoAutoMu.Unlock()
t.videoAuto = auto
}
func (t *Track) VideoAuto() bool {
t.videoAutoMu.RLock()
defer t.videoAutoMu.RUnlock()
return t.videoAuto
} }

View File

@ -20,28 +20,26 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
payload.Video = videos[0] payload.Video = videos[0]
} }
var err error offer, peer, err := h.webrtc.CreatePeer(session)
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
if err != nil { if err != nil {
return err return err
} }
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil { // set webrtc as paused if session has private mode enabled
// set webrtc as paused if session has private mode enabled if session.PrivateModeEnabled() {
if session.PrivateModeEnabled() { peer.SetPaused(true)
webrtcPeer.SetPaused(true) }
}
payload.Video = webrtcPeer.GetVideoID() // set video auto state
payload.VideoAuto = webrtcPeer.VideoAuto() peer.SetVideoAuto(payload.Auto)
// set video stream
err = peer.SetVideo(types.StreamSelector{
ID: payload.Video,
Type: types.StreamSelectorTypeNearest,
})
if err != nil {
return err
} }
session.Send( session.Send(
@ -49,9 +47,6 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{ message.SignalProvide{
SDP: offer.SDP, SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(), ICEServers: h.webrtc.ICEServers(),
Video: payload.Video, // TODO: Refactor
Bitrate: payload.Bitrate,
VideoAuto: payload.VideoAuto,
}) })
return nil return nil
@ -133,16 +128,13 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist") return errors.New("webRTC peer does not exist")
} }
peer.SetVideoAuto(payload.VideoAuto) peer.SetVideoAuto(payload.Auto)
if payload.Video != "" { if payload.Video != "" {
if err := peer.SetVideoID(payload.Video); err != nil { return peer.SetVideo(types.StreamSelector{
h.logger.Error().Err(err).Msg("failed to set video id") ID: payload.Video,
} Type: types.StreamSelectorTypeNearest,
} else { })
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
h.logger.Error().Err(err).Msg("failed to set video bitrate")
}
} }
return nil return nil

View File

@ -31,26 +31,6 @@ type SampleListener interface {
WriteSample(Sample) WriteSample(Sample)
} }
type Receiver interface {
SetStream(stream StreamSinkManager) (changed bool, err error)
RemoveStream()
OnBitrateChange(f func(bitrate int) (changed bool, err error))
OnVideoChange(f func(videoID string) (changed bool, err error))
VideoAuto() bool
SetVideoAuto(videoAuto bool)
}
type BucketsManager interface {
IDs() []string
Codec() codec.RTPCodec
SetReceiver(receiver Receiver)
RemoveReceiver(receiver Receiver) error
DestroyAll()
RecreateAll() error
Shutdown()
}
type BroadcastManager interface { type BroadcastManager interface {
Start(url string) error Start(url string) error
Stop() Stop()
@ -64,10 +44,74 @@ type ScreencastManager interface {
Image() ([]byte, error) Image() ([]byte, error)
} }
type StreamSelectorType int
const (
// select exact stream
StreamSelectorTypeExact StreamSelectorType = iota
// select nearest stream (in either direction) if exact stream is not available
StreamSelectorTypeNearest
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
StreamSelectorTypeLower
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
StreamSelectorTypeHigher
)
func (s StreamSelectorType) String() string {
switch s {
case StreamSelectorTypeExact:
return "exact"
case StreamSelectorTypeNearest:
return "nearest"
case StreamSelectorTypeLower:
return "lower"
case StreamSelectorTypeHigher:
return "higher"
default:
return fmt.Sprintf("%d", int(s))
}
}
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
case "exact", "":
*s = StreamSelectorTypeExact
case "nearest":
*s = StreamSelectorTypeNearest
case "lower":
*s = StreamSelectorTypeLower
case "higher":
*s = StreamSelectorTypeHigher
default:
return fmt.Errorf("invalid stream selector type: %s", string(text))
}
return nil
}
func (s StreamSelectorType) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
type StreamSelector struct {
// type of stream selector
Type StreamSelectorType
// select stream by its ID
ID string
// select stream by its bitrate
Bitrate uint64
}
type StreamSelectorManager interface {
IDs() []string
Codec() codec.RTPCodec
GetStream(selector StreamSelector) (StreamSinkManager, bool)
}
type StreamSinkManager interface { type StreamSinkManager interface {
ID() string ID() string
Codec() codec.RTPCodec Codec() codec.RTPCodec
Bitrate() int Bitrate() uint64
AddListener(listener SampleListener) error AddListener(listener SampleListener) error
RemoveListener(listener SampleListener) error RemoveListener(listener SampleListener) error
@ -94,12 +138,10 @@ type CaptureManager interface {
Start() Start()
Shutdown() error Shutdown() error
GetBitrateFromVideoID(videoID string) (int, error)
Broadcast() BroadcastManager Broadcast() BroadcastManager
Screencast() ScreencastManager Screencast() ScreencastManager
Audio() StreamSinkManager Audio() StreamSinkManager
Video() BucketsManager Video() StreamSelectorManager
Webcam() StreamSrcManager Webcam() StreamSrcManager
Microphone() StreamSrcManager Microphone() StreamSrcManager
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
config.GstSuffix, config.GstSuffix,
}[:], " "), nil }[:], " "), nil
} }
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
return func() (int, error) {
if config.Bitrate > 0 {
return config.Bitrate, nil
}
screen := getScreen()
values := map[string]any{
"width": screen.Width,
"height": screen.Height,
"fps": screen.Rate,
}
language := []gval.Language{
gval.Function("round", func(args ...any) (any, error) {
return (int)(math.Round(args[0].(float64))), nil
}),
}
// TOOD: do not read target-bitrate from pipeline, but only from config.
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
// TODO: This is only for h264.
expr, ok = config.GstParams["bitrate"]
if !ok {
return 0, fmt.Errorf("bitrate not found")
}
}
bitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
}
var bitrateInt int
switch val := bitrate.(type) {
case int:
bitrateInt = val
case float64:
bitrateInt = (int)(val)
default:
return 0, fmt.Errorf("bitrate is not int or float64")
}
return bitrateInt, nil
}
}

View File

@ -48,10 +48,6 @@ type SystemDisconnect struct {
type SignalProvide struct { type SignalProvide struct {
SDP string `json:"sdp"` SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"` ICEServers []types.ICEServer `json:"iceservers"`
// TODO: Use SignalVideo struct.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
} }
type SignalCandidate struct { type SignalCandidate struct {
@ -63,9 +59,8 @@ type SignalDescription struct {
} }
type SignalVideo struct { type SignalVideo struct {
Video string `json:"video"` Video string `json:"video"`
Bitrate int `json:"bitrate"` Auto bool `json:"auto"`
VideoAuto bool `json:"video_auto"`
} }
///////////////////////////// /////////////////////////////

View File

@ -9,6 +9,7 @@ import (
var ( var (
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found") ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found") ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
) )
type ICEServer struct { type ICEServer struct {
@ -23,10 +24,10 @@ type WebRTCPeer interface {
SetRemoteDescription(webrtc.SessionDescription) error SetRemoteDescription(webrtc.SessionDescription) error
SetCandidate(webrtc.ICECandidateInit) error SetCandidate(webrtc.ICECandidateInit) error
SetVideoBitrate(bitrate int) error SetVideo(StreamSelector) error
SetVideoID(videoID string) error VideoID() (string, bool)
GetVideoID() string
SetPaused(isPaused bool) error SetPaused(isPaused bool) error
Paused() bool
SetVideoAuto(auto bool) SetVideoAuto(auto bool)
VideoAuto() bool VideoAuto() bool
@ -42,6 +43,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer ICEServers() []ICEServer
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
SetCursorPosition(x, y int) SetCursorPosition(x, y int)
} }

153
pkg/utils/trenddetector.go Normal file
View File

@ -0,0 +1,153 @@
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
package utils
import (
"fmt"
"time"
)
// ------------------------------------------------
type TrendDirection int
const (
TrendDirectionNeutral TrendDirection = iota
TrendDirectionUpward
TrendDirectionDownward
)
func (t TrendDirection) String() string {
switch t {
case TrendDirectionNeutral:
return "NEUTRAL"
case TrendDirectionUpward:
return "UPWARD"
case TrendDirectionDownward:
return "DOWNWARD"
default:
return fmt.Sprintf("%d", int(t))
}
}
// ------------------------------------------------
type TrendDetectorParams struct {
RequiredSamples int
DownwardTrendThreshold float64
CollapseValues bool
}
type TrendDetector struct {
params TrendDetectorParams
startTime time.Time
numSamples int
values []int64
lowestValue int64
highestValue int64
direction TrendDirection
}
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
return &TrendDetector{
params: params,
startTime: time.Now(),
direction: TrendDirectionNeutral,
}
}
func (t *TrendDetector) Seed(value int64) {
if len(t.values) != 0 {
return
}
t.values = append(t.values, value)
}
func (t *TrendDetector) AddValue(value int64) {
t.numSamples++
if t.lowestValue == 0 || value < t.lowestValue {
t.lowestValue = value
}
if value > t.highestValue {
t.highestValue = value
}
// ignore duplicate values
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
return
}
if len(t.values) == t.params.RequiredSamples {
t.values = t.values[1:]
}
t.values = append(t.values, value)
t.updateDirection()
}
func (t *TrendDetector) GetLowest() int64 {
return t.lowestValue
}
func (t *TrendDetector) GetHighest() int64 {
return t.highestValue
}
func (t *TrendDetector) GetValues() []int64 {
return t.values
}
func (t *TrendDetector) GetDirection() TrendDirection {
return t.direction
}
func (t *TrendDetector) ToString() string {
now := time.Now()
elapsed := now.Sub(t.startTime).Seconds()
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
return str
}
func (t *TrendDetector) updateDirection() {
if len(t.values) < t.params.RequiredSamples {
t.direction = TrendDirectionNeutral
return
}
// using Kendall's Tau to find trend
kt := kendallsTau(t.values)
t.direction = TrendDirectionNeutral
switch {
case kt > 0:
t.direction = TrendDirectionUpward
case kt < t.params.DownwardTrendThreshold:
t.direction = TrendDirectionDownward
}
}
// ------------------------------------------------
func kendallsTau(values []int64) float64 {
concordantPairs := 0
discordantPairs := 0
for i := 0; i < len(values)-1; i++ {
for j := i + 1; j < len(values); j++ {
if values[i] < values[j] {
concordantPairs++
} else if values[i] > values[j] {
discordantPairs++
}
}
}
if (concordantPairs + discordantPairs) == 0 {
return 0.0
}
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
}