mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
WebRTC congestion control (#26)
* Add congestion control * Improve stream matching, add manual stream selection, add metrics * Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively * Missing signal response, fix video auto bug * Remove redundant mutex * Bitrate history queue * Get bitrate fn support h264 & float64 --------- Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
145
internal/capture/buckets/buckets.go
Normal file
145
internal/capture/buckets/buckets.go
Normal file
@ -0,0 +1,145 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type BucketsManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
streamIDs []string
|
||||
}
|
||||
|
||||
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "buckets").
|
||||
Logger()
|
||||
|
||||
return &BucketsManagerCtx{
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
streams: streams,
|
||||
streamIDs: streamIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) Shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.DestroyAll()
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) DestroyAll() {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
stream.DestroyPipeline()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) RecreateAll() error {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
err := stream.CreatePipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) IDs() []string {
|
||||
return manager.streamIDs
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
|
||||
// bitrate history is per receiver
|
||||
bitrateHistory := &queue{}
|
||||
|
||||
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
|
||||
bitrate := peerBitrate
|
||||
if receiver.VideoAuto() {
|
||||
bitrate = bitrateHistory.normaliseBitrate(bitrate)
|
||||
}
|
||||
|
||||
stream := manager.findNearestStream(bitrate)
|
||||
streamID := stream.ID()
|
||||
|
||||
// TODO: make this less noisy in logs
|
||||
manager.logger.Debug().
|
||||
Str("video_id", streamID).
|
||||
Int("len", bitrateHistory.len()).
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Int("bitrate", bitrate).
|
||||
Msg("change video bitrate")
|
||||
|
||||
return receiver.SetStream(stream)
|
||||
})
|
||||
|
||||
receiver.OnVideoChange(func(videoID string) (bool, error) {
|
||||
stream := manager.streams[videoID]
|
||||
manager.logger.Info().
|
||||
Str("video_id", videoID).
|
||||
Msg("video change")
|
||||
|
||||
return receiver.SetStream(stream)
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
|
||||
type streamDiff struct {
|
||||
id string
|
||||
bitrateDiff int
|
||||
}
|
||||
|
||||
sortDiff := func(a, b int) bool {
|
||||
switch {
|
||||
case a < 0 && b < 0:
|
||||
return a > b
|
||||
case a >= 0:
|
||||
if b >= 0 {
|
||||
return a <= b
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var diffs []streamDiff
|
||||
|
||||
for _, stream := range manager.streams {
|
||||
diffs = append(diffs, streamDiff{
|
||||
id: stream.ID(),
|
||||
bitrateDiff: peerBitrate - stream.Bitrate(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(diffs, func(i, j int) bool {
|
||||
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
||||
})
|
||||
|
||||
bestDiff := diffs[0]
|
||||
|
||||
return manager.streams[bestDiff.id]
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
|
||||
receiver.OnBitrateChange(nil)
|
||||
receiver.OnVideoChange(nil)
|
||||
receiver.RemoveStream()
|
||||
return nil
|
||||
}
|
83
internal/capture/buckets/buckets_test.go
Normal file
83
internal/capture/buckets/buckets_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
|
||||
type fields struct {
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
}
|
||||
type args struct {
|
||||
peerBitrate int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want types.StreamSinkManager
|
||||
}{
|
||||
{
|
||||
name: "findNearestStream",
|
||||
fields: fields{
|
||||
streams: map[string]types.StreamSinkManager{
|
||||
"1": mockStreamSink{
|
||||
id: "1",
|
||||
bitrate: 500,
|
||||
},
|
||||
"2": mockStreamSink{
|
||||
id: "2",
|
||||
bitrate: 750,
|
||||
},
|
||||
"3": mockStreamSink{
|
||||
id: "3",
|
||||
bitrate: 1000,
|
||||
},
|
||||
"4": mockStreamSink{
|
||||
id: "4",
|
||||
bitrate: 1250,
|
||||
},
|
||||
"5": mockStreamSink{
|
||||
id: "5",
|
||||
bitrate: 1700,
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
peerBitrate: 950,
|
||||
},
|
||||
want: mockStreamSink{
|
||||
id: "2",
|
||||
bitrate: 750,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
|
||||
|
||||
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockStreamSink struct {
|
||||
id string
|
||||
bitrate int
|
||||
types.StreamSinkManager
|
||||
}
|
||||
|
||||
func (m mockStreamSink) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m mockStreamSink) Bitrate() int {
|
||||
return m.bitrate
|
||||
}
|
88
internal/capture/buckets/queue.go
Normal file
88
internal/capture/buckets/queue.go
Normal file
@ -0,0 +1,88 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type queue struct {
|
||||
sync.Mutex
|
||||
q []elem
|
||||
}
|
||||
|
||||
type elem struct {
|
||||
created time.Time
|
||||
bitrate int
|
||||
}
|
||||
|
||||
func (q *queue) push(v elem) {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
// if the first element is older than 10 seconds, remove it
|
||||
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
|
||||
q.q = q.q[1:]
|
||||
}
|
||||
q.q = append(q.q, v)
|
||||
}
|
||||
|
||||
func (q *queue) len() int {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
return len(q.q)
|
||||
}
|
||||
|
||||
func (q *queue) avg() int {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
if len(q.q) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0
|
||||
for _, v := range q.q {
|
||||
sum += v.bitrate
|
||||
}
|
||||
return sum / len(q.q)
|
||||
}
|
||||
|
||||
func (q *queue) avgLastN(n int) int {
|
||||
if n <= 0 {
|
||||
return q.avg()
|
||||
}
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
if len(q.q) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0
|
||||
for _, v := range q.q[len(q.q)-n:] {
|
||||
sum += v.bitrate
|
||||
}
|
||||
return sum / n
|
||||
}
|
||||
|
||||
func (q *queue) normaliseBitrate(currentBitrate int) int {
|
||||
avgBitrate := float64(q.avg())
|
||||
histLen := float64(q.len())
|
||||
|
||||
q.push(elem{
|
||||
bitrate: currentBitrate,
|
||||
created: time.Now(),
|
||||
})
|
||||
|
||||
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
|
||||
return currentBitrate
|
||||
}
|
||||
|
||||
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
|
||||
if lastN > q.len() {
|
||||
lastN = q.len()
|
||||
}
|
||||
|
||||
if lastN == 0 {
|
||||
return currentBitrate
|
||||
}
|
||||
|
||||
return q.avgLastN(lastN)
|
||||
}
|
99
internal/capture/buckets/queue_test.go
Normal file
99
internal/capture/buckets/queue_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package buckets
|
||||
|
||||
import "testing"
|
||||
|
||||
func Queue_normaliseBitrate(t *testing.T) {
|
||||
type fields struct {
|
||||
queue *queue
|
||||
}
|
||||
type args struct {
|
||||
currentBitrate int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "normaliseBitrate: big drop",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 350,
|
||||
},
|
||||
want: []int{816, 700, 537, 350, 350},
|
||||
}, {
|
||||
name: "normaliseBitrate: small drop",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 700,
|
||||
},
|
||||
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
|
||||
}, {
|
||||
name: "normaliseBitrate",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 1350,
|
||||
},
|
||||
want: []int{943, 1003, 1060, 1085},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := tt.fields.queue
|
||||
for i := 0; i < len(tt.want); i++ {
|
||||
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
|
||||
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user