mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
Bandwidth estimator refactor (#46)
* rewrite to use stream selector. * WIP. * add nacks to metrics. * add estimate trend. * estimator based on trend detector. * add estimator unstable duration. * add estimator debug. * add stalled duration. * estimator move values to config. * change default estimator values. * minor style changes. * fix websocket video messages. * replace video track with ivdeo id.
This commit is contained in:
parent
8660c1a256
commit
3e8d686c0f
@ -1,145 +0,0 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type BucketsManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
streamIDs []string
|
||||
}
|
||||
|
||||
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "buckets").
|
||||
Logger()
|
||||
|
||||
return &BucketsManagerCtx{
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
streams: streams,
|
||||
streamIDs: streamIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) Shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.DestroyAll()
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) DestroyAll() {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
stream.DestroyPipeline()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) RecreateAll() error {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
err := stream.CreatePipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) IDs() []string {
|
||||
return manager.streamIDs
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
|
||||
// bitrate history is per receiver
|
||||
bitrateHistory := &queue{}
|
||||
|
||||
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
|
||||
bitrate := peerBitrate
|
||||
if receiver.VideoAuto() {
|
||||
bitrate = bitrateHistory.normaliseBitrate(bitrate)
|
||||
}
|
||||
|
||||
stream := manager.findNearestStream(bitrate)
|
||||
streamID := stream.ID()
|
||||
|
||||
// TODO: make this less noisy in logs
|
||||
manager.logger.Debug().
|
||||
Str("video_id", streamID).
|
||||
Int("len", bitrateHistory.len()).
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Int("bitrate", bitrate).
|
||||
Msg("change video bitrate")
|
||||
|
||||
return receiver.SetStream(stream)
|
||||
})
|
||||
|
||||
receiver.OnVideoChange(func(videoID string) (bool, error) {
|
||||
stream := manager.streams[videoID]
|
||||
manager.logger.Info().
|
||||
Str("video_id", videoID).
|
||||
Msg("video change")
|
||||
|
||||
return receiver.SetStream(stream)
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
|
||||
type streamDiff struct {
|
||||
id string
|
||||
bitrateDiff int
|
||||
}
|
||||
|
||||
sortDiff := func(a, b int) bool {
|
||||
switch {
|
||||
case a < 0 && b < 0:
|
||||
return a > b
|
||||
case a >= 0:
|
||||
if b >= 0 {
|
||||
return a <= b
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var diffs []streamDiff
|
||||
|
||||
for _, stream := range manager.streams {
|
||||
diffs = append(diffs, streamDiff{
|
||||
id: stream.ID(),
|
||||
bitrateDiff: peerBitrate - stream.Bitrate(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(diffs, func(i, j int) bool {
|
||||
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
||||
})
|
||||
|
||||
bestDiff := diffs[0]
|
||||
|
||||
return manager.streams[bestDiff.id]
|
||||
}
|
||||
|
||||
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
|
||||
receiver.OnBitrateChange(nil)
|
||||
receiver.OnVideoChange(nil)
|
||||
receiver.RemoveStream()
|
||||
return nil
|
||||
}
|
@ -1,83 +0,0 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
|
||||
type fields struct {
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
}
|
||||
type args struct {
|
||||
peerBitrate int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want types.StreamSinkManager
|
||||
}{
|
||||
{
|
||||
name: "findNearestStream",
|
||||
fields: fields{
|
||||
streams: map[string]types.StreamSinkManager{
|
||||
"1": mockStreamSink{
|
||||
id: "1",
|
||||
bitrate: 500,
|
||||
},
|
||||
"2": mockStreamSink{
|
||||
id: "2",
|
||||
bitrate: 750,
|
||||
},
|
||||
"3": mockStreamSink{
|
||||
id: "3",
|
||||
bitrate: 1000,
|
||||
},
|
||||
"4": mockStreamSink{
|
||||
id: "4",
|
||||
bitrate: 1250,
|
||||
},
|
||||
"5": mockStreamSink{
|
||||
id: "5",
|
||||
bitrate: 1700,
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
peerBitrate: 950,
|
||||
},
|
||||
want: mockStreamSink{
|
||||
id: "2",
|
||||
bitrate: 750,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
|
||||
|
||||
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockStreamSink struct {
|
||||
id string
|
||||
bitrate int
|
||||
types.StreamSinkManager
|
||||
}
|
||||
|
||||
func (m mockStreamSink) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m mockStreamSink) Bitrate() int {
|
||||
return m.bitrate
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
package buckets
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type queue struct {
|
||||
sync.Mutex
|
||||
q []elem
|
||||
}
|
||||
|
||||
type elem struct {
|
||||
created time.Time
|
||||
bitrate int
|
||||
}
|
||||
|
||||
func (q *queue) push(v elem) {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
|
||||
// if the first element is older than 10 seconds, remove it
|
||||
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
|
||||
q.q = q.q[1:]
|
||||
}
|
||||
q.q = append(q.q, v)
|
||||
}
|
||||
|
||||
func (q *queue) len() int {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
return len(q.q)
|
||||
}
|
||||
|
||||
func (q *queue) avg() int {
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
if len(q.q) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0
|
||||
for _, v := range q.q {
|
||||
sum += v.bitrate
|
||||
}
|
||||
return sum / len(q.q)
|
||||
}
|
||||
|
||||
func (q *queue) avgLastN(n int) int {
|
||||
if n <= 0 {
|
||||
return q.avg()
|
||||
}
|
||||
q.Lock()
|
||||
defer q.Unlock()
|
||||
if len(q.q) == 0 {
|
||||
return 0
|
||||
}
|
||||
sum := 0
|
||||
for _, v := range q.q[len(q.q)-n:] {
|
||||
sum += v.bitrate
|
||||
}
|
||||
return sum / n
|
||||
}
|
||||
|
||||
func (q *queue) normaliseBitrate(currentBitrate int) int {
|
||||
avgBitrate := float64(q.avg())
|
||||
histLen := float64(q.len())
|
||||
|
||||
q.push(elem{
|
||||
bitrate: currentBitrate,
|
||||
created: time.Now(),
|
||||
})
|
||||
|
||||
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
|
||||
return currentBitrate
|
||||
}
|
||||
|
||||
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
|
||||
if lastN > q.len() {
|
||||
lastN = q.len()
|
||||
}
|
||||
|
||||
if lastN == 0 {
|
||||
return currentBitrate
|
||||
}
|
||||
|
||||
return q.avgLastN(lastN)
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
package buckets
|
||||
|
||||
import "testing"
|
||||
|
||||
func Queue_normaliseBitrate(t *testing.T) {
|
||||
type fields struct {
|
||||
queue *queue
|
||||
}
|
||||
type args struct {
|
||||
currentBitrate int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "normaliseBitrate: big drop",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 350,
|
||||
},
|
||||
want: []int{816, 700, 537, 350, 350},
|
||||
}, {
|
||||
name: "normaliseBitrate: small drop",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 700,
|
||||
},
|
||||
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
|
||||
}, {
|
||||
name: "normaliseBitrate",
|
||||
fields: fields{
|
||||
queue: &queue{
|
||||
q: []elem{
|
||||
{bitrate: 900},
|
||||
{bitrate: 750},
|
||||
{bitrate: 780},
|
||||
{bitrate: 1100},
|
||||
{bitrate: 950},
|
||||
{bitrate: 700},
|
||||
{bitrate: 800},
|
||||
{bitrate: 900},
|
||||
{bitrate: 1000},
|
||||
{bitrate: 1100},
|
||||
// avg = 898
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
currentBitrate: 1350,
|
||||
},
|
||||
want: []int{943, 1003, 1060, 1085},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := tt.fields.queue
|
||||
for i := 0; i < len(tt.want); i++ {
|
||||
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
|
||||
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -8,7 +8,6 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/capture/buckets"
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
@ -23,7 +22,7 @@ type CaptureManagerCtx struct {
|
||||
broadcast *BroacastManagerCtx
|
||||
screencast *ScreencastManagerCtx
|
||||
audio *StreamSinkManagerCtx
|
||||
video types.BucketsManager
|
||||
video *StreamSelectorManagerCtx
|
||||
|
||||
// sources
|
||||
webcam *StreamSrcManagerCtx
|
||||
@ -68,13 +67,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
||||
Str("pipeline", pipeline).
|
||||
Msg("syntax check for video stream pipeline passed")
|
||||
|
||||
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
|
||||
if _, err = getVideoBitrate(); err != nil {
|
||||
logger.Panic().Err(err).Msg("unable to get video bitrate")
|
||||
}
|
||||
|
||||
// append to videos
|
||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
|
||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
|
||||
}
|
||||
|
||||
return &CaptureManagerCtx{
|
||||
@ -140,8 +134,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
||||
"! %s "+
|
||||
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
|
||||
), nil
|
||||
}, "audio", nil),
|
||||
video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs),
|
||||
}, "audio"),
|
||||
video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
|
||||
|
||||
// sources
|
||||
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
|
||||
@ -202,7 +196,7 @@ func (manager *CaptureManagerCtx) Start() {
|
||||
}
|
||||
|
||||
manager.desktop.OnBeforeScreenSizeChange(func() {
|
||||
manager.video.DestroyAll()
|
||||
manager.video.destroyPipelines()
|
||||
|
||||
if manager.broadcast.Started() {
|
||||
manager.broadcast.destroyPipeline()
|
||||
@ -214,7 +208,7 @@ func (manager *CaptureManagerCtx) Start() {
|
||||
})
|
||||
|
||||
manager.desktop.OnAfterScreenSizeChange(func() {
|
||||
err := manager.video.RecreateAll()
|
||||
err := manager.video.recreatePipelines()
|
||||
if err != nil {
|
||||
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
|
||||
}
|
||||
@ -242,7 +236,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
|
||||
manager.screencast.shutdown()
|
||||
|
||||
manager.audio.shutdown()
|
||||
manager.video.Shutdown()
|
||||
manager.video.shutdown()
|
||||
|
||||
manager.webcam.shutdown()
|
||||
manager.microphone.shutdown()
|
||||
@ -250,15 +244,6 @@ func (manager *CaptureManagerCtx) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
|
||||
cfg, ok := manager.config.VideoPipelines[videoID]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("video config not found for %s", videoID)
|
||||
}
|
||||
|
||||
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
|
||||
return manager.broadcast
|
||||
}
|
||||
@ -271,7 +256,7 @@ func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
|
||||
return manager.audio
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Video() types.BucketsManager {
|
||||
func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
|
||||
return manager.video
|
||||
}
|
||||
|
||||
|
206
internal/capture/streamselector.go
Normal file
206
internal/capture/streamselector.go
Normal file
@ -0,0 +1,206 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type StreamSelectorManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
streamIDs []string
|
||||
}
|
||||
|
||||
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-selector").
|
||||
Logger()
|
||||
|
||||
return &StreamSelectorManagerCtx{
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
streams: streams,
|
||||
streamIDs: streamIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.destroyPipelines()
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
stream.DestroyPipeline()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
err := stream.CreatePipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) IDs() []string {
|
||||
return manager.streamIDs
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
|
||||
// select stream by ID
|
||||
if selector.ID != "" {
|
||||
// select lower stream
|
||||
if selector.Type == types.StreamSelectorTypeLower {
|
||||
var lastStream types.StreamSinkManager
|
||||
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||
streamID := manager.streamIDs[i]
|
||||
if streamID == selector.ID {
|
||||
return lastStream, lastStream != nil
|
||||
}
|
||||
stream, ok := manager.streams[streamID]
|
||||
if ok {
|
||||
lastStream = stream
|
||||
}
|
||||
}
|
||||
// we couldn't find a lower stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select higher stream
|
||||
if selector.Type == types.StreamSelectorTypeHigher {
|
||||
var lastStream types.StreamSinkManager
|
||||
for _, streamID := range manager.streamIDs {
|
||||
if streamID == selector.ID {
|
||||
return lastStream, lastStream != nil
|
||||
}
|
||||
stream, ok := manager.streams[streamID]
|
||||
if ok {
|
||||
lastStream = stream
|
||||
}
|
||||
}
|
||||
// we couldn't find a higher stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select exact stream
|
||||
stream, ok := manager.streams[selector.ID]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
// select stream by bitrate
|
||||
if selector.Bitrate != 0 {
|
||||
// select stream by nearest bitrate
|
||||
if selector.Type == types.StreamSelectorTypeNearest {
|
||||
return manager.nearestBitrate(selector.Bitrate), true
|
||||
}
|
||||
|
||||
// select lower stream
|
||||
if selector.Type == types.StreamSelectorTypeLower {
|
||||
// start from the highest stream, and go down, until we find a lower stream
|
||||
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||
streamID := manager.streamIDs[i]
|
||||
stream := manager.streams[streamID]
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if considered && stream.Bitrate() < selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
// we couldn't find a lower stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select higher stream
|
||||
if selector.Type == types.StreamSelectorTypeHigher {
|
||||
// start from the lowest stream, and go up, until we find a higher stream
|
||||
for _, streamID := range manager.streamIDs {
|
||||
stream := manager.streams[streamID]
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if considered && stream.Bitrate() > selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
// we couldn't find a higher stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select stream by exact bitrate
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Bitrate() == selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we couldn't find a stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// TODO: This is a very naive implementation, we should use a binary search instead.
|
||||
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
|
||||
type streamDiff struct {
|
||||
id string
|
||||
bitrateDiff int
|
||||
}
|
||||
|
||||
sortDiff := func(a, b int) bool {
|
||||
switch {
|
||||
case a < 0 && b < 0:
|
||||
return a > b
|
||||
case a >= 0:
|
||||
if b >= 0 {
|
||||
return a <= b
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var diffs []streamDiff
|
||||
|
||||
for _, stream := range manager.streams {
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if !considered {
|
||||
continue
|
||||
}
|
||||
diffs = append(diffs, streamDiff{
|
||||
id: stream.ID(),
|
||||
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
|
||||
})
|
||||
}
|
||||
|
||||
// no streams available
|
||||
if len(diffs) == 0 {
|
||||
// return first (lowest) stream
|
||||
return manager.streams[manager.streamIDs[0]]
|
||||
}
|
||||
|
||||
sort.Slice(diffs, func(i, j int) bool {
|
||||
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
||||
})
|
||||
|
||||
bestDiff := diffs[0]
|
||||
return manager.streams[bestDiff.id]
|
||||
}
|
@ -22,8 +22,9 @@ var moveSinkListenerMu = sync.Mutex{}
|
||||
|
||||
type StreamSinkManagerCtx struct {
|
||||
id string
|
||||
getBitrate func() (int, error)
|
||||
waitForKf bool // wait for a keyframe before sending samples
|
||||
|
||||
// wait for a keyframe before sending samples
|
||||
waitForKf bool
|
||||
|
||||
bitrate uint64 // atomic
|
||||
brBuckets map[int]float64
|
||||
@ -48,7 +49,7 @@ type StreamSinkManagerCtx struct {
|
||||
pipelinesActive prometheus.Gauge
|
||||
}
|
||||
|
||||
func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
|
||||
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-sink").
|
||||
@ -56,14 +57,15 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
||||
|
||||
manager := &StreamSinkManagerCtx{
|
||||
id: id,
|
||||
getBitrate: getBitrate,
|
||||
// only wait for keyframes if the codec is video
|
||||
waitForKf: c.IsVideo(),
|
||||
|
||||
// only wait for keyframes if the codec is video
|
||||
waitForKf: codec.IsVideo(),
|
||||
|
||||
bitrate: 0,
|
||||
brBuckets: map[int]float64{},
|
||||
|
||||
logger: logger,
|
||||
codec: c,
|
||||
codec: codec,
|
||||
pipelineFn: pipelineFn,
|
||||
|
||||
listeners: map[uintptr]types.SampleListener{},
|
||||
@ -77,8 +79,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
||||
Help: "Current number of listeners for a pipeline.",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": id,
|
||||
"codec_name": c.Name,
|
||||
"codec_type": c.Type.String(),
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
|
||||
@ -88,8 +90,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
||||
Help: "Total number of bytes created by the pipeline.",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": id,
|
||||
"codec_name": c.Name,
|
||||
"codec_type": c.Type.String(),
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||
@ -100,8 +102,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": id,
|
||||
"codec_name": c.Name,
|
||||
"codec_type": c.Type.String(),
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
@ -112,8 +114,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": id,
|
||||
"codec_name": c.Name,
|
||||
"codec_type": c.Type.String(),
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
}
|
||||
@ -141,27 +143,8 @@ func (manager *StreamSinkManagerCtx) ID() string {
|
||||
return manager.id
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Bitrate() int {
|
||||
// TODO: fix bitrate switching calculation
|
||||
// return real bitrate if available
|
||||
//realBitrate := atomic.LoadUint64(&manager.bitrate)
|
||||
//if realBitrate != 0 {
|
||||
// return int(realBitrate)
|
||||
//}
|
||||
|
||||
// if we do not have function to estimate bitrate, return 0
|
||||
if manager.getBitrate == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// recalculate bitrate every time, take screen resolution (and fps) into account
|
||||
// we called this function during startup, so it shouldn't error here
|
||||
bitrate, err := manager.getBitrate()
|
||||
if err != nil {
|
||||
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
|
||||
}
|
||||
|
||||
return bitrate
|
||||
func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
|
||||
return atomic.LoadUint64(&manager.bitrate)
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
|
||||
|
@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
@ -15,6 +16,28 @@ import (
|
||||
// default stun server
|
||||
const defStunSrv = "stun:stun.l.google.com:19302"
|
||||
|
||||
type WebRTCEstimator struct {
|
||||
Enabled bool
|
||||
Passive bool
|
||||
Debug bool
|
||||
InitialBitrate int
|
||||
|
||||
// how often to read and process bandwidth estimation reports
|
||||
ReadInterval time.Duration
|
||||
// how long to wait for stable connection (only neutral or upward trend) before upgrading
|
||||
StableDuration time.Duration
|
||||
// how long to wait for unstable connection (downward trend) before downgrading
|
||||
UnstableDuration time.Duration
|
||||
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
|
||||
StalledDuration time.Duration
|
||||
// how long to wait before downgrading again after previous downgrade
|
||||
DowngradeBackoff time.Duration
|
||||
// how long to wait before upgrading again after previous upgrade
|
||||
UpgradeBackoff time.Duration
|
||||
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
|
||||
DiffThreshold float64
|
||||
}
|
||||
|
||||
type WebRTC struct {
|
||||
ICELite bool
|
||||
ICETrickle bool
|
||||
@ -28,9 +51,7 @@ type WebRTC struct {
|
||||
NAT1To1IPs []string
|
||||
IpRetrievalUrl string
|
||||
|
||||
EstimatorEnabled bool
|
||||
EstimatorPassive bool
|
||||
EstimatorInitialBitrate int
|
||||
Estimator WebRTCEstimator
|
||||
}
|
||||
|
||||
func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
@ -96,11 +117,51 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
|
||||
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
|
||||
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
|
||||
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -197,7 +258,15 @@ func (s *WebRTC) Set() {
|
||||
|
||||
// bandwidth estimator
|
||||
|
||||
s.EstimatorEnabled = viper.GetBool("webrtc.estimator.enabled")
|
||||
s.EstimatorPassive = viper.GetBool("webrtc.estimator.passive")
|
||||
s.EstimatorInitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
|
||||
s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
|
||||
s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
|
||||
s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
|
||||
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
|
||||
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
|
||||
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
|
||||
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
|
||||
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
|
||||
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
|
||||
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
|
||||
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -167,7 +168,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
|
||||
return manager.config.ICEServersFrontend
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||
// create media engine
|
||||
engine := &webrtc.MediaEngine{}
|
||||
for _, codec := range codecs {
|
||||
@ -223,14 +224,10 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
|
||||
|
||||
// create bandwidth estimator
|
||||
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
||||
if manager.config.EstimatorEnabled {
|
||||
if manager.config.Estimator.Enabled {
|
||||
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
||||
if bitrate == 0 {
|
||||
bitrate = manager.config.EstimatorInitialBitrate
|
||||
}
|
||||
|
||||
return gcc.NewSendSideBWE(
|
||||
gcc.SendSideBWEInitialBitrate(bitrate),
|
||||
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
|
||||
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
|
||||
)
|
||||
})
|
||||
@ -268,7 +265,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
|
||||
return connection, <-estimatorChan, err
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
|
||||
id := atomic.AddInt32(&manager.peerId, 1)
|
||||
|
||||
// get metrics for session
|
||||
@ -287,12 +284,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
video := manager.capture.Video()
|
||||
videoCodec := video.Codec()
|
||||
|
||||
connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{
|
||||
audioCodec,
|
||||
videoCodec,
|
||||
}, bitrate)
|
||||
connection, estimator, err := manager.newPeerConnection(
|
||||
logger, []codec.RTPCodec{audioCodec, videoCodec})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// asynchronously send local ICE Candidates
|
||||
@ -311,47 +306,34 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
})
|
||||
}
|
||||
|
||||
// if bitrate is 0, and estimator is enabled, use estimator bitrate
|
||||
if bitrate == 0 && estimator != nil {
|
||||
bitrate = estimator.GetTargetBitrate()
|
||||
}
|
||||
|
||||
// audio track
|
||||
audioTrack, err := NewTrack(logger, audioCodec, connection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// set stream for audio track
|
||||
_, err = audioTrack.SetStream(audio)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// if estimator is disabled, or in passive mode, disable video auto bitrate
|
||||
if !manager.config.EstimatorEnabled || manager.config.EstimatorPassive {
|
||||
videoAuto = false
|
||||
}
|
||||
|
||||
videoRtcp := make(chan []rtcp.Packet, 1)
|
||||
|
||||
// video track
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection,
|
||||
WithVideoAuto(videoAuto),
|
||||
WithRtcpChan(videoRtcp),
|
||||
)
|
||||
videoRtcp := make(chan []rtcp.Packet, 1)
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// let video stream bucket manager handle stream subscriptions
|
||||
video.SetReceiver(videoTrack)
|
||||
//
|
||||
// stream for video track will be set later
|
||||
//
|
||||
|
||||
// data channel
|
||||
|
||||
dataChannel, err := connection.CreateDataChannel("data", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer := &WebRTCPeerCtx{
|
||||
@ -359,7 +341,21 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
session: session,
|
||||
metrics: metrics,
|
||||
connection: connection,
|
||||
// bandwidth estimator
|
||||
estimator: estimator,
|
||||
estimateTrend: utils.NewTrendDetector(
|
||||
utils.TrendDetectorParams{
|
||||
// Probing
|
||||
//RequiredSamples: 3,
|
||||
//DownwardTrendThreshold: 0.0,
|
||||
//CollapseValues: false,
|
||||
// Non-Probing
|
||||
RequiredSamples: 8,
|
||||
DownwardTrendThreshold: -0.5,
|
||||
CollapseValues: true,
|
||||
}),
|
||||
// stream selectors
|
||||
videoSelector: manager.capture.Video(),
|
||||
// tracks & channels
|
||||
audioTrack: audioTrack,
|
||||
videoTrack: videoTrack,
|
||||
@ -367,16 +363,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
rtcpChannel: videoRtcp,
|
||||
// config
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
estimatorPassive: manager.config.EstimatorPassive,
|
||||
}
|
||||
|
||||
logger.Info().
|
||||
Int("target_bitrate", bitrate).
|
||||
Msg("estimated initial peer bitrate")
|
||||
|
||||
// set initial video bitrate
|
||||
if err := peer.SetVideoBitrate(bitrate); err != nil {
|
||||
return nil, err
|
||||
estimatorConfig: manager.config.Estimator,
|
||||
}
|
||||
|
||||
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
@ -492,9 +479,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
// ensure we only run this once
|
||||
once.Do(func() {
|
||||
session.SetWebRTCConnected(peer, false)
|
||||
if err = video.RemoveReceiver(videoTrack); err != nil {
|
||||
logger.Err(err).Msg("failed to remove video receiver")
|
||||
}
|
||||
//
|
||||
// TODO: Shutdown peer?
|
||||
//
|
||||
audioTrack.Shutdown()
|
||||
videoTrack.Shutdown()
|
||||
close(videoRtcp)
|
||||
@ -542,7 +529,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
|
||||
offer, err := peer.CreateOffer(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// on negotiation needed handler must be registered after creating initial
|
||||
@ -576,7 +563,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
||||
// start estimator reader
|
||||
go peer.estimatorReader()
|
||||
|
||||
return offer, nil
|
||||
return offer, peer, nil
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
|
||||
|
@ -121,7 +121,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
Name: "receiver_estimated_maximum_bitrate",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Estimated Maximum Bitrate from SCTP.",
|
||||
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
@ -140,7 +140,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
Name: "receiver_report_delay",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Delay from SCTP, expressed in units of 1/65536 seconds.",
|
||||
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
@ -149,7 +149,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
Name: "receiver_report_jitter",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Jitter from SCTP.",
|
||||
Help: "Receiver Report Jitter from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
@ -158,7 +158,17 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
Name: "receiver_report_total_lost",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Total Lost from SCTP.",
|
||||
Help: "Receiver Report Total Lost from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "transport_layer_nacks",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Transport Layer NACKs from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
@ -236,6 +246,8 @@ type metrics struct {
|
||||
receiverReportJitter prometheus.Gauge
|
||||
receiverReportTotalLost prometheus.Gauge
|
||||
|
||||
transportLayerNacks prometheus.Counter
|
||||
|
||||
iceBytesSent prometheus.Gauge
|
||||
iceBytesReceived prometheus.Gauge
|
||||
sctpBytesSent prometheus.Gauge
|
||||
@ -386,6 +398,11 @@ func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
|
||||
// use only last report
|
||||
met.SetReceiverReport(rtcpPacket.Reports[l-1])
|
||||
}
|
||||
case *rtcp.TransportLayerNack:
|
||||
for _, pair := range rtcpPacket.Nacks {
|
||||
packetList := pair.PacketList()
|
||||
met.transportLayerNacks.Add(float64(len(packetList)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,15 +11,12 @@ import (
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
const (
|
||||
// how often to read and process bandwidth estimation reports
|
||||
estimatorReadInterval = 250 * time.Millisecond
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type WebRTCPeerCtx struct {
|
||||
@ -28,7 +25,11 @@ type WebRTCPeerCtx struct {
|
||||
session types.Session
|
||||
metrics *metrics
|
||||
connection *webrtc.PeerConnection
|
||||
// bandwidth estimator
|
||||
estimator cc.BandwidthEstimator
|
||||
estimateTrend *utils.TrendDetector
|
||||
// stream selectors
|
||||
videoSelector types.StreamSelectorManager
|
||||
// tracks & channels
|
||||
audioTrack *Track
|
||||
videoTrack *Track
|
||||
@ -36,7 +37,8 @@ type WebRTCPeerCtx struct {
|
||||
rtcpChannel chan []rtcp.Packet
|
||||
// config
|
||||
iceTrickle bool
|
||||
estimatorPassive bool
|
||||
estimatorConfig config.WebRTCEstimator
|
||||
videoAuto bool
|
||||
}
|
||||
|
||||
//
|
||||
@ -102,6 +104,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
|
||||
return peer.connection.AddICECandidate(candidate)
|
||||
}
|
||||
|
||||
// TODO: Add shutdown function?
|
||||
func (peer *WebRTCPeerCtx) Destroy() {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
@ -111,32 +114,186 @@ func (peer *WebRTCPeerCtx) Destroy() {
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) estimatorReader() {
|
||||
conf := peer.estimatorConfig
|
||||
|
||||
// if estimator is not in debug mode, use a nop logger
|
||||
var debugLogger zerolog.Logger
|
||||
if conf.Debug {
|
||||
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
|
||||
} else {
|
||||
debugLogger = zerolog.Nop()
|
||||
}
|
||||
|
||||
// if estimator is disabled, do nothing
|
||||
if peer.estimator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// use a ticker to get current client target bitrate
|
||||
ticker := time.NewTicker(estimatorReadInterval)
|
||||
ticker := time.NewTicker(conf.ReadInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// since when is the estimate stable/unstable
|
||||
stableSince := time.Now() // we asume stable at start
|
||||
unstableSince := time.Time{}
|
||||
// since when are we neutral but cannot accomodate current bitrate
|
||||
// we migt be stalled or estimator just reached zer (very bad connection)
|
||||
stalledSince := time.Time{}
|
||||
// when was the last upgrade/downgrade
|
||||
lastUpgradeTime := time.Time{}
|
||||
lastDowngradeTime := time.Time{}
|
||||
|
||||
for range ticker.C {
|
||||
targetBitrate := peer.estimator.GetTargetBitrate()
|
||||
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
|
||||
|
||||
// if peer connection is closed, stop reading
|
||||
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
|
||||
if !peer.videoTrack.VideoAuto() {
|
||||
// if estimation is disabled, do nothing
|
||||
if !peer.videoAuto || conf.Passive {
|
||||
continue
|
||||
}
|
||||
|
||||
if !peer.estimatorPassive {
|
||||
err := peer.SetVideoBitrate(targetBitrate)
|
||||
if err != nil {
|
||||
peer.logger.Warn().Err(err).Msg("failed to set video bitrate")
|
||||
// get trend direction to decide if we should upgrade or downgrade
|
||||
peer.estimateTrend.AddValue(int64(targetBitrate))
|
||||
direction := peer.estimateTrend.GetDirection()
|
||||
|
||||
// get current stream bitrate
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if !ok {
|
||||
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
|
||||
continue
|
||||
}
|
||||
|
||||
// if stream bitrate is 0, we need to wait for some time until we get a valid value
|
||||
streamId, streamBitrate := stream.ID(), stream.Bitrate()
|
||||
if streamBitrate == 0 {
|
||||
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
|
||||
continue
|
||||
}
|
||||
|
||||
// check whats the difference between target and stream bitrate
|
||||
diff := float64(targetBitrate) / float64(streamBitrate)
|
||||
|
||||
debugLogger.Info().
|
||||
Float64("diff", diff).
|
||||
Int("target_bitrate", targetBitrate).
|
||||
Uint64("stream_bitrate", streamBitrate).
|
||||
Str("direction", direction.String()).
|
||||
Msg("got bitrate from estimator")
|
||||
|
||||
// if we can accomodate current stream or we are not netural anymore,
|
||||
// we are not stalled so we reset the stalled time
|
||||
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
|
||||
stalledSince = time.Now()
|
||||
}
|
||||
|
||||
// if we are neutral and stalled for too long, we might be congesting
|
||||
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
|
||||
if stalled {
|
||||
debugLogger.Warn().
|
||||
Time("stalled_since", stalledSince).
|
||||
Msgf("it looks like we are stalled")
|
||||
}
|
||||
|
||||
// if we have an downward trend or are stalled, we might be congesting
|
||||
if direction == utils.TrendDirectionDownward || stalled {
|
||||
// we reset the stable time because we are congesting
|
||||
stableSince = time.Now()
|
||||
|
||||
// if we downgraded recently, we wait for some more time
|
||||
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_downgrade", lastDowngradeTime).
|
||||
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not unstable but we fluctuate we should wait for some more time
|
||||
if time.Since(unstableSince) < conf.UnstableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("unstable_since", unstableSince).
|
||||
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we still have a big difference between target and stream bitrate, we wait for some more time
|
||||
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("we still have a big difference between target and stream bitrate, " +
|
||||
"therefore we still should be able to accomodate current stream")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeLower,
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
|
||||
}
|
||||
lastDowngradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the lowest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("downgraded video stream")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// we reset the unstable time because we are not congesting
|
||||
unstableSince = time.Now()
|
||||
|
||||
// if we have a neutral or upward trend, that means our estimate is stable
|
||||
// if we are on the highest stream, we don't need to do anything
|
||||
// but if there is a higher stream, we should try to upgrade and see if it works
|
||||
|
||||
// if we upgraded recently, we wait for some more time
|
||||
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_upgrade", lastUpgradeTime).
|
||||
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not stable for long enough, we wait for some more time
|
||||
// because bandwidth estimation might fluctuate
|
||||
if time.Since(stableSince) < conf.StableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("stable_since", stableSince).
|
||||
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// upgrade only if estimated bitrate passed the threshold
|
||||
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
|
||||
"therefore we should wait for some more time")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeHigher,
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
|
||||
}
|
||||
lastUpgradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the highest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("upgraded video stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -145,88 +302,52 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
|
||||
// video
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
|
||||
func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// when switching from manual to auto bitrate estimation, in case the estimator is
|
||||
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
|
||||
if peerBitrate == 0 && peer.estimator != nil && !peer.estimatorPassive {
|
||||
peerBitrate = peer.estimator.GetTargetBitrate()
|
||||
peer.logger.Debug().
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Msg("evaluated bitrate")
|
||||
// get requested video stream from selector
|
||||
stream, ok := peer.videoSelector.GetStream(selector)
|
||||
if !ok {
|
||||
return types.ErrWebRTCStreamNotFound
|
||||
}
|
||||
|
||||
changed, err := peer.videoTrack.SetBitrate(peerBitrate)
|
||||
// set video stream to track
|
||||
changed, err := peer.videoTrack.SetStream(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if video stream was already set, do nothing
|
||||
if !changed {
|
||||
// TODO: return error?
|
||||
return nil
|
||||
}
|
||||
|
||||
videoID := peer.videoTrack.stream.ID()
|
||||
bitrate := peer.videoTrack.stream.Bitrate()
|
||||
|
||||
videoID := stream.ID()
|
||||
peer.metrics.SetVideoID(videoID)
|
||||
peer.logger.Debug().
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Int("video_bitrate", bitrate).
|
||||
Str("video_id", videoID).
|
||||
Msg("peer bitrate triggered video stream change")
|
||||
|
||||
peer.logger.Info().Str("video_id", videoID).Msg("set video")
|
||||
|
||||
go peer.session.Send(
|
||||
event.SIGNAL_VIDEO,
|
||||
message.SignalVideo{
|
||||
Video: videoID,
|
||||
Bitrate: bitrate,
|
||||
VideoAuto: peer.videoTrack.VideoAuto(),
|
||||
Auto: peer.videoAuto,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
|
||||
func (peer *WebRTCPeerCtx) VideoID() (string, bool) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
changed, err := peer.videoTrack.SetVideoID(videoID)
|
||||
if err != nil {
|
||||
return err
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if !changed {
|
||||
// TODO: return error?
|
||||
return nil
|
||||
}
|
||||
|
||||
bitrate := peer.videoTrack.stream.Bitrate()
|
||||
|
||||
peer.logger.Debug().
|
||||
Str("video_id", videoID).
|
||||
Int("video_bitrate", bitrate).
|
||||
Msg("peer video id triggered video stream change")
|
||||
|
||||
go peer.session.Send(
|
||||
event.SIGNAL_VIDEO,
|
||||
message.SignalVideo{
|
||||
Video: videoID,
|
||||
Bitrate: bitrate,
|
||||
VideoAuto: peer.videoTrack.VideoAuto(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) GetVideoID() string {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// TODO: Refactor.
|
||||
return peer.videoTrack.stream.ID()
|
||||
return stream.ID(), true
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||
@ -239,18 +360,32 @@ func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Paused() bool {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.videoTrack.Paused() || peer.audioTrack.Paused()
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// if estimator is enabled and is not passive, enable video auto bitrate
|
||||
if peer.estimator != nil && !peer.estimatorPassive {
|
||||
peer.videoTrack.SetVideoAuto(videoAuto)
|
||||
if peer.estimator != nil && !peer.estimatorConfig.Passive {
|
||||
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
|
||||
peer.videoAuto = videoAuto
|
||||
} else {
|
||||
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
|
||||
peer.videoTrack.SetVideoAuto(false) // ensure video auto is disabled
|
||||
peer.videoAuto = false // ensure video auto is disabled
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) VideoAuto() bool {
|
||||
return peer.videoTrack.VideoAuto()
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.videoAuto
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -2,7 +2,6 @@ package webrtc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
@ -22,25 +21,13 @@ type Track struct {
|
||||
rtcpCh chan []rtcp.Packet
|
||||
sample chan types.Sample
|
||||
|
||||
videoAuto bool
|
||||
videoAutoMu sync.RWMutex
|
||||
|
||||
paused bool
|
||||
stream types.StreamSinkManager
|
||||
streamMu sync.Mutex
|
||||
|
||||
bitrateChange func(int) (bool, error)
|
||||
videoChange func(string) (bool, error)
|
||||
}
|
||||
|
||||
type trackOption func(*Track)
|
||||
|
||||
func WithVideoAuto(auto bool) trackOption {
|
||||
return func(t *Track) {
|
||||
t.videoAuto = auto
|
||||
}
|
||||
}
|
||||
|
||||
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
|
||||
return func(t *Track) {
|
||||
t.rtcpCh = rtcp
|
||||
@ -100,6 +87,8 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- sample ---
|
||||
|
||||
func (t *Track) sampleReader() {
|
||||
for {
|
||||
sample, ok := <-t.sample
|
||||
@ -120,6 +109,12 @@ func (t *Track) sampleReader() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Track) WriteSample(sample types.Sample) {
|
||||
t.sample <- sample
|
||||
}
|
||||
|
||||
// --- stream ---
|
||||
|
||||
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
@ -167,6 +162,15 @@ func (t *Track) RemoveStream() {
|
||||
t.stream = nil
|
||||
}
|
||||
|
||||
func (t *Track) Stream() (types.StreamSinkManager, bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.stream, t.stream != nil
|
||||
}
|
||||
|
||||
// --- paused ---
|
||||
|
||||
func (t *Track) SetPaused(paused bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
@ -190,42 +194,9 @@ func (t *Track) SetPaused(paused bool) {
|
||||
t.paused = paused
|
||||
}
|
||||
|
||||
func (t *Track) WriteSample(sample types.Sample) {
|
||||
t.sample <- sample
|
||||
}
|
||||
|
||||
func (t *Track) SetBitrate(bitrate int) (bool, error) {
|
||||
if t.bitrateChange == nil {
|
||||
return false, fmt.Errorf("bitrate change not supported")
|
||||
}
|
||||
|
||||
return t.bitrateChange(bitrate)
|
||||
}
|
||||
|
||||
func (t *Track) SetVideoID(videoID string) (bool, error) {
|
||||
if t.videoChange == nil {
|
||||
return false, fmt.Errorf("video change not supported")
|
||||
}
|
||||
|
||||
return t.videoChange(videoID)
|
||||
}
|
||||
|
||||
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
|
||||
t.bitrateChange = f
|
||||
}
|
||||
|
||||
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
|
||||
t.videoChange = f
|
||||
}
|
||||
|
||||
func (t *Track) SetVideoAuto(auto bool) {
|
||||
t.videoAutoMu.Lock()
|
||||
defer t.videoAutoMu.Unlock()
|
||||
t.videoAuto = auto
|
||||
}
|
||||
|
||||
func (t *Track) VideoAuto() bool {
|
||||
t.videoAutoMu.RLock()
|
||||
defer t.videoAutoMu.RUnlock()
|
||||
return t.videoAuto
|
||||
func (t *Track) Paused() bool {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.paused
|
||||
}
|
||||
|
@ -20,28 +20,26 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
||||
payload.Video = videos[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
if payload.Bitrate == 0 {
|
||||
// get bitrate from video id
|
||||
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
|
||||
offer, peer, err := h.webrtc.CreatePeer(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
|
||||
// set webrtc as paused if session has private mode enabled
|
||||
if session.PrivateModeEnabled() {
|
||||
webrtcPeer.SetPaused(true)
|
||||
peer.SetPaused(true)
|
||||
}
|
||||
|
||||
payload.Video = webrtcPeer.GetVideoID()
|
||||
payload.VideoAuto = webrtcPeer.VideoAuto()
|
||||
// set video auto state
|
||||
peer.SetVideoAuto(payload.Auto)
|
||||
|
||||
// set video stream
|
||||
err = peer.SetVideo(types.StreamSelector{
|
||||
ID: payload.Video,
|
||||
Type: types.StreamSelectorTypeNearest,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.Send(
|
||||
@ -49,9 +47,6 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
||||
message.SignalProvide{
|
||||
SDP: offer.SDP,
|
||||
ICEServers: h.webrtc.ICEServers(),
|
||||
Video: payload.Video, // TODO: Refactor
|
||||
Bitrate: payload.Bitrate,
|
||||
VideoAuto: payload.VideoAuto,
|
||||
})
|
||||
|
||||
return nil
|
||||
@ -133,16 +128,13 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
peer.SetVideoAuto(payload.VideoAuto)
|
||||
peer.SetVideoAuto(payload.Auto)
|
||||
|
||||
if payload.Video != "" {
|
||||
if err := peer.SetVideoID(payload.Video); err != nil {
|
||||
h.logger.Error().Err(err).Msg("failed to set video id")
|
||||
}
|
||||
} else {
|
||||
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
|
||||
h.logger.Error().Err(err).Msg("failed to set video bitrate")
|
||||
}
|
||||
return peer.SetVideo(types.StreamSelector{
|
||||
ID: payload.Video,
|
||||
Type: types.StreamSelectorTypeNearest,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -31,26 +31,6 @@ type SampleListener interface {
|
||||
WriteSample(Sample)
|
||||
}
|
||||
|
||||
type Receiver interface {
|
||||
SetStream(stream StreamSinkManager) (changed bool, err error)
|
||||
RemoveStream()
|
||||
OnBitrateChange(f func(bitrate int) (changed bool, err error))
|
||||
OnVideoChange(f func(videoID string) (changed bool, err error))
|
||||
VideoAuto() bool
|
||||
SetVideoAuto(videoAuto bool)
|
||||
}
|
||||
|
||||
type BucketsManager interface {
|
||||
IDs() []string
|
||||
Codec() codec.RTPCodec
|
||||
SetReceiver(receiver Receiver)
|
||||
RemoveReceiver(receiver Receiver) error
|
||||
|
||||
DestroyAll()
|
||||
RecreateAll() error
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
type BroadcastManager interface {
|
||||
Start(url string) error
|
||||
Stop()
|
||||
@ -64,10 +44,74 @@ type ScreencastManager interface {
|
||||
Image() ([]byte, error)
|
||||
}
|
||||
|
||||
type StreamSelectorType int
|
||||
|
||||
const (
|
||||
// select exact stream
|
||||
StreamSelectorTypeExact StreamSelectorType = iota
|
||||
// select nearest stream (in either direction) if exact stream is not available
|
||||
StreamSelectorTypeNearest
|
||||
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
|
||||
StreamSelectorTypeLower
|
||||
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
|
||||
StreamSelectorTypeHigher
|
||||
)
|
||||
|
||||
func (s StreamSelectorType) String() string {
|
||||
switch s {
|
||||
case StreamSelectorTypeExact:
|
||||
return "exact"
|
||||
case StreamSelectorTypeNearest:
|
||||
return "nearest"
|
||||
case StreamSelectorTypeLower:
|
||||
return "lower"
|
||||
case StreamSelectorTypeHigher:
|
||||
return "higher"
|
||||
default:
|
||||
return fmt.Sprintf("%d", int(s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
|
||||
switch strings.ToLower(string(text)) {
|
||||
case "exact", "":
|
||||
*s = StreamSelectorTypeExact
|
||||
case "nearest":
|
||||
*s = StreamSelectorTypeNearest
|
||||
case "lower":
|
||||
*s = StreamSelectorTypeLower
|
||||
case "higher":
|
||||
*s = StreamSelectorTypeHigher
|
||||
default:
|
||||
return fmt.Errorf("invalid stream selector type: %s", string(text))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s StreamSelectorType) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
type StreamSelector struct {
|
||||
// type of stream selector
|
||||
Type StreamSelectorType
|
||||
// select stream by its ID
|
||||
ID string
|
||||
// select stream by its bitrate
|
||||
Bitrate uint64
|
||||
}
|
||||
|
||||
type StreamSelectorManager interface {
|
||||
IDs() []string
|
||||
Codec() codec.RTPCodec
|
||||
|
||||
GetStream(selector StreamSelector) (StreamSinkManager, bool)
|
||||
}
|
||||
|
||||
type StreamSinkManager interface {
|
||||
ID() string
|
||||
Codec() codec.RTPCodec
|
||||
Bitrate() int
|
||||
Bitrate() uint64
|
||||
|
||||
AddListener(listener SampleListener) error
|
||||
RemoveListener(listener SampleListener) error
|
||||
@ -94,12 +138,10 @@ type CaptureManager interface {
|
||||
Start()
|
||||
Shutdown() error
|
||||
|
||||
GetBitrateFromVideoID(videoID string) (int, error)
|
||||
|
||||
Broadcast() BroadcastManager
|
||||
Screencast() ScreencastManager
|
||||
Audio() StreamSinkManager
|
||||
Video() BucketsManager
|
||||
Video() StreamSelectorManager
|
||||
|
||||
Webcam() StreamSrcManager
|
||||
Microphone() StreamSrcManager
|
||||
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
|
||||
config.GstSuffix,
|
||||
}[:], " "), nil
|
||||
}
|
||||
|
||||
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
|
||||
return func() (int, error) {
|
||||
if config.Bitrate > 0 {
|
||||
return config.Bitrate, nil
|
||||
}
|
||||
|
||||
screen := getScreen()
|
||||
|
||||
values := map[string]any{
|
||||
"width": screen.Width,
|
||||
"height": screen.Height,
|
||||
"fps": screen.Rate,
|
||||
}
|
||||
|
||||
language := []gval.Language{
|
||||
gval.Function("round", func(args ...any) (any, error) {
|
||||
return (int)(math.Round(args[0].(float64))), nil
|
||||
}),
|
||||
}
|
||||
|
||||
// TOOD: do not read target-bitrate from pipeline, but only from config.
|
||||
|
||||
// TODO: This is only for vp8.
|
||||
expr, ok := config.GstParams["target-bitrate"]
|
||||
if !ok {
|
||||
// TODO: This is only for h264.
|
||||
expr, ok = config.GstParams["bitrate"]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("bitrate not found")
|
||||
}
|
||||
}
|
||||
|
||||
bitrate, err := gval.Evaluate(expr, values, language...)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
|
||||
}
|
||||
|
||||
var bitrateInt int
|
||||
switch val := bitrate.(type) {
|
||||
case int:
|
||||
bitrateInt = val
|
||||
case float64:
|
||||
bitrateInt = (int)(val)
|
||||
default:
|
||||
return 0, fmt.Errorf("bitrate is not int or float64")
|
||||
}
|
||||
|
||||
return bitrateInt, nil
|
||||
}
|
||||
}
|
||||
|
@ -48,10 +48,6 @@ type SystemDisconnect struct {
|
||||
type SignalProvide struct {
|
||||
SDP string `json:"sdp"`
|
||||
ICEServers []types.ICEServer `json:"iceservers"`
|
||||
// TODO: Use SignalVideo struct.
|
||||
Video string `json:"video"`
|
||||
Bitrate int `json:"bitrate"`
|
||||
VideoAuto bool `json:"video_auto"`
|
||||
}
|
||||
|
||||
type SignalCandidate struct {
|
||||
@ -64,8 +60,7 @@ type SignalDescription struct {
|
||||
|
||||
type SignalVideo struct {
|
||||
Video string `json:"video"`
|
||||
Bitrate int `json:"bitrate"`
|
||||
VideoAuto bool `json:"video_auto"`
|
||||
Auto bool `json:"auto"`
|
||||
}
|
||||
|
||||
/////////////////////////////
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
var (
|
||||
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
|
||||
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
|
||||
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
|
||||
)
|
||||
|
||||
type ICEServer struct {
|
||||
@ -23,10 +24,10 @@ type WebRTCPeer interface {
|
||||
SetRemoteDescription(webrtc.SessionDescription) error
|
||||
SetCandidate(webrtc.ICECandidateInit) error
|
||||
|
||||
SetVideoBitrate(bitrate int) error
|
||||
SetVideoID(videoID string) error
|
||||
GetVideoID() string
|
||||
SetVideo(StreamSelector) error
|
||||
VideoID() (string, bool)
|
||||
SetPaused(isPaused bool) error
|
||||
Paused() bool
|
||||
SetVideoAuto(auto bool)
|
||||
VideoAuto() bool
|
||||
|
||||
@ -42,6 +43,6 @@ type WebRTCManager interface {
|
||||
|
||||
ICEServers() []ICEServer
|
||||
|
||||
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
|
||||
CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
|
||||
SetCursorPosition(x, y int)
|
||||
}
|
||||
|
153
pkg/utils/trenddetector.go
Normal file
153
pkg/utils/trenddetector.go
Normal file
@ -0,0 +1,153 @@
|
||||
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
type TrendDirection int
|
||||
|
||||
const (
|
||||
TrendDirectionNeutral TrendDirection = iota
|
||||
TrendDirectionUpward
|
||||
TrendDirectionDownward
|
||||
)
|
||||
|
||||
func (t TrendDirection) String() string {
|
||||
switch t {
|
||||
case TrendDirectionNeutral:
|
||||
return "NEUTRAL"
|
||||
case TrendDirectionUpward:
|
||||
return "UPWARD"
|
||||
case TrendDirectionDownward:
|
||||
return "DOWNWARD"
|
||||
default:
|
||||
return fmt.Sprintf("%d", int(t))
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
type TrendDetectorParams struct {
|
||||
RequiredSamples int
|
||||
DownwardTrendThreshold float64
|
||||
CollapseValues bool
|
||||
}
|
||||
|
||||
type TrendDetector struct {
|
||||
params TrendDetectorParams
|
||||
|
||||
startTime time.Time
|
||||
numSamples int
|
||||
values []int64
|
||||
lowestValue int64
|
||||
highestValue int64
|
||||
|
||||
direction TrendDirection
|
||||
}
|
||||
|
||||
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
|
||||
return &TrendDetector{
|
||||
params: params,
|
||||
startTime: time.Now(),
|
||||
direction: TrendDirectionNeutral,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TrendDetector) Seed(value int64) {
|
||||
if len(t.values) != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
t.values = append(t.values, value)
|
||||
}
|
||||
|
||||
func (t *TrendDetector) AddValue(value int64) {
|
||||
t.numSamples++
|
||||
if t.lowestValue == 0 || value < t.lowestValue {
|
||||
t.lowestValue = value
|
||||
}
|
||||
if value > t.highestValue {
|
||||
t.highestValue = value
|
||||
}
|
||||
|
||||
// ignore duplicate values
|
||||
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
|
||||
return
|
||||
}
|
||||
|
||||
if len(t.values) == t.params.RequiredSamples {
|
||||
t.values = t.values[1:]
|
||||
}
|
||||
t.values = append(t.values, value)
|
||||
|
||||
t.updateDirection()
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetLowest() int64 {
|
||||
return t.lowestValue
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetHighest() int64 {
|
||||
return t.highestValue
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetValues() []int64 {
|
||||
return t.values
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetDirection() TrendDirection {
|
||||
return t.direction
|
||||
}
|
||||
|
||||
func (t *TrendDetector) ToString() string {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(t.startTime).Seconds()
|
||||
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
|
||||
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
|
||||
return str
|
||||
}
|
||||
|
||||
func (t *TrendDetector) updateDirection() {
|
||||
if len(t.values) < t.params.RequiredSamples {
|
||||
t.direction = TrendDirectionNeutral
|
||||
return
|
||||
}
|
||||
|
||||
// using Kendall's Tau to find trend
|
||||
kt := kendallsTau(t.values)
|
||||
|
||||
t.direction = TrendDirectionNeutral
|
||||
switch {
|
||||
case kt > 0:
|
||||
t.direction = TrendDirectionUpward
|
||||
case kt < t.params.DownwardTrendThreshold:
|
||||
t.direction = TrendDirectionDownward
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
func kendallsTau(values []int64) float64 {
|
||||
concordantPairs := 0
|
||||
discordantPairs := 0
|
||||
|
||||
for i := 0; i < len(values)-1; i++ {
|
||||
for j := i + 1; j < len(values); j++ {
|
||||
if values[i] < values[j] {
|
||||
concordantPairs++
|
||||
} else if values[i] > values[j] {
|
||||
discordantPairs++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (concordantPairs + discordantPairs) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
|
||||
}
|
Loading…
Reference in New Issue
Block a user