WebRTC congestion control (#26)

* Add congestion control

* Improve stream matching, add manual stream selection, add metrics

* Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively

* Missing signal response, fix video auto bug

* Remove redundant mutex

* Bitrate history queue

* Get bitrate fn support h264 & float64

---------

Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
Miroslav Šedivý
2023-02-06 19:45:51 +01:00
committed by GitHub
parent e80ae8019e
commit 2364facd60
15 changed files with 738 additions and 222 deletions

View File

@ -0,0 +1,145 @@
package buckets
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *BucketsManagerCtx) Shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.DestroyAll()
}
func (manager *BucketsManagerCtx) DestroyAll() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *BucketsManagerCtx) RecreateAll() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *BucketsManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
// bitrate history is per receiver
bitrateHistory := &queue{}
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
bitrate := peerBitrate
if receiver.VideoAuto() {
bitrate = bitrateHistory.normaliseBitrate(bitrate)
}
stream := manager.findNearestStream(bitrate)
streamID := stream.ID()
// TODO: make this less noisy in logs
manager.logger.Debug().
Str("video_id", streamID).
Int("len", bitrateHistory.len()).
Int("peer_bitrate", peerBitrate).
Int("bitrate", bitrate).
Msg("change video bitrate")
return receiver.SetStream(stream)
})
receiver.OnVideoChange(func(videoID string) (bool, error) {
stream := manager.streams[videoID]
manager.logger.Info().
Str("video_id", videoID).
Msg("video change")
return receiver.SetStream(stream)
})
}
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: peerBitrate - stream.Bitrate(),
})
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.OnVideoChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -0,0 +1,83 @@
package buckets
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
type fields struct {
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
}
type args struct {
peerBitrate int
}
tests := []struct {
name string
fields fields
args args
want types.StreamSinkManager
}{
{
name: "findNearestStream",
fields: fields{
streams: map[string]types.StreamSinkManager{
"1": mockStreamSink{
id: "1",
bitrate: 500,
},
"2": mockStreamSink{
id: "2",
bitrate: 750,
},
"3": mockStreamSink{
id: "3",
bitrate: 1000,
},
"4": mockStreamSink{
id: "4",
bitrate: 1250,
},
"5": mockStreamSink{
id: "5",
bitrate: 1700,
},
},
},
args: args{
peerBitrate: 950,
},
want: mockStreamSink{
id: "2",
bitrate: 750,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
}
})
}
}
type mockStreamSink struct {
id string
bitrate int
types.StreamSinkManager
}
func (m mockStreamSink) ID() string {
return m.id
}
func (m mockStreamSink) Bitrate() int {
return m.bitrate
}

View File

@ -0,0 +1,88 @@
package buckets
import (
"math"
"sync"
"time"
)
type queue struct {
sync.Mutex
q []elem
}
type elem struct {
created time.Time
bitrate int
}
func (q *queue) push(v elem) {
q.Lock()
defer q.Unlock()
// if the first element is older than 10 seconds, remove it
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
q.q = q.q[1:]
}
q.q = append(q.q, v)
}
func (q *queue) len() int {
q.Lock()
defer q.Unlock()
return len(q.q)
}
func (q *queue) avg() int {
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q {
sum += v.bitrate
}
return sum / len(q.q)
}
func (q *queue) avgLastN(n int) int {
if n <= 0 {
return q.avg()
}
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q[len(q.q)-n:] {
sum += v.bitrate
}
return sum / n
}
func (q *queue) normaliseBitrate(currentBitrate int) int {
avgBitrate := float64(q.avg())
histLen := float64(q.len())
q.push(elem{
bitrate: currentBitrate,
created: time.Now(),
})
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
return currentBitrate
}
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
if lastN > q.len() {
lastN = q.len()
}
if lastN == 0 {
return currentBitrate
}
return q.avgLastN(lastN)
}

View File

@ -0,0 +1,99 @@
package buckets
import "testing"
func Queue_normaliseBitrate(t *testing.T) {
type fields struct {
queue *queue
}
type args struct {
currentBitrate int
}
tests := []struct {
name string
fields fields
args args
want []int
}{
{
name: "normaliseBitrate: big drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 350,
},
want: []int{816, 700, 537, 350, 350},
}, {
name: "normaliseBitrate: small drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 700,
},
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
}, {
name: "normaliseBitrate",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 1350,
},
want: []int{943, 1003, 1060, 1085},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.queue
for i := 0; i < len(tt.want); i++ {
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
}
}
})
}
}