Bandwidth estimator refactor (#46)

* rewrite to use stream selector.

* WIP.

* add nacks to metrics.

* add estimate trend.

* estimator based on trend detector.

* add estimator unstable duration.

* add estimator debug.

* add stalled duration.

* estimator move values to config.

* change default estimator values.

* minor style changes.

* fix websocket video messages.

* replace video track with ivdeo id.
This commit is contained in:
Miroslav Šedivý 2023-05-15 19:29:39 +02:00 committed by GitHub
parent 8660c1a256
commit 3e8d686c0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 845 additions and 775 deletions

View File

@ -1,145 +0,0 @@
package buckets
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *BucketsManagerCtx) Shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.DestroyAll()
}
func (manager *BucketsManagerCtx) DestroyAll() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *BucketsManagerCtx) RecreateAll() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *BucketsManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
// bitrate history is per receiver
bitrateHistory := &queue{}
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
bitrate := peerBitrate
if receiver.VideoAuto() {
bitrate = bitrateHistory.normaliseBitrate(bitrate)
}
stream := manager.findNearestStream(bitrate)
streamID := stream.ID()
// TODO: make this less noisy in logs
manager.logger.Debug().
Str("video_id", streamID).
Int("len", bitrateHistory.len()).
Int("peer_bitrate", peerBitrate).
Int("bitrate", bitrate).
Msg("change video bitrate")
return receiver.SetStream(stream)
})
receiver.OnVideoChange(func(videoID string) (bool, error) {
stream := manager.streams[videoID]
manager.logger.Info().
Str("video_id", videoID).
Msg("video change")
return receiver.SetStream(stream)
})
}
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: peerBitrate - stream.Bitrate(),
})
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.OnVideoChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -1,83 +0,0 @@
package buckets
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
type fields struct {
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
}
type args struct {
peerBitrate int
}
tests := []struct {
name string
fields fields
args args
want types.StreamSinkManager
}{
{
name: "findNearestStream",
fields: fields{
streams: map[string]types.StreamSinkManager{
"1": mockStreamSink{
id: "1",
bitrate: 500,
},
"2": mockStreamSink{
id: "2",
bitrate: 750,
},
"3": mockStreamSink{
id: "3",
bitrate: 1000,
},
"4": mockStreamSink{
id: "4",
bitrate: 1250,
},
"5": mockStreamSink{
id: "5",
bitrate: 1700,
},
},
},
args: args{
peerBitrate: 950,
},
want: mockStreamSink{
id: "2",
bitrate: 750,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
}
})
}
}
type mockStreamSink struct {
id string
bitrate int
types.StreamSinkManager
}
func (m mockStreamSink) ID() string {
return m.id
}
func (m mockStreamSink) Bitrate() int {
return m.bitrate
}

View File

@ -1,88 +0,0 @@
package buckets
import (
"math"
"sync"
"time"
)
type queue struct {
sync.Mutex
q []elem
}
type elem struct {
created time.Time
bitrate int
}
func (q *queue) push(v elem) {
q.Lock()
defer q.Unlock()
// if the first element is older than 10 seconds, remove it
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
q.q = q.q[1:]
}
q.q = append(q.q, v)
}
func (q *queue) len() int {
q.Lock()
defer q.Unlock()
return len(q.q)
}
func (q *queue) avg() int {
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q {
sum += v.bitrate
}
return sum / len(q.q)
}
func (q *queue) avgLastN(n int) int {
if n <= 0 {
return q.avg()
}
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q[len(q.q)-n:] {
sum += v.bitrate
}
return sum / n
}
func (q *queue) normaliseBitrate(currentBitrate int) int {
avgBitrate := float64(q.avg())
histLen := float64(q.len())
q.push(elem{
bitrate: currentBitrate,
created: time.Now(),
})
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
return currentBitrate
}
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
if lastN > q.len() {
lastN = q.len()
}
if lastN == 0 {
return currentBitrate
}
return q.avgLastN(lastN)
}

View File

@ -1,99 +0,0 @@
package buckets
import "testing"
func Queue_normaliseBitrate(t *testing.T) {
type fields struct {
queue *queue
}
type args struct {
currentBitrate int
}
tests := []struct {
name string
fields fields
args args
want []int
}{
{
name: "normaliseBitrate: big drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 350,
},
want: []int{816, 700, 537, 350, 350},
}, {
name: "normaliseBitrate: small drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 700,
},
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
}, {
name: "normaliseBitrate",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 1350,
},
want: []int{943, 1003, 1060, 1085},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.queue
for i := 0; i < len(tt.want); i++ {
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
}
}
})
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/capture/buckets"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
@ -23,7 +22,7 @@ type CaptureManagerCtx struct {
broadcast *BroacastManagerCtx
screencast *ScreencastManagerCtx
audio *StreamSinkManagerCtx
video types.BucketsManager
video *StreamSelectorManagerCtx
// sources
webcam *StreamSrcManagerCtx
@ -68,13 +67,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Str("pipeline", pipeline).
Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if _, err = getVideoBitrate(); err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate")
}
// append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
}
return &CaptureManagerCtx{
@ -140,8 +134,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! %s "+
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil
}, "audio", nil),
video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs),
}, "audio"),
video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
// sources
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
@ -202,7 +196,7 @@ func (manager *CaptureManagerCtx) Start() {
}
manager.desktop.OnBeforeScreenSizeChange(func() {
manager.video.DestroyAll()
manager.video.destroyPipelines()
if manager.broadcast.Started() {
manager.broadcast.destroyPipeline()
@ -214,7 +208,7 @@ func (manager *CaptureManagerCtx) Start() {
})
manager.desktop.OnAfterScreenSizeChange(func() {
err := manager.video.RecreateAll()
err := manager.video.recreatePipelines()
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
}
@ -242,7 +236,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
manager.screencast.shutdown()
manager.audio.shutdown()
manager.video.Shutdown()
manager.video.shutdown()
manager.webcam.shutdown()
manager.microphone.shutdown()
@ -250,15 +244,6 @@ func (manager *CaptureManagerCtx) Shutdown() error {
return nil
}
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
cfg, ok := manager.config.VideoPipelines[videoID]
if !ok {
return 0, fmt.Errorf("video config not found for %s", videoID)
}
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
}
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
return manager.broadcast
}
@ -271,7 +256,7 @@ func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
return manager.audio
}
func (manager *CaptureManagerCtx) Video() types.BucketsManager {
func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
return manager.video
}

View File

@ -0,0 +1,206 @@
package capture
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type StreamSelectorManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-selector").
Logger()
return &StreamSelectorManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *StreamSelectorManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.destroyPipelines()
}
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *StreamSelectorManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
// select stream by ID
if selector.ID != "" {
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
var lastStream types.StreamSinkManager
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
var lastStream types.StreamSinkManager
for _, streamID := range manager.streamIDs {
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a higher stream
return nil, false
}
// select exact stream
stream, ok := manager.streams[selector.ID]
return stream, ok
}
// select stream by bitrate
if selector.Bitrate != 0 {
// select stream by nearest bitrate
if selector.Type == types.StreamSelectorTypeNearest {
return manager.nearestBitrate(selector.Bitrate), true
}
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
// start from the highest stream, and go down, until we find a lower stream
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() < selector.Bitrate {
return stream, true
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
// start from the lowest stream, and go up, until we find a higher stream
for _, streamID := range manager.streamIDs {
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() > selector.Bitrate {
return stream, true
}
}
// we couldn't find a higher stream
return nil, false
}
// select stream by exact bitrate
for _, stream := range manager.streams {
if stream.Bitrate() == selector.Bitrate {
return stream, true
}
}
}
// we couldn't find a stream
return nil, false
}
// TODO: This is a very naive implementation, we should use a binary search instead.
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if !considered {
continue
}
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
})
}
// no streams available
if len(diffs) == 0 {
// return first (lowest) stream
return manager.streams[manager.streamIDs[0]]
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}

View File

@ -21,9 +21,10 @@ import (
var moveSinkListenerMu = sync.Mutex{}
type StreamSinkManagerCtx struct {
id string
getBitrate func() (int, error)
waitForKf bool // wait for a keyframe before sending samples
id string
// wait for a keyframe before sending samples
waitForKf bool
bitrate uint64 // atomic
brBuckets map[int]float64
@ -48,22 +49,23 @@ type StreamSinkManagerCtx struct {
pipelinesActive prometheus.Gauge
}
func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-sink").
Str("id", id).Logger()
manager := &StreamSinkManagerCtx{
id: id,
getBitrate: getBitrate,
// only wait for keyframes if the codec is video
waitForKf: c.IsVideo(),
id: id,
// only wait for keyframes if the codec is video
waitForKf: codec.IsVideo(),
bitrate: 0,
brBuckets: map[int]float64{},
logger: logger,
codec: c,
codec: codec,
pipelineFn: pipelineFn,
listeners: map[uintptr]types.SampleListener{},
@ -77,8 +79,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
Help: "Current number of listeners for a pipeline.",
ConstLabels: map[string]string{
"video_id": id,
"codec_name": c.Name,
"codec_type": c.Type.String(),
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
@ -88,8 +90,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
Help: "Total number of bytes created by the pipeline.",
ConstLabels: map[string]string{
"video_id": id,
"codec_name": c.Name,
"codec_type": c.Type.String(),
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
@ -100,8 +102,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": id,
"codec_name": c.Name,
"codec_type": c.Type.String(),
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
@ -112,8 +114,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": id,
"codec_name": c.Name,
"codec_type": c.Type.String(),
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
}
@ -141,27 +143,8 @@ func (manager *StreamSinkManagerCtx) ID() string {
return manager.id
}
func (manager *StreamSinkManagerCtx) Bitrate() int {
// TODO: fix bitrate switching calculation
// return real bitrate if available
//realBitrate := atomic.LoadUint64(&manager.bitrate)
//if realBitrate != 0 {
// return int(realBitrate)
//}
// if we do not have function to estimate bitrate, return 0
if manager.getBitrate == nil {
return 0
}
// recalculate bitrate every time, take screen resolution (and fps) into account
// we called this function during startup, so it shouldn't error here
bitrate, err := manager.getBitrate()
if err != nil {
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
}
return bitrate
func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
return atomic.LoadUint64(&manager.bitrate)
}
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {

View File

@ -3,6 +3,7 @@ package config
import (
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
@ -15,6 +16,28 @@ import (
// default stun server
const defStunSrv = "stun:stun.l.google.com:19302"
type WebRTCEstimator struct {
Enabled bool
Passive bool
Debug bool
InitialBitrate int
// how often to read and process bandwidth estimation reports
ReadInterval time.Duration
// how long to wait for stable connection (only neutral or upward trend) before upgrading
StableDuration time.Duration
// how long to wait for unstable connection (downward trend) before downgrading
UnstableDuration time.Duration
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
StalledDuration time.Duration
// how long to wait before downgrading again after previous downgrade
DowngradeBackoff time.Duration
// how long to wait before upgrading again after previous upgrade
UpgradeBackoff time.Duration
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
DiffThreshold float64
}
type WebRTC struct {
ICELite bool
ICETrickle bool
@ -28,9 +51,7 @@ type WebRTC struct {
NAT1To1IPs []string
IpRetrievalUrl string
EstimatorEnabled bool
EstimatorPassive bool
EstimatorInitialBitrate int
Estimator WebRTCEstimator
}
func (WebRTC) Init(cmd *cobra.Command) error {
@ -96,11 +117,51 @@ func (WebRTC) Init(cmd *cobra.Command) error {
return err
}
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
return err
}
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
return err
}
return nil
}
@ -197,7 +258,15 @@ func (s *WebRTC) Set() {
// bandwidth estimator
s.EstimatorEnabled = viper.GetBool("webrtc.estimator.enabled")
s.EstimatorPassive = viper.GetBool("webrtc.estimator.passive")
s.EstimatorInitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
}

View File

@ -24,6 +24,7 @@ import (
"github.com/demodesk/neko/pkg/types/codec"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
const (
@ -167,7 +168,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServersFrontend
}
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine
engine := &webrtc.MediaEngine{}
for _, codec := range codecs {
@ -223,14 +224,10 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
// create bandwidth estimator
estimatorChan := make(chan cc.BandwidthEstimator, 1)
if manager.config.EstimatorEnabled {
if manager.config.Estimator.Enabled {
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
if bitrate == 0 {
bitrate = manager.config.EstimatorInitialBitrate
}
return gcc.NewSendSideBWE(
gcc.SendSideBWEInitialBitrate(bitrate),
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
)
})
@ -268,7 +265,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
return connection, <-estimatorChan, err
}
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
id := atomic.AddInt32(&manager.peerId, 1)
// get metrics for session
@ -287,12 +284,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
video := manager.capture.Video()
videoCodec := video.Codec()
connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{
audioCodec,
videoCodec,
}, bitrate)
connection, estimator, err := manager.newPeerConnection(
logger, []codec.RTPCodec{audioCodec, videoCodec})
if err != nil {
return nil, err
return nil, nil, err
}
// asynchronously send local ICE Candidates
@ -311,47 +306,34 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
})
}
// if bitrate is 0, and estimator is enabled, use estimator bitrate
if bitrate == 0 && estimator != nil {
bitrate = estimator.GetTargetBitrate()
}
// audio track
audioTrack, err := NewTrack(logger, audioCodec, connection)
if err != nil {
return nil, err
return nil, nil, err
}
// set stream for audio track
_, err = audioTrack.SetStream(audio)
if err != nil {
return nil, err
return nil, nil, err
}
// if estimator is disabled, or in passive mode, disable video auto bitrate
if !manager.config.EstimatorEnabled || manager.config.EstimatorPassive {
videoAuto = false
}
videoRtcp := make(chan []rtcp.Packet, 1)
// video track
videoTrack, err := NewTrack(logger, videoCodec, connection,
WithVideoAuto(videoAuto),
WithRtcpChan(videoRtcp),
)
videoRtcp := make(chan []rtcp.Packet, 1)
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
if err != nil {
return nil, err
return nil, nil, err
}
// let video stream bucket manager handle stream subscriptions
video.SetReceiver(videoTrack)
//
// stream for video track will be set later
//
// data channel
dataChannel, err := connection.CreateDataChannel("data", nil)
if err != nil {
return nil, err
return nil, nil, err
}
peer := &WebRTCPeerCtx{
@ -359,24 +341,29 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
session: session,
metrics: metrics,
connection: connection,
estimator: estimator,
// bandwidth estimator
estimator: estimator,
estimateTrend: utils.NewTrendDetector(
utils.TrendDetectorParams{
// Probing
//RequiredSamples: 3,
//DownwardTrendThreshold: 0.0,
//CollapseValues: false,
// Non-Probing
RequiredSamples: 8,
DownwardTrendThreshold: -0.5,
CollapseValues: true,
}),
// stream selectors
videoSelector: manager.capture.Video(),
// tracks & channels
audioTrack: audioTrack,
videoTrack: videoTrack,
dataChannel: dataChannel,
rtcpChannel: videoRtcp,
// config
iceTrickle: manager.config.ICETrickle,
estimatorPassive: manager.config.EstimatorPassive,
}
logger.Info().
Int("target_bitrate", bitrate).
Msg("estimated initial peer bitrate")
// set initial video bitrate
if err := peer.SetVideoBitrate(bitrate); err != nil {
return nil, err
iceTrickle: manager.config.ICETrickle,
estimatorConfig: manager.config.Estimator,
}
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
@ -492,9 +479,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
// ensure we only run this once
once.Do(func() {
session.SetWebRTCConnected(peer, false)
if err = video.RemoveReceiver(videoTrack); err != nil {
logger.Err(err).Msg("failed to remove video receiver")
}
//
// TODO: Shutdown peer?
//
audioTrack.Shutdown()
videoTrack.Shutdown()
close(videoRtcp)
@ -542,7 +529,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
offer, err := peer.CreateOffer(false)
if err != nil {
return nil, err
return nil, nil, err
}
// on negotiation needed handler must be registered after creating initial
@ -576,7 +563,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
// start estimator reader
go peer.estimatorReader()
return offer, nil
return offer, peer, nil
}
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {

View File

@ -121,7 +121,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_estimated_maximum_bitrate",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Estimated Maximum Bitrate from SCTP.",
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
@ -140,7 +140,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_delay",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Delay from SCTP, expressed in units of 1/65536 seconds.",
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
@ -149,7 +149,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_jitter",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Jitter from SCTP.",
Help: "Receiver Report Jitter from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
@ -158,7 +158,17 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
Name: "receiver_report_total_lost",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Total Lost from SCTP.",
Help: "Receiver Report Total Lost from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
Name: "transport_layer_nacks",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Transport Layer NACKs from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
@ -236,6 +246,8 @@ type metrics struct {
receiverReportJitter prometheus.Gauge
receiverReportTotalLost prometheus.Gauge
transportLayerNacks prometheus.Counter
iceBytesSent prometheus.Gauge
iceBytesReceived prometheus.Gauge
sctpBytesSent prometheus.Gauge
@ -386,6 +398,11 @@ func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
// use only last report
met.SetReceiverReport(rtcpPacket.Reports[l-1])
}
case *rtcp.TransportLayerNack:
for _, pair := range rtcpPacket.Nacks {
packetList := pair.PacketList()
met.transportLayerNacks.Add(float64(len(packetList)))
}
}
}
}

View File

@ -11,15 +11,12 @@ import (
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
const (
// how often to read and process bandwidth estimation reports
estimatorReadInterval = 250 * time.Millisecond
"github.com/demodesk/neko/pkg/utils"
)
type WebRTCPeerCtx struct {
@ -28,15 +25,20 @@ type WebRTCPeerCtx struct {
session types.Session
metrics *metrics
connection *webrtc.PeerConnection
estimator cc.BandwidthEstimator
// bandwidth estimator
estimator cc.BandwidthEstimator
estimateTrend *utils.TrendDetector
// stream selectors
videoSelector types.StreamSelectorManager
// tracks & channels
audioTrack *Track
videoTrack *Track
dataChannel *webrtc.DataChannel
rtcpChannel chan []rtcp.Packet
// config
iceTrickle bool
estimatorPassive bool
iceTrickle bool
estimatorConfig config.WebRTCEstimator
videoAuto bool
}
//
@ -102,6 +104,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate)
}
// TODO: Add shutdown function?
func (peer *WebRTCPeerCtx) Destroy() {
peer.mu.Lock()
defer peer.mu.Unlock()
@ -111,32 +114,186 @@ func (peer *WebRTCPeerCtx) Destroy() {
}
func (peer *WebRTCPeerCtx) estimatorReader() {
conf := peer.estimatorConfig
// if estimator is not in debug mode, use a nop logger
var debugLogger zerolog.Logger
if conf.Debug {
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
} else {
debugLogger = zerolog.Nop()
}
// if estimator is disabled, do nothing
if peer.estimator == nil {
return
}
// use a ticker to get current client target bitrate
ticker := time.NewTicker(estimatorReadInterval)
ticker := time.NewTicker(conf.ReadInterval)
defer ticker.Stop()
// since when is the estimate stable/unstable
stableSince := time.Now() // we asume stable at start
unstableSince := time.Time{}
// since when are we neutral but cannot accomodate current bitrate
// we migt be stalled or estimator just reached zer (very bad connection)
stalledSince := time.Time{}
// when was the last upgrade/downgrade
lastUpgradeTime := time.Time{}
lastDowngradeTime := time.Time{}
for range ticker.C {
targetBitrate := peer.estimator.GetTargetBitrate()
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
// if peer connection is closed, stop reading
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
if !peer.videoTrack.VideoAuto() {
// if estimation is disabled, do nothing
if !peer.videoAuto || conf.Passive {
continue
}
if !peer.estimatorPassive {
err := peer.SetVideoBitrate(targetBitrate)
if err != nil {
peer.logger.Warn().Err(err).Msg("failed to set video bitrate")
// get trend direction to decide if we should upgrade or downgrade
peer.estimateTrend.AddValue(int64(targetBitrate))
direction := peer.estimateTrend.GetDirection()
// get current stream bitrate
stream, ok := peer.videoTrack.Stream()
if !ok {
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
continue
}
// if stream bitrate is 0, we need to wait for some time until we get a valid value
streamId, streamBitrate := stream.ID(), stream.Bitrate()
if streamBitrate == 0 {
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
continue
}
// check whats the difference between target and stream bitrate
diff := float64(targetBitrate) / float64(streamBitrate)
debugLogger.Info().
Float64("diff", diff).
Int("target_bitrate", targetBitrate).
Uint64("stream_bitrate", streamBitrate).
Str("direction", direction.String()).
Msg("got bitrate from estimator")
// if we can accomodate current stream or we are not netural anymore,
// we are not stalled so we reset the stalled time
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
stalledSince = time.Now()
}
// if we are neutral and stalled for too long, we might be congesting
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
if stalled {
debugLogger.Warn().
Time("stalled_since", stalledSince).
Msgf("it looks like we are stalled")
}
// if we have an downward trend or are stalled, we might be congesting
if direction == utils.TrendDirectionDownward || stalled {
// we reset the stable time because we are congesting
stableSince = time.Now()
// if we downgraded recently, we wait for some more time
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
debugLogger.Debug().
Time("last_downgrade", lastDowngradeTime).
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
continue
}
// if we are not unstable but we fluctuate we should wait for some more time
if time.Since(unstableSince) < conf.UnstableDuration {
debugLogger.Debug().
Time("unstable_since", unstableSince).
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
continue
}
// if we still have a big difference between target and stream bitrate, we wait for some more time
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("we still have a big difference between target and stream bitrate, " +
"therefore we still should be able to accomodate current stream")
continue
}
err := peer.SetVideo(types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeLower,
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
}
lastDowngradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the lowest stream")
} else {
debugLogger.Info().Msg("downgraded video stream")
}
continue
}
// we reset the unstable time because we are not congesting
unstableSince = time.Now()
// if we have a neutral or upward trend, that means our estimate is stable
// if we are on the highest stream, we don't need to do anything
// but if there is a higher stream, we should try to upgrade and see if it works
// if we upgraded recently, we wait for some more time
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
debugLogger.Debug().
Time("last_upgrade", lastUpgradeTime).
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
continue
}
// if we are not stable for long enough, we wait for some more time
// because bandwidth estimation might fluctuate
if time.Since(stableSince) < conf.StableDuration {
debugLogger.Debug().
Time("stable_since", stableSince).
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
continue
}
// upgrade only if estimated bitrate passed the threshold
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
"therefore we should wait for some more time")
continue
}
err := peer.SetVideo(types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeHigher,
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
}
lastUpgradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the highest stream")
} else {
debugLogger.Info().Msg("upgraded video stream")
}
}
}
@ -145,88 +302,52 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
// video
//
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error {
peer.mu.Lock()
defer peer.mu.Unlock()
// when switching from manual to auto bitrate estimation, in case the estimator is
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
if peerBitrate == 0 && peer.estimator != nil && !peer.estimatorPassive {
peerBitrate = peer.estimator.GetTargetBitrate()
peer.logger.Debug().
Int("peer_bitrate", peerBitrate).
Msg("evaluated bitrate")
// get requested video stream from selector
stream, ok := peer.videoSelector.GetStream(selector)
if !ok {
return types.ErrWebRTCStreamNotFound
}
changed, err := peer.videoTrack.SetBitrate(peerBitrate)
// set video stream to track
changed, err := peer.videoTrack.SetStream(stream)
if err != nil {
return err
}
// if video stream was already set, do nothing
if !changed {
// TODO: return error?
return nil
}
videoID := peer.videoTrack.stream.ID()
bitrate := peer.videoTrack.stream.Bitrate()
videoID := stream.ID()
peer.metrics.SetVideoID(videoID)
peer.logger.Debug().
Int("peer_bitrate", peerBitrate).
Int("video_bitrate", bitrate).
Str("video_id", videoID).
Msg("peer bitrate triggered video stream change")
peer.logger.Info().Str("video_id", videoID).Msg("set video")
go peer.session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: peer.videoTrack.VideoAuto(),
Video: videoID,
Auto: peer.videoAuto,
})
return nil
}
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
func (peer *WebRTCPeerCtx) VideoID() (string, bool) {
peer.mu.Lock()
defer peer.mu.Unlock()
changed, err := peer.videoTrack.SetVideoID(videoID)
if err != nil {
return err
stream, ok := peer.videoTrack.Stream()
if !ok {
return "", false
}
if !changed {
// TODO: return error?
return nil
}
bitrate := peer.videoTrack.stream.Bitrate()
peer.logger.Debug().
Str("video_id", videoID).
Int("video_bitrate", bitrate).
Msg("peer video id triggered video stream change")
go peer.session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: peer.videoTrack.VideoAuto(),
})
return nil
}
func (peer *WebRTCPeerCtx) GetVideoID() string {
peer.mu.Lock()
defer peer.mu.Unlock()
// TODO: Refactor.
return peer.videoTrack.stream.ID()
return stream.ID(), true
}
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
@ -239,18 +360,32 @@ func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
return nil
}
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoTrack.Paused() || peer.audioTrack.Paused()
}
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
peer.mu.Lock()
defer peer.mu.Unlock()
// if estimator is enabled and is not passive, enable video auto bitrate
if peer.estimator != nil && !peer.estimatorPassive {
peer.videoTrack.SetVideoAuto(videoAuto)
if peer.estimator != nil && !peer.estimatorConfig.Passive {
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
peer.videoAuto = videoAuto
} else {
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
peer.videoTrack.SetVideoAuto(false) // ensure video auto is disabled
peer.videoAuto = false // ensure video auto is disabled
}
}
func (peer *WebRTCPeerCtx) VideoAuto() bool {
return peer.videoTrack.VideoAuto()
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoAuto
}
//

View File

@ -2,7 +2,6 @@ package webrtc
import (
"errors"
"fmt"
"io"
"sync"
@ -22,25 +21,13 @@ type Track struct {
rtcpCh chan []rtcp.Packet
sample chan types.Sample
videoAuto bool
videoAutoMu sync.RWMutex
paused bool
stream types.StreamSinkManager
streamMu sync.Mutex
bitrateChange func(int) (bool, error)
videoChange func(string) (bool, error)
}
type trackOption func(*Track)
func WithVideoAuto(auto bool) trackOption {
return func(t *Track) {
t.videoAuto = auto
}
}
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
return func(t *Track) {
t.rtcpCh = rtcp
@ -100,6 +87,8 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
}
}
// --- sample ---
func (t *Track) sampleReader() {
for {
sample, ok := <-t.sample
@ -120,6 +109,12 @@ func (t *Track) sampleReader() {
}
}
func (t *Track) WriteSample(sample types.Sample) {
t.sample <- sample
}
// --- stream ---
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
@ -167,6 +162,15 @@ func (t *Track) RemoveStream() {
t.stream = nil
}
func (t *Track) Stream() (types.StreamSinkManager, bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.stream, t.stream != nil
}
// --- paused ---
func (t *Track) SetPaused(paused bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
@ -190,42 +194,9 @@ func (t *Track) SetPaused(paused bool) {
t.paused = paused
}
func (t *Track) WriteSample(sample types.Sample) {
t.sample <- sample
}
func (t *Track) SetBitrate(bitrate int) (bool, error) {
if t.bitrateChange == nil {
return false, fmt.Errorf("bitrate change not supported")
}
return t.bitrateChange(bitrate)
}
func (t *Track) SetVideoID(videoID string) (bool, error) {
if t.videoChange == nil {
return false, fmt.Errorf("video change not supported")
}
return t.videoChange(videoID)
}
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
t.bitrateChange = f
}
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
t.videoChange = f
}
func (t *Track) SetVideoAuto(auto bool) {
t.videoAutoMu.Lock()
defer t.videoAutoMu.Unlock()
t.videoAuto = auto
}
func (t *Track) VideoAuto() bool {
t.videoAutoMu.RLock()
defer t.videoAutoMu.RUnlock()
return t.videoAuto
func (t *Track) Paused() bool {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.paused
}

View File

@ -20,28 +20,26 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
payload.Video = videos[0]
}
var err error
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
offer, peer, err := h.webrtc.CreatePeer(session)
if err != nil {
return err
}
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
// set webrtc as paused if session has private mode enabled
if session.PrivateModeEnabled() {
webrtcPeer.SetPaused(true)
}
// set webrtc as paused if session has private mode enabled
if session.PrivateModeEnabled() {
peer.SetPaused(true)
}
payload.Video = webrtcPeer.GetVideoID()
payload.VideoAuto = webrtcPeer.VideoAuto()
// set video auto state
peer.SetVideoAuto(payload.Auto)
// set video stream
err = peer.SetVideo(types.StreamSelector{
ID: payload.Video,
Type: types.StreamSelectorTypeNearest,
})
if err != nil {
return err
}
session.Send(
@ -49,9 +47,6 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{
SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(),
Video: payload.Video, // TODO: Refactor
Bitrate: payload.Bitrate,
VideoAuto: payload.VideoAuto,
})
return nil
@ -133,16 +128,13 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist")
}
peer.SetVideoAuto(payload.VideoAuto)
peer.SetVideoAuto(payload.Auto)
if payload.Video != "" {
if err := peer.SetVideoID(payload.Video); err != nil {
h.logger.Error().Err(err).Msg("failed to set video id")
}
} else {
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
h.logger.Error().Err(err).Msg("failed to set video bitrate")
}
return peer.SetVideo(types.StreamSelector{
ID: payload.Video,
Type: types.StreamSelectorTypeNearest,
})
}
return nil

View File

@ -31,26 +31,6 @@ type SampleListener interface {
WriteSample(Sample)
}
type Receiver interface {
SetStream(stream StreamSinkManager) (changed bool, err error)
RemoveStream()
OnBitrateChange(f func(bitrate int) (changed bool, err error))
OnVideoChange(f func(videoID string) (changed bool, err error))
VideoAuto() bool
SetVideoAuto(videoAuto bool)
}
type BucketsManager interface {
IDs() []string
Codec() codec.RTPCodec
SetReceiver(receiver Receiver)
RemoveReceiver(receiver Receiver) error
DestroyAll()
RecreateAll() error
Shutdown()
}
type BroadcastManager interface {
Start(url string) error
Stop()
@ -64,10 +44,74 @@ type ScreencastManager interface {
Image() ([]byte, error)
}
type StreamSelectorType int
const (
// select exact stream
StreamSelectorTypeExact StreamSelectorType = iota
// select nearest stream (in either direction) if exact stream is not available
StreamSelectorTypeNearest
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
StreamSelectorTypeLower
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
StreamSelectorTypeHigher
)
func (s StreamSelectorType) String() string {
switch s {
case StreamSelectorTypeExact:
return "exact"
case StreamSelectorTypeNearest:
return "nearest"
case StreamSelectorTypeLower:
return "lower"
case StreamSelectorTypeHigher:
return "higher"
default:
return fmt.Sprintf("%d", int(s))
}
}
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
case "exact", "":
*s = StreamSelectorTypeExact
case "nearest":
*s = StreamSelectorTypeNearest
case "lower":
*s = StreamSelectorTypeLower
case "higher":
*s = StreamSelectorTypeHigher
default:
return fmt.Errorf("invalid stream selector type: %s", string(text))
}
return nil
}
func (s StreamSelectorType) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
type StreamSelector struct {
// type of stream selector
Type StreamSelectorType
// select stream by its ID
ID string
// select stream by its bitrate
Bitrate uint64
}
type StreamSelectorManager interface {
IDs() []string
Codec() codec.RTPCodec
GetStream(selector StreamSelector) (StreamSinkManager, bool)
}
type StreamSinkManager interface {
ID() string
Codec() codec.RTPCodec
Bitrate() int
Bitrate() uint64
AddListener(listener SampleListener) error
RemoveListener(listener SampleListener) error
@ -94,12 +138,10 @@ type CaptureManager interface {
Start()
Shutdown() error
GetBitrateFromVideoID(videoID string) (int, error)
Broadcast() BroadcastManager
Screencast() ScreencastManager
Audio() StreamSinkManager
Video() BucketsManager
Video() StreamSelectorManager
Webcam() StreamSrcManager
Microphone() StreamSrcManager
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
config.GstSuffix,
}[:], " "), nil
}
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
return func() (int, error) {
if config.Bitrate > 0 {
return config.Bitrate, nil
}
screen := getScreen()
values := map[string]any{
"width": screen.Width,
"height": screen.Height,
"fps": screen.Rate,
}
language := []gval.Language{
gval.Function("round", func(args ...any) (any, error) {
return (int)(math.Round(args[0].(float64))), nil
}),
}
// TOOD: do not read target-bitrate from pipeline, but only from config.
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
// TODO: This is only for h264.
expr, ok = config.GstParams["bitrate"]
if !ok {
return 0, fmt.Errorf("bitrate not found")
}
}
bitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
}
var bitrateInt int
switch val := bitrate.(type) {
case int:
bitrateInt = val
case float64:
bitrateInt = (int)(val)
default:
return 0, fmt.Errorf("bitrate is not int or float64")
}
return bitrateInt, nil
}
}

View File

@ -48,10 +48,6 @@ type SystemDisconnect struct {
type SignalProvide struct {
SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"`
// TODO: Use SignalVideo struct.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
}
type SignalCandidate struct {
@ -63,9 +59,8 @@ type SignalDescription struct {
}
type SignalVideo struct {
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
Video string `json:"video"`
Auto bool `json:"auto"`
}
/////////////////////////////

View File

@ -9,6 +9,7 @@ import (
var (
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
)
type ICEServer struct {
@ -23,10 +24,10 @@ type WebRTCPeer interface {
SetRemoteDescription(webrtc.SessionDescription) error
SetCandidate(webrtc.ICECandidateInit) error
SetVideoBitrate(bitrate int) error
SetVideoID(videoID string) error
GetVideoID() string
SetVideo(StreamSelector) error
VideoID() (string, bool)
SetPaused(isPaused bool) error
Paused() bool
SetVideoAuto(auto bool)
VideoAuto() bool
@ -42,6 +43,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
SetCursorPosition(x, y int)
}

153
pkg/utils/trenddetector.go Normal file
View File

@ -0,0 +1,153 @@
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
package utils
import (
"fmt"
"time"
)
// ------------------------------------------------
type TrendDirection int
const (
TrendDirectionNeutral TrendDirection = iota
TrendDirectionUpward
TrendDirectionDownward
)
func (t TrendDirection) String() string {
switch t {
case TrendDirectionNeutral:
return "NEUTRAL"
case TrendDirectionUpward:
return "UPWARD"
case TrendDirectionDownward:
return "DOWNWARD"
default:
return fmt.Sprintf("%d", int(t))
}
}
// ------------------------------------------------
type TrendDetectorParams struct {
RequiredSamples int
DownwardTrendThreshold float64
CollapseValues bool
}
type TrendDetector struct {
params TrendDetectorParams
startTime time.Time
numSamples int
values []int64
lowestValue int64
highestValue int64
direction TrendDirection
}
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
return &TrendDetector{
params: params,
startTime: time.Now(),
direction: TrendDirectionNeutral,
}
}
func (t *TrendDetector) Seed(value int64) {
if len(t.values) != 0 {
return
}
t.values = append(t.values, value)
}
func (t *TrendDetector) AddValue(value int64) {
t.numSamples++
if t.lowestValue == 0 || value < t.lowestValue {
t.lowestValue = value
}
if value > t.highestValue {
t.highestValue = value
}
// ignore duplicate values
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
return
}
if len(t.values) == t.params.RequiredSamples {
t.values = t.values[1:]
}
t.values = append(t.values, value)
t.updateDirection()
}
func (t *TrendDetector) GetLowest() int64 {
return t.lowestValue
}
func (t *TrendDetector) GetHighest() int64 {
return t.highestValue
}
func (t *TrendDetector) GetValues() []int64 {
return t.values
}
func (t *TrendDetector) GetDirection() TrendDirection {
return t.direction
}
func (t *TrendDetector) ToString() string {
now := time.Now()
elapsed := now.Sub(t.startTime).Seconds()
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
return str
}
func (t *TrendDetector) updateDirection() {
if len(t.values) < t.params.RequiredSamples {
t.direction = TrendDirectionNeutral
return
}
// using Kendall's Tau to find trend
kt := kendallsTau(t.values)
t.direction = TrendDirectionNeutral
switch {
case kt > 0:
t.direction = TrendDirectionUpward
case kt < t.params.DownwardTrendThreshold:
t.direction = TrendDirectionDownward
}
}
// ------------------------------------------------
func kendallsTau(values []int64) float64 {
concordantPairs := 0
discordantPairs := 0
for i := 0; i < len(values)-1; i++ {
for j := i + 1; j < len(values); j++ {
if values[i] < values[j] {
concordantPairs++
} else if values[i] > values[j] {
discordantPairs++
}
}
}
if (concordantPairs + discordantPairs) == 0 {
return 0.0
}
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
}