mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
Bandwidth estimator refactor (#46)
* rewrite to use stream selector. * WIP. * add nacks to metrics. * add estimate trend. * estimator based on trend detector. * add estimator unstable duration. * add estimator debug. * add stalled duration. * estimator move values to config. * change default estimator values. * minor style changes. * fix websocket video messages. * replace video track with ivdeo id.
This commit is contained in:
parent
8660c1a256
commit
3e8d686c0f
@ -1,145 +0,0 @@
|
|||||||
package buckets
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"github.com/demodesk/neko/pkg/types"
|
|
||||||
"github.com/demodesk/neko/pkg/types/codec"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BucketsManagerCtx struct {
|
|
||||||
logger zerolog.Logger
|
|
||||||
codec codec.RTPCodec
|
|
||||||
streams map[string]types.StreamSinkManager
|
|
||||||
streamIDs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
|
|
||||||
logger := log.With().
|
|
||||||
Str("module", "capture").
|
|
||||||
Str("submodule", "buckets").
|
|
||||||
Logger()
|
|
||||||
|
|
||||||
return &BucketsManagerCtx{
|
|
||||||
logger: logger,
|
|
||||||
codec: codec,
|
|
||||||
streams: streams,
|
|
||||||
streamIDs: streamIDs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) Shutdown() {
|
|
||||||
manager.logger.Info().Msgf("shutdown")
|
|
||||||
|
|
||||||
manager.DestroyAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) DestroyAll() {
|
|
||||||
for _, stream := range manager.streams {
|
|
||||||
if stream.Started() {
|
|
||||||
stream.DestroyPipeline()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) RecreateAll() error {
|
|
||||||
for _, stream := range manager.streams {
|
|
||||||
if stream.Started() {
|
|
||||||
err := stream.CreatePipeline()
|
|
||||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) IDs() []string {
|
|
||||||
return manager.streamIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
|
|
||||||
return manager.codec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
|
|
||||||
// bitrate history is per receiver
|
|
||||||
bitrateHistory := &queue{}
|
|
||||||
|
|
||||||
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
|
|
||||||
bitrate := peerBitrate
|
|
||||||
if receiver.VideoAuto() {
|
|
||||||
bitrate = bitrateHistory.normaliseBitrate(bitrate)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := manager.findNearestStream(bitrate)
|
|
||||||
streamID := stream.ID()
|
|
||||||
|
|
||||||
// TODO: make this less noisy in logs
|
|
||||||
manager.logger.Debug().
|
|
||||||
Str("video_id", streamID).
|
|
||||||
Int("len", bitrateHistory.len()).
|
|
||||||
Int("peer_bitrate", peerBitrate).
|
|
||||||
Int("bitrate", bitrate).
|
|
||||||
Msg("change video bitrate")
|
|
||||||
|
|
||||||
return receiver.SetStream(stream)
|
|
||||||
})
|
|
||||||
|
|
||||||
receiver.OnVideoChange(func(videoID string) (bool, error) {
|
|
||||||
stream := manager.streams[videoID]
|
|
||||||
manager.logger.Info().
|
|
||||||
Str("video_id", videoID).
|
|
||||||
Msg("video change")
|
|
||||||
|
|
||||||
return receiver.SetStream(stream)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
|
|
||||||
type streamDiff struct {
|
|
||||||
id string
|
|
||||||
bitrateDiff int
|
|
||||||
}
|
|
||||||
|
|
||||||
sortDiff := func(a, b int) bool {
|
|
||||||
switch {
|
|
||||||
case a < 0 && b < 0:
|
|
||||||
return a > b
|
|
||||||
case a >= 0:
|
|
||||||
if b >= 0 {
|
|
||||||
return a <= b
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
var diffs []streamDiff
|
|
||||||
|
|
||||||
for _, stream := range manager.streams {
|
|
||||||
diffs = append(diffs, streamDiff{
|
|
||||||
id: stream.ID(),
|
|
||||||
bitrateDiff: peerBitrate - stream.Bitrate(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(diffs, func(i, j int) bool {
|
|
||||||
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
|
||||||
})
|
|
||||||
|
|
||||||
bestDiff := diffs[0]
|
|
||||||
|
|
||||||
return manager.streams[bestDiff.id]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
|
|
||||||
receiver.OnBitrateChange(nil)
|
|
||||||
receiver.OnVideoChange(nil)
|
|
||||||
receiver.RemoveStream()
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,83 +0,0 @@
|
|||||||
package buckets
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/demodesk/neko/pkg/types"
|
|
||||||
"github.com/demodesk/neko/pkg/types/codec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
codec codec.RTPCodec
|
|
||||||
streams map[string]types.StreamSinkManager
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
peerBitrate int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want types.StreamSinkManager
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "findNearestStream",
|
|
||||||
fields: fields{
|
|
||||||
streams: map[string]types.StreamSinkManager{
|
|
||||||
"1": mockStreamSink{
|
|
||||||
id: "1",
|
|
||||||
bitrate: 500,
|
|
||||||
},
|
|
||||||
"2": mockStreamSink{
|
|
||||||
id: "2",
|
|
||||||
bitrate: 750,
|
|
||||||
},
|
|
||||||
"3": mockStreamSink{
|
|
||||||
id: "3",
|
|
||||||
bitrate: 1000,
|
|
||||||
},
|
|
||||||
"4": mockStreamSink{
|
|
||||||
id: "4",
|
|
||||||
bitrate: 1250,
|
|
||||||
},
|
|
||||||
"5": mockStreamSink{
|
|
||||||
id: "5",
|
|
||||||
bitrate: 1700,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
peerBitrate: 950,
|
|
||||||
},
|
|
||||||
want: mockStreamSink{
|
|
||||||
id: "2",
|
|
||||||
bitrate: 750,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
|
|
||||||
|
|
||||||
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockStreamSink struct {
|
|
||||||
id string
|
|
||||||
bitrate int
|
|
||||||
types.StreamSinkManager
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockStreamSink) ID() string {
|
|
||||||
return m.id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockStreamSink) Bitrate() int {
|
|
||||||
return m.bitrate
|
|
||||||
}
|
|
@ -1,88 +0,0 @@
|
|||||||
package buckets
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type queue struct {
|
|
||||||
sync.Mutex
|
|
||||||
q []elem
|
|
||||||
}
|
|
||||||
|
|
||||||
type elem struct {
|
|
||||||
created time.Time
|
|
||||||
bitrate int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) push(v elem) {
|
|
||||||
q.Lock()
|
|
||||||
defer q.Unlock()
|
|
||||||
|
|
||||||
// if the first element is older than 10 seconds, remove it
|
|
||||||
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
|
|
||||||
q.q = q.q[1:]
|
|
||||||
}
|
|
||||||
q.q = append(q.q, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) len() int {
|
|
||||||
q.Lock()
|
|
||||||
defer q.Unlock()
|
|
||||||
return len(q.q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) avg() int {
|
|
||||||
q.Lock()
|
|
||||||
defer q.Unlock()
|
|
||||||
if len(q.q) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
sum := 0
|
|
||||||
for _, v := range q.q {
|
|
||||||
sum += v.bitrate
|
|
||||||
}
|
|
||||||
return sum / len(q.q)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) avgLastN(n int) int {
|
|
||||||
if n <= 0 {
|
|
||||||
return q.avg()
|
|
||||||
}
|
|
||||||
q.Lock()
|
|
||||||
defer q.Unlock()
|
|
||||||
if len(q.q) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
sum := 0
|
|
||||||
for _, v := range q.q[len(q.q)-n:] {
|
|
||||||
sum += v.bitrate
|
|
||||||
}
|
|
||||||
return sum / n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *queue) normaliseBitrate(currentBitrate int) int {
|
|
||||||
avgBitrate := float64(q.avg())
|
|
||||||
histLen := float64(q.len())
|
|
||||||
|
|
||||||
q.push(elem{
|
|
||||||
bitrate: currentBitrate,
|
|
||||||
created: time.Now(),
|
|
||||||
})
|
|
||||||
|
|
||||||
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
|
|
||||||
return currentBitrate
|
|
||||||
}
|
|
||||||
|
|
||||||
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
|
|
||||||
if lastN > q.len() {
|
|
||||||
lastN = q.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastN == 0 {
|
|
||||||
return currentBitrate
|
|
||||||
}
|
|
||||||
|
|
||||||
return q.avgLastN(lastN)
|
|
||||||
}
|
|
@ -1,99 +0,0 @@
|
|||||||
package buckets
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func Queue_normaliseBitrate(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
queue *queue
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
currentBitrate int
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want []int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normaliseBitrate: big drop",
|
|
||||||
fields: fields{
|
|
||||||
queue: &queue{
|
|
||||||
q: []elem{
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 750},
|
|
||||||
{bitrate: 780},
|
|
||||||
{bitrate: 1100},
|
|
||||||
{bitrate: 950},
|
|
||||||
{bitrate: 700},
|
|
||||||
{bitrate: 800},
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 1000},
|
|
||||||
{bitrate: 1100},
|
|
||||||
// avg = 898
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
currentBitrate: 350,
|
|
||||||
},
|
|
||||||
want: []int{816, 700, 537, 350, 350},
|
|
||||||
}, {
|
|
||||||
name: "normaliseBitrate: small drop",
|
|
||||||
fields: fields{
|
|
||||||
queue: &queue{
|
|
||||||
q: []elem{
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 750},
|
|
||||||
{bitrate: 780},
|
|
||||||
{bitrate: 1100},
|
|
||||||
{bitrate: 950},
|
|
||||||
{bitrate: 700},
|
|
||||||
{bitrate: 800},
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 1000},
|
|
||||||
{bitrate: 1100},
|
|
||||||
// avg = 898
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
currentBitrate: 700,
|
|
||||||
},
|
|
||||||
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
|
|
||||||
}, {
|
|
||||||
name: "normaliseBitrate",
|
|
||||||
fields: fields{
|
|
||||||
queue: &queue{
|
|
||||||
q: []elem{
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 750},
|
|
||||||
{bitrate: 780},
|
|
||||||
{bitrate: 1100},
|
|
||||||
{bitrate: 950},
|
|
||||||
{bitrate: 700},
|
|
||||||
{bitrate: 800},
|
|
||||||
{bitrate: 900},
|
|
||||||
{bitrate: 1000},
|
|
||||||
{bitrate: 1100},
|
|
||||||
// avg = 898
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
currentBitrate: 1350,
|
|
||||||
},
|
|
||||||
want: []int{943, 1003, 1060, 1085},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := tt.fields.queue
|
|
||||||
for i := 0; i < len(tt.want); i++ {
|
|
||||||
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
|
|
||||||
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/demodesk/neko/internal/capture/buckets"
|
|
||||||
"github.com/demodesk/neko/internal/config"
|
"github.com/demodesk/neko/internal/config"
|
||||||
"github.com/demodesk/neko/pkg/types"
|
"github.com/demodesk/neko/pkg/types"
|
||||||
"github.com/demodesk/neko/pkg/types/codec"
|
"github.com/demodesk/neko/pkg/types/codec"
|
||||||
@ -23,7 +22,7 @@ type CaptureManagerCtx struct {
|
|||||||
broadcast *BroacastManagerCtx
|
broadcast *BroacastManagerCtx
|
||||||
screencast *ScreencastManagerCtx
|
screencast *ScreencastManagerCtx
|
||||||
audio *StreamSinkManagerCtx
|
audio *StreamSinkManagerCtx
|
||||||
video types.BucketsManager
|
video *StreamSelectorManagerCtx
|
||||||
|
|
||||||
// sources
|
// sources
|
||||||
webcam *StreamSrcManagerCtx
|
webcam *StreamSrcManagerCtx
|
||||||
@ -68,13 +67,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
|||||||
Str("pipeline", pipeline).
|
Str("pipeline", pipeline).
|
||||||
Msg("syntax check for video stream pipeline passed")
|
Msg("syntax check for video stream pipeline passed")
|
||||||
|
|
||||||
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
|
|
||||||
if _, err = getVideoBitrate(); err != nil {
|
|
||||||
logger.Panic().Err(err).Msg("unable to get video bitrate")
|
|
||||||
}
|
|
||||||
|
|
||||||
// append to videos
|
// append to videos
|
||||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
|
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CaptureManagerCtx{
|
return &CaptureManagerCtx{
|
||||||
@ -140,8 +134,8 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
|||||||
"! %s "+
|
"! %s "+
|
||||||
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
|
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
|
||||||
), nil
|
), nil
|
||||||
}, "audio", nil),
|
}, "audio"),
|
||||||
video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs),
|
video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
|
||||||
|
|
||||||
// sources
|
// sources
|
||||||
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
|
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
|
||||||
@ -202,7 +196,7 @@ func (manager *CaptureManagerCtx) Start() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.desktop.OnBeforeScreenSizeChange(func() {
|
manager.desktop.OnBeforeScreenSizeChange(func() {
|
||||||
manager.video.DestroyAll()
|
manager.video.destroyPipelines()
|
||||||
|
|
||||||
if manager.broadcast.Started() {
|
if manager.broadcast.Started() {
|
||||||
manager.broadcast.destroyPipeline()
|
manager.broadcast.destroyPipeline()
|
||||||
@ -214,7 +208,7 @@ func (manager *CaptureManagerCtx) Start() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
manager.desktop.OnAfterScreenSizeChange(func() {
|
manager.desktop.OnAfterScreenSizeChange(func() {
|
||||||
err := manager.video.RecreateAll()
|
err := manager.video.recreatePipelines()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
|
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
|
||||||
}
|
}
|
||||||
@ -242,7 +236,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
|
|||||||
manager.screencast.shutdown()
|
manager.screencast.shutdown()
|
||||||
|
|
||||||
manager.audio.shutdown()
|
manager.audio.shutdown()
|
||||||
manager.video.Shutdown()
|
manager.video.shutdown()
|
||||||
|
|
||||||
manager.webcam.shutdown()
|
manager.webcam.shutdown()
|
||||||
manager.microphone.shutdown()
|
manager.microphone.shutdown()
|
||||||
@ -250,15 +244,6 @@ func (manager *CaptureManagerCtx) Shutdown() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
|
|
||||||
cfg, ok := manager.config.VideoPipelines[videoID]
|
|
||||||
if !ok {
|
|
||||||
return 0, fmt.Errorf("video config not found for %s", videoID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
|
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
|
||||||
return manager.broadcast
|
return manager.broadcast
|
||||||
}
|
}
|
||||||
@ -271,7 +256,7 @@ func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
|
|||||||
return manager.audio
|
return manager.audio
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *CaptureManagerCtx) Video() types.BucketsManager {
|
func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
|
||||||
return manager.video
|
return manager.video
|
||||||
}
|
}
|
||||||
|
|
||||||
|
206
internal/capture/streamselector.go
Normal file
206
internal/capture/streamselector.go
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
package capture
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"github.com/demodesk/neko/pkg/types"
|
||||||
|
"github.com/demodesk/neko/pkg/types/codec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StreamSelectorManagerCtx struct {
|
||||||
|
logger zerolog.Logger
|
||||||
|
codec codec.RTPCodec
|
||||||
|
streams map[string]types.StreamSinkManager
|
||||||
|
streamIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
|
||||||
|
logger := log.With().
|
||||||
|
Str("module", "capture").
|
||||||
|
Str("submodule", "stream-selector").
|
||||||
|
Logger()
|
||||||
|
|
||||||
|
return &StreamSelectorManagerCtx{
|
||||||
|
logger: logger,
|
||||||
|
codec: codec,
|
||||||
|
streams: streams,
|
||||||
|
streamIDs: streamIDs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) shutdown() {
|
||||||
|
manager.logger.Info().Msgf("shutdown")
|
||||||
|
|
||||||
|
manager.destroyPipelines()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
|
||||||
|
for _, stream := range manager.streams {
|
||||||
|
if stream.Started() {
|
||||||
|
stream.DestroyPipeline()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
|
||||||
|
for _, stream := range manager.streams {
|
||||||
|
if stream.Started() {
|
||||||
|
err := stream.CreatePipeline()
|
||||||
|
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) IDs() []string {
|
||||||
|
return manager.streamIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
|
||||||
|
return manager.codec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
|
||||||
|
// select stream by ID
|
||||||
|
if selector.ID != "" {
|
||||||
|
// select lower stream
|
||||||
|
if selector.Type == types.StreamSelectorTypeLower {
|
||||||
|
var lastStream types.StreamSinkManager
|
||||||
|
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||||
|
streamID := manager.streamIDs[i]
|
||||||
|
if streamID == selector.ID {
|
||||||
|
return lastStream, lastStream != nil
|
||||||
|
}
|
||||||
|
stream, ok := manager.streams[streamID]
|
||||||
|
if ok {
|
||||||
|
lastStream = stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// we couldn't find a lower stream
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// select higher stream
|
||||||
|
if selector.Type == types.StreamSelectorTypeHigher {
|
||||||
|
var lastStream types.StreamSinkManager
|
||||||
|
for _, streamID := range manager.streamIDs {
|
||||||
|
if streamID == selector.ID {
|
||||||
|
return lastStream, lastStream != nil
|
||||||
|
}
|
||||||
|
stream, ok := manager.streams[streamID]
|
||||||
|
if ok {
|
||||||
|
lastStream = stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// we couldn't find a higher stream
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// select exact stream
|
||||||
|
stream, ok := manager.streams[selector.ID]
|
||||||
|
return stream, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// select stream by bitrate
|
||||||
|
if selector.Bitrate != 0 {
|
||||||
|
// select stream by nearest bitrate
|
||||||
|
if selector.Type == types.StreamSelectorTypeNearest {
|
||||||
|
return manager.nearestBitrate(selector.Bitrate), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// select lower stream
|
||||||
|
if selector.Type == types.StreamSelectorTypeLower {
|
||||||
|
// start from the highest stream, and go down, until we find a lower stream
|
||||||
|
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||||
|
streamID := manager.streamIDs[i]
|
||||||
|
stream := manager.streams[streamID]
|
||||||
|
// if stream should be considered in calculation
|
||||||
|
considered := stream.Bitrate() != 0 && stream.Started()
|
||||||
|
if considered && stream.Bitrate() < selector.Bitrate {
|
||||||
|
return stream, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// we couldn't find a lower stream
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// select higher stream
|
||||||
|
if selector.Type == types.StreamSelectorTypeHigher {
|
||||||
|
// start from the lowest stream, and go up, until we find a higher stream
|
||||||
|
for _, streamID := range manager.streamIDs {
|
||||||
|
stream := manager.streams[streamID]
|
||||||
|
// if stream should be considered in calculation
|
||||||
|
considered := stream.Bitrate() != 0 && stream.Started()
|
||||||
|
if considered && stream.Bitrate() > selector.Bitrate {
|
||||||
|
return stream, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// we couldn't find a higher stream
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// select stream by exact bitrate
|
||||||
|
for _, stream := range manager.streams {
|
||||||
|
if stream.Bitrate() == selector.Bitrate {
|
||||||
|
return stream, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we couldn't find a stream
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: This is a very naive implementation, we should use a binary search instead.
|
||||||
|
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
|
||||||
|
type streamDiff struct {
|
||||||
|
id string
|
||||||
|
bitrateDiff int
|
||||||
|
}
|
||||||
|
|
||||||
|
sortDiff := func(a, b int) bool {
|
||||||
|
switch {
|
||||||
|
case a < 0 && b < 0:
|
||||||
|
return a > b
|
||||||
|
case a >= 0:
|
||||||
|
if b >= 0 {
|
||||||
|
return a <= b
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var diffs []streamDiff
|
||||||
|
|
||||||
|
for _, stream := range manager.streams {
|
||||||
|
// if stream should be considered in calculation
|
||||||
|
considered := stream.Bitrate() != 0 && stream.Started()
|
||||||
|
if !considered {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
diffs = append(diffs, streamDiff{
|
||||||
|
id: stream.ID(),
|
||||||
|
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// no streams available
|
||||||
|
if len(diffs) == 0 {
|
||||||
|
// return first (lowest) stream
|
||||||
|
return manager.streams[manager.streamIDs[0]]
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(diffs, func(i, j int) bool {
|
||||||
|
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
||||||
|
})
|
||||||
|
|
||||||
|
bestDiff := diffs[0]
|
||||||
|
return manager.streams[bestDiff.id]
|
||||||
|
}
|
@ -21,9 +21,10 @@ import (
|
|||||||
var moveSinkListenerMu = sync.Mutex{}
|
var moveSinkListenerMu = sync.Mutex{}
|
||||||
|
|
||||||
type StreamSinkManagerCtx struct {
|
type StreamSinkManagerCtx struct {
|
||||||
id string
|
id string
|
||||||
getBitrate func() (int, error)
|
|
||||||
waitForKf bool // wait for a keyframe before sending samples
|
// wait for a keyframe before sending samples
|
||||||
|
waitForKf bool
|
||||||
|
|
||||||
bitrate uint64 // atomic
|
bitrate uint64 // atomic
|
||||||
brBuckets map[int]float64
|
brBuckets map[int]float64
|
||||||
@ -48,22 +49,23 @@ type StreamSinkManagerCtx struct {
|
|||||||
pipelinesActive prometheus.Gauge
|
pipelinesActive prometheus.Gauge
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
|
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
|
||||||
logger := log.With().
|
logger := log.With().
|
||||||
Str("module", "capture").
|
Str("module", "capture").
|
||||||
Str("submodule", "stream-sink").
|
Str("submodule", "stream-sink").
|
||||||
Str("id", id).Logger()
|
Str("id", id).Logger()
|
||||||
|
|
||||||
manager := &StreamSinkManagerCtx{
|
manager := &StreamSinkManagerCtx{
|
||||||
id: id,
|
id: id,
|
||||||
getBitrate: getBitrate,
|
|
||||||
// only wait for keyframes if the codec is video
|
|
||||||
waitForKf: c.IsVideo(),
|
|
||||||
|
|
||||||
|
// only wait for keyframes if the codec is video
|
||||||
|
waitForKf: codec.IsVideo(),
|
||||||
|
|
||||||
|
bitrate: 0,
|
||||||
brBuckets: map[int]float64{},
|
brBuckets: map[int]float64{},
|
||||||
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
codec: c,
|
codec: codec,
|
||||||
pipelineFn: pipelineFn,
|
pipelineFn: pipelineFn,
|
||||||
|
|
||||||
listeners: map[uintptr]types.SampleListener{},
|
listeners: map[uintptr]types.SampleListener{},
|
||||||
@ -77,8 +79,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
|||||||
Help: "Current number of listeners for a pipeline.",
|
Help: "Current number of listeners for a pipeline.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"video_id": id,
|
"video_id": id,
|
||||||
"codec_name": c.Name,
|
"codec_name": codec.Name,
|
||||||
"codec_type": c.Type.String(),
|
"codec_type": codec.Type.String(),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
|
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
|
||||||
@ -88,8 +90,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
|||||||
Help: "Total number of bytes created by the pipeline.",
|
Help: "Total number of bytes created by the pipeline.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"video_id": id,
|
"video_id": id,
|
||||||
"codec_name": c.Name,
|
"codec_name": codec.Name,
|
||||||
"codec_type": c.Type.String(),
|
"codec_type": codec.Type.String(),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||||
@ -100,8 +102,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
|||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"submodule": "streamsink",
|
"submodule": "streamsink",
|
||||||
"video_id": id,
|
"video_id": id,
|
||||||
"codec_name": c.Name,
|
"codec_name": codec.Name,
|
||||||
"codec_type": c.Type.String(),
|
"codec_type": codec.Type.String(),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
@ -112,8 +114,8 @@ func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id strin
|
|||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"submodule": "streamsink",
|
"submodule": "streamsink",
|
||||||
"video_id": id,
|
"video_id": id,
|
||||||
"codec_name": c.Name,
|
"codec_name": codec.Name,
|
||||||
"codec_type": c.Type.String(),
|
"codec_type": codec.Type.String(),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@ -141,27 +143,8 @@ func (manager *StreamSinkManagerCtx) ID() string {
|
|||||||
return manager.id
|
return manager.id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *StreamSinkManagerCtx) Bitrate() int {
|
func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
|
||||||
// TODO: fix bitrate switching calculation
|
return atomic.LoadUint64(&manager.bitrate)
|
||||||
// return real bitrate if available
|
|
||||||
//realBitrate := atomic.LoadUint64(&manager.bitrate)
|
|
||||||
//if realBitrate != 0 {
|
|
||||||
// return int(realBitrate)
|
|
||||||
//}
|
|
||||||
|
|
||||||
// if we do not have function to estimate bitrate, return 0
|
|
||||||
if manager.getBitrate == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// recalculate bitrate every time, take screen resolution (and fps) into account
|
|
||||||
// we called this function during startup, so it shouldn't error here
|
|
||||||
bitrate, err := manager.getBitrate()
|
|
||||||
if err != nil {
|
|
||||||
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
|
|
||||||
}
|
|
||||||
|
|
||||||
return bitrate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
|
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
|
||||||
|
@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@ -15,6 +16,28 @@ import (
|
|||||||
// default stun server
|
// default stun server
|
||||||
const defStunSrv = "stun:stun.l.google.com:19302"
|
const defStunSrv = "stun:stun.l.google.com:19302"
|
||||||
|
|
||||||
|
type WebRTCEstimator struct {
|
||||||
|
Enabled bool
|
||||||
|
Passive bool
|
||||||
|
Debug bool
|
||||||
|
InitialBitrate int
|
||||||
|
|
||||||
|
// how often to read and process bandwidth estimation reports
|
||||||
|
ReadInterval time.Duration
|
||||||
|
// how long to wait for stable connection (only neutral or upward trend) before upgrading
|
||||||
|
StableDuration time.Duration
|
||||||
|
// how long to wait for unstable connection (downward trend) before downgrading
|
||||||
|
UnstableDuration time.Duration
|
||||||
|
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
|
||||||
|
StalledDuration time.Duration
|
||||||
|
// how long to wait before downgrading again after previous downgrade
|
||||||
|
DowngradeBackoff time.Duration
|
||||||
|
// how long to wait before upgrading again after previous upgrade
|
||||||
|
UpgradeBackoff time.Duration
|
||||||
|
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
|
||||||
|
DiffThreshold float64
|
||||||
|
}
|
||||||
|
|
||||||
type WebRTC struct {
|
type WebRTC struct {
|
||||||
ICELite bool
|
ICELite bool
|
||||||
ICETrickle bool
|
ICETrickle bool
|
||||||
@ -28,9 +51,7 @@ type WebRTC struct {
|
|||||||
NAT1To1IPs []string
|
NAT1To1IPs []string
|
||||||
IpRetrievalUrl string
|
IpRetrievalUrl string
|
||||||
|
|
||||||
EstimatorEnabled bool
|
Estimator WebRTCEstimator
|
||||||
EstimatorPassive bool
|
|
||||||
EstimatorInitialBitrate int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (WebRTC) Init(cmd *cobra.Command) error {
|
func (WebRTC) Init(cmd *cobra.Command) error {
|
||||||
@ -96,11 +117,51 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
|
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
|
||||||
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
|
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
|
||||||
|
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,7 +258,15 @@ func (s *WebRTC) Set() {
|
|||||||
|
|
||||||
// bandwidth estimator
|
// bandwidth estimator
|
||||||
|
|
||||||
s.EstimatorEnabled = viper.GetBool("webrtc.estimator.enabled")
|
s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
|
||||||
s.EstimatorPassive = viper.GetBool("webrtc.estimator.passive")
|
s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
|
||||||
s.EstimatorInitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
|
s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
|
||||||
|
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
|
||||||
|
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
|
||||||
|
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
|
||||||
|
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
|
||||||
|
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
|
||||||
|
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
|
||||||
|
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
|
||||||
|
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/demodesk/neko/pkg/types/codec"
|
"github.com/demodesk/neko/pkg/types/codec"
|
||||||
"github.com/demodesk/neko/pkg/types/event"
|
"github.com/demodesk/neko/pkg/types/event"
|
||||||
"github.com/demodesk/neko/pkg/types/message"
|
"github.com/demodesk/neko/pkg/types/message"
|
||||||
|
"github.com/demodesk/neko/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -167,7 +168,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
|
|||||||
return manager.config.ICEServersFrontend
|
return manager.config.ICEServersFrontend
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||||
// create media engine
|
// create media engine
|
||||||
engine := &webrtc.MediaEngine{}
|
engine := &webrtc.MediaEngine{}
|
||||||
for _, codec := range codecs {
|
for _, codec := range codecs {
|
||||||
@ -223,14 +224,10 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
|
|||||||
|
|
||||||
// create bandwidth estimator
|
// create bandwidth estimator
|
||||||
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
||||||
if manager.config.EstimatorEnabled {
|
if manager.config.Estimator.Enabled {
|
||||||
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
||||||
if bitrate == 0 {
|
|
||||||
bitrate = manager.config.EstimatorInitialBitrate
|
|
||||||
}
|
|
||||||
|
|
||||||
return gcc.NewSendSideBWE(
|
return gcc.NewSendSideBWE(
|
||||||
gcc.SendSideBWEInitialBitrate(bitrate),
|
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
|
||||||
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
|
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@ -268,7 +265,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs
|
|||||||
return connection, <-estimatorChan, err
|
return connection, <-estimatorChan, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
|
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
|
||||||
id := atomic.AddInt32(&manager.peerId, 1)
|
id := atomic.AddInt32(&manager.peerId, 1)
|
||||||
|
|
||||||
// get metrics for session
|
// get metrics for session
|
||||||
@ -287,12 +284,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
video := manager.capture.Video()
|
video := manager.capture.Video()
|
||||||
videoCodec := video.Codec()
|
videoCodec := video.Codec()
|
||||||
|
|
||||||
connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{
|
connection, estimator, err := manager.newPeerConnection(
|
||||||
audioCodec,
|
logger, []codec.RTPCodec{audioCodec, videoCodec})
|
||||||
videoCodec,
|
|
||||||
}, bitrate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// asynchronously send local ICE Candidates
|
// asynchronously send local ICE Candidates
|
||||||
@ -311,47 +306,34 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// if bitrate is 0, and estimator is enabled, use estimator bitrate
|
|
||||||
if bitrate == 0 && estimator != nil {
|
|
||||||
bitrate = estimator.GetTargetBitrate()
|
|
||||||
}
|
|
||||||
|
|
||||||
// audio track
|
// audio track
|
||||||
audioTrack, err := NewTrack(logger, audioCodec, connection)
|
audioTrack, err := NewTrack(logger, audioCodec, connection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// set stream for audio track
|
// set stream for audio track
|
||||||
_, err = audioTrack.SetStream(audio)
|
_, err = audioTrack.SetStream(audio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// if estimator is disabled, or in passive mode, disable video auto bitrate
|
|
||||||
if !manager.config.EstimatorEnabled || manager.config.EstimatorPassive {
|
|
||||||
videoAuto = false
|
|
||||||
}
|
|
||||||
|
|
||||||
videoRtcp := make(chan []rtcp.Packet, 1)
|
|
||||||
|
|
||||||
// video track
|
// video track
|
||||||
videoTrack, err := NewTrack(logger, videoCodec, connection,
|
videoRtcp := make(chan []rtcp.Packet, 1)
|
||||||
WithVideoAuto(videoAuto),
|
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
|
||||||
WithRtcpChan(videoRtcp),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// let video stream bucket manager handle stream subscriptions
|
//
|
||||||
video.SetReceiver(videoTrack)
|
// stream for video track will be set later
|
||||||
|
//
|
||||||
|
|
||||||
// data channel
|
// data channel
|
||||||
|
|
||||||
dataChannel, err := connection.CreateDataChannel("data", nil)
|
dataChannel, err := connection.CreateDataChannel("data", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := &WebRTCPeerCtx{
|
peer := &WebRTCPeerCtx{
|
||||||
@ -359,24 +341,29 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
session: session,
|
session: session,
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
connection: connection,
|
connection: connection,
|
||||||
estimator: estimator,
|
// bandwidth estimator
|
||||||
|
estimator: estimator,
|
||||||
|
estimateTrend: utils.NewTrendDetector(
|
||||||
|
utils.TrendDetectorParams{
|
||||||
|
// Probing
|
||||||
|
//RequiredSamples: 3,
|
||||||
|
//DownwardTrendThreshold: 0.0,
|
||||||
|
//CollapseValues: false,
|
||||||
|
// Non-Probing
|
||||||
|
RequiredSamples: 8,
|
||||||
|
DownwardTrendThreshold: -0.5,
|
||||||
|
CollapseValues: true,
|
||||||
|
}),
|
||||||
|
// stream selectors
|
||||||
|
videoSelector: manager.capture.Video(),
|
||||||
// tracks & channels
|
// tracks & channels
|
||||||
audioTrack: audioTrack,
|
audioTrack: audioTrack,
|
||||||
videoTrack: videoTrack,
|
videoTrack: videoTrack,
|
||||||
dataChannel: dataChannel,
|
dataChannel: dataChannel,
|
||||||
rtcpChannel: videoRtcp,
|
rtcpChannel: videoRtcp,
|
||||||
// config
|
// config
|
||||||
iceTrickle: manager.config.ICETrickle,
|
iceTrickle: manager.config.ICETrickle,
|
||||||
estimatorPassive: manager.config.EstimatorPassive,
|
estimatorConfig: manager.config.Estimator,
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info().
|
|
||||||
Int("target_bitrate", bitrate).
|
|
||||||
Msg("estimated initial peer bitrate")
|
|
||||||
|
|
||||||
// set initial video bitrate
|
|
||||||
if err := peer.SetVideoBitrate(bitrate); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||||
@ -492,9 +479,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
// ensure we only run this once
|
// ensure we only run this once
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
session.SetWebRTCConnected(peer, false)
|
session.SetWebRTCConnected(peer, false)
|
||||||
if err = video.RemoveReceiver(videoTrack); err != nil {
|
//
|
||||||
logger.Err(err).Msg("failed to remove video receiver")
|
// TODO: Shutdown peer?
|
||||||
}
|
//
|
||||||
audioTrack.Shutdown()
|
audioTrack.Shutdown()
|
||||||
videoTrack.Shutdown()
|
videoTrack.Shutdown()
|
||||||
close(videoRtcp)
|
close(videoRtcp)
|
||||||
@ -542,7 +529,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
|
|
||||||
offer, err := peer.CreateOffer(false)
|
offer, err := peer.CreateOffer(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// on negotiation needed handler must be registered after creating initial
|
// on negotiation needed handler must be registered after creating initial
|
||||||
@ -576,7 +563,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
|
|||||||
// start estimator reader
|
// start estimator reader
|
||||||
go peer.estimatorReader()
|
go peer.estimatorReader()
|
||||||
|
|
||||||
return offer, nil
|
return offer, peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
|
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
|
||||||
|
@ -121,7 +121,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
|||||||
Name: "receiver_estimated_maximum_bitrate",
|
Name: "receiver_estimated_maximum_bitrate",
|
||||||
Namespace: "neko",
|
Namespace: "neko",
|
||||||
Subsystem: "webrtc",
|
Subsystem: "webrtc",
|
||||||
Help: "Receiver Estimated Maximum Bitrate from SCTP.",
|
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"session_id": sessionId,
|
"session_id": sessionId,
|
||||||
},
|
},
|
||||||
@ -140,7 +140,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
|||||||
Name: "receiver_report_delay",
|
Name: "receiver_report_delay",
|
||||||
Namespace: "neko",
|
Namespace: "neko",
|
||||||
Subsystem: "webrtc",
|
Subsystem: "webrtc",
|
||||||
Help: "Receiver Report Delay from SCTP, expressed in units of 1/65536 seconds.",
|
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"session_id": sessionId,
|
"session_id": sessionId,
|
||||||
},
|
},
|
||||||
@ -149,7 +149,7 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
|||||||
Name: "receiver_report_jitter",
|
Name: "receiver_report_jitter",
|
||||||
Namespace: "neko",
|
Namespace: "neko",
|
||||||
Subsystem: "webrtc",
|
Subsystem: "webrtc",
|
||||||
Help: "Receiver Report Jitter from SCTP.",
|
Help: "Receiver Report Jitter from RTCP.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"session_id": sessionId,
|
"session_id": sessionId,
|
||||||
},
|
},
|
||||||
@ -158,7 +158,17 @@ func (m *metricsManager) getBySession(session types.Session) *metrics {
|
|||||||
Name: "receiver_report_total_lost",
|
Name: "receiver_report_total_lost",
|
||||||
Namespace: "neko",
|
Namespace: "neko",
|
||||||
Subsystem: "webrtc",
|
Subsystem: "webrtc",
|
||||||
Help: "Receiver Report Total Lost from SCTP.",
|
Help: "Receiver Report Total Lost from RTCP.",
|
||||||
|
ConstLabels: map[string]string{
|
||||||
|
"session_id": sessionId,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
|
||||||
|
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
|
||||||
|
Name: "transport_layer_nacks",
|
||||||
|
Namespace: "neko",
|
||||||
|
Subsystem: "webrtc",
|
||||||
|
Help: "Transport Layer NACKs from RTCP.",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
"session_id": sessionId,
|
"session_id": sessionId,
|
||||||
},
|
},
|
||||||
@ -236,6 +246,8 @@ type metrics struct {
|
|||||||
receiverReportJitter prometheus.Gauge
|
receiverReportJitter prometheus.Gauge
|
||||||
receiverReportTotalLost prometheus.Gauge
|
receiverReportTotalLost prometheus.Gauge
|
||||||
|
|
||||||
|
transportLayerNacks prometheus.Counter
|
||||||
|
|
||||||
iceBytesSent prometheus.Gauge
|
iceBytesSent prometheus.Gauge
|
||||||
iceBytesReceived prometheus.Gauge
|
iceBytesReceived prometheus.Gauge
|
||||||
sctpBytesSent prometheus.Gauge
|
sctpBytesSent prometheus.Gauge
|
||||||
@ -386,6 +398,11 @@ func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
|
|||||||
// use only last report
|
// use only last report
|
||||||
met.SetReceiverReport(rtcpPacket.Reports[l-1])
|
met.SetReceiverReport(rtcpPacket.Reports[l-1])
|
||||||
}
|
}
|
||||||
|
case *rtcp.TransportLayerNack:
|
||||||
|
for _, pair := range rtcpPacket.Nacks {
|
||||||
|
packetList := pair.PacketList()
|
||||||
|
met.transportLayerNacks.Add(float64(len(packetList)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,15 +11,12 @@ import (
|
|||||||
"github.com/pion/webrtc/v3"
|
"github.com/pion/webrtc/v3"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/demodesk/neko/internal/config"
|
||||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||||
"github.com/demodesk/neko/pkg/types"
|
"github.com/demodesk/neko/pkg/types"
|
||||||
"github.com/demodesk/neko/pkg/types/event"
|
"github.com/demodesk/neko/pkg/types/event"
|
||||||
"github.com/demodesk/neko/pkg/types/message"
|
"github.com/demodesk/neko/pkg/types/message"
|
||||||
)
|
"github.com/demodesk/neko/pkg/utils"
|
||||||
|
|
||||||
const (
|
|
||||||
// how often to read and process bandwidth estimation reports
|
|
||||||
estimatorReadInterval = 250 * time.Millisecond
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WebRTCPeerCtx struct {
|
type WebRTCPeerCtx struct {
|
||||||
@ -28,15 +25,20 @@ type WebRTCPeerCtx struct {
|
|||||||
session types.Session
|
session types.Session
|
||||||
metrics *metrics
|
metrics *metrics
|
||||||
connection *webrtc.PeerConnection
|
connection *webrtc.PeerConnection
|
||||||
estimator cc.BandwidthEstimator
|
// bandwidth estimator
|
||||||
|
estimator cc.BandwidthEstimator
|
||||||
|
estimateTrend *utils.TrendDetector
|
||||||
|
// stream selectors
|
||||||
|
videoSelector types.StreamSelectorManager
|
||||||
// tracks & channels
|
// tracks & channels
|
||||||
audioTrack *Track
|
audioTrack *Track
|
||||||
videoTrack *Track
|
videoTrack *Track
|
||||||
dataChannel *webrtc.DataChannel
|
dataChannel *webrtc.DataChannel
|
||||||
rtcpChannel chan []rtcp.Packet
|
rtcpChannel chan []rtcp.Packet
|
||||||
// config
|
// config
|
||||||
iceTrickle bool
|
iceTrickle bool
|
||||||
estimatorPassive bool
|
estimatorConfig config.WebRTCEstimator
|
||||||
|
videoAuto bool
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -102,6 +104,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
|
|||||||
return peer.connection.AddICECandidate(candidate)
|
return peer.connection.AddICECandidate(candidate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Add shutdown function?
|
||||||
func (peer *WebRTCPeerCtx) Destroy() {
|
func (peer *WebRTCPeerCtx) Destroy() {
|
||||||
peer.mu.Lock()
|
peer.mu.Lock()
|
||||||
defer peer.mu.Unlock()
|
defer peer.mu.Unlock()
|
||||||
@ -111,32 +114,186 @@ func (peer *WebRTCPeerCtx) Destroy() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) estimatorReader() {
|
func (peer *WebRTCPeerCtx) estimatorReader() {
|
||||||
|
conf := peer.estimatorConfig
|
||||||
|
|
||||||
|
// if estimator is not in debug mode, use a nop logger
|
||||||
|
var debugLogger zerolog.Logger
|
||||||
|
if conf.Debug {
|
||||||
|
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
|
||||||
|
} else {
|
||||||
|
debugLogger = zerolog.Nop()
|
||||||
|
}
|
||||||
|
|
||||||
// if estimator is disabled, do nothing
|
// if estimator is disabled, do nothing
|
||||||
if peer.estimator == nil {
|
if peer.estimator == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// use a ticker to get current client target bitrate
|
// use a ticker to get current client target bitrate
|
||||||
ticker := time.NewTicker(estimatorReadInterval)
|
ticker := time.NewTicker(conf.ReadInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// since when is the estimate stable/unstable
|
||||||
|
stableSince := time.Now() // we asume stable at start
|
||||||
|
unstableSince := time.Time{}
|
||||||
|
// since when are we neutral but cannot accomodate current bitrate
|
||||||
|
// we migt be stalled or estimator just reached zer (very bad connection)
|
||||||
|
stalledSince := time.Time{}
|
||||||
|
// when was the last upgrade/downgrade
|
||||||
|
lastUpgradeTime := time.Time{}
|
||||||
|
lastDowngradeTime := time.Time{}
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
targetBitrate := peer.estimator.GetTargetBitrate()
|
targetBitrate := peer.estimator.GetTargetBitrate()
|
||||||
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
|
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
|
||||||
|
|
||||||
|
// if peer connection is closed, stop reading
|
||||||
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if !peer.videoTrack.VideoAuto() {
|
// if estimation is disabled, do nothing
|
||||||
|
if !peer.videoAuto || conf.Passive {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !peer.estimatorPassive {
|
// get trend direction to decide if we should upgrade or downgrade
|
||||||
err := peer.SetVideoBitrate(targetBitrate)
|
peer.estimateTrend.AddValue(int64(targetBitrate))
|
||||||
if err != nil {
|
direction := peer.estimateTrend.GetDirection()
|
||||||
peer.logger.Warn().Err(err).Msg("failed to set video bitrate")
|
|
||||||
|
// get current stream bitrate
|
||||||
|
stream, ok := peer.videoTrack.Stream()
|
||||||
|
if !ok {
|
||||||
|
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if stream bitrate is 0, we need to wait for some time until we get a valid value
|
||||||
|
streamId, streamBitrate := stream.ID(), stream.Bitrate()
|
||||||
|
if streamBitrate == 0 {
|
||||||
|
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check whats the difference between target and stream bitrate
|
||||||
|
diff := float64(targetBitrate) / float64(streamBitrate)
|
||||||
|
|
||||||
|
debugLogger.Info().
|
||||||
|
Float64("diff", diff).
|
||||||
|
Int("target_bitrate", targetBitrate).
|
||||||
|
Uint64("stream_bitrate", streamBitrate).
|
||||||
|
Str("direction", direction.String()).
|
||||||
|
Msg("got bitrate from estimator")
|
||||||
|
|
||||||
|
// if we can accomodate current stream or we are not netural anymore,
|
||||||
|
// we are not stalled so we reset the stalled time
|
||||||
|
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
|
||||||
|
stalledSince = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are neutral and stalled for too long, we might be congesting
|
||||||
|
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
|
||||||
|
if stalled {
|
||||||
|
debugLogger.Warn().
|
||||||
|
Time("stalled_since", stalledSince).
|
||||||
|
Msgf("it looks like we are stalled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we have an downward trend or are stalled, we might be congesting
|
||||||
|
if direction == utils.TrendDirectionDownward || stalled {
|
||||||
|
// we reset the stable time because we are congesting
|
||||||
|
stableSince = time.Now()
|
||||||
|
|
||||||
|
// if we downgraded recently, we wait for some more time
|
||||||
|
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Time("last_downgrade", lastDowngradeTime).
|
||||||
|
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if we are not unstable but we fluctuate we should wait for some more time
|
||||||
|
if time.Since(unstableSince) < conf.UnstableDuration {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Time("unstable_since", unstableSince).
|
||||||
|
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we still have a big difference between target and stream bitrate, we wait for some more time
|
||||||
|
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Float64("diff", diff).
|
||||||
|
Float64("threshold", conf.DiffThreshold).
|
||||||
|
Msgf("we still have a big difference between target and stream bitrate, " +
|
||||||
|
"therefore we still should be able to accomodate current stream")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := peer.SetVideo(types.StreamSelector{
|
||||||
|
ID: streamId,
|
||||||
|
Type: types.StreamSelectorTypeLower,
|
||||||
|
})
|
||||||
|
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||||
|
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
|
||||||
|
}
|
||||||
|
lastDowngradeTime = time.Now()
|
||||||
|
|
||||||
|
if err == types.ErrWebRTCStreamNotFound {
|
||||||
|
debugLogger.Info().Msg("looks like we are already on the lowest stream")
|
||||||
|
} else {
|
||||||
|
debugLogger.Info().Msg("downgraded video stream")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// we reset the unstable time because we are not congesting
|
||||||
|
unstableSince = time.Now()
|
||||||
|
|
||||||
|
// if we have a neutral or upward trend, that means our estimate is stable
|
||||||
|
// if we are on the highest stream, we don't need to do anything
|
||||||
|
// but if there is a higher stream, we should try to upgrade and see if it works
|
||||||
|
|
||||||
|
// if we upgraded recently, we wait for some more time
|
||||||
|
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Time("last_upgrade", lastUpgradeTime).
|
||||||
|
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are not stable for long enough, we wait for some more time
|
||||||
|
// because bandwidth estimation might fluctuate
|
||||||
|
if time.Since(stableSince) < conf.StableDuration {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Time("stable_since", stableSince).
|
||||||
|
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// upgrade only if estimated bitrate passed the threshold
|
||||||
|
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
|
||||||
|
debugLogger.Debug().
|
||||||
|
Float64("diff", diff).
|
||||||
|
Float64("threshold", conf.DiffThreshold).
|
||||||
|
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
|
||||||
|
"therefore we should wait for some more time")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := peer.SetVideo(types.StreamSelector{
|
||||||
|
ID: streamId,
|
||||||
|
Type: types.StreamSelectorTypeHigher,
|
||||||
|
})
|
||||||
|
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||||
|
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
|
||||||
|
}
|
||||||
|
lastUpgradeTime = time.Now()
|
||||||
|
|
||||||
|
if err == types.ErrWebRTCStreamNotFound {
|
||||||
|
debugLogger.Info().Msg("looks like we are already on the highest stream")
|
||||||
|
} else {
|
||||||
|
debugLogger.Info().Msg("upgraded video stream")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,88 +302,52 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
|
|||||||
// video
|
// video
|
||||||
//
|
//
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
|
func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error {
|
||||||
peer.mu.Lock()
|
peer.mu.Lock()
|
||||||
defer peer.mu.Unlock()
|
defer peer.mu.Unlock()
|
||||||
|
|
||||||
// when switching from manual to auto bitrate estimation, in case the estimator is
|
// get requested video stream from selector
|
||||||
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
|
stream, ok := peer.videoSelector.GetStream(selector)
|
||||||
if peerBitrate == 0 && peer.estimator != nil && !peer.estimatorPassive {
|
if !ok {
|
||||||
peerBitrate = peer.estimator.GetTargetBitrate()
|
return types.ErrWebRTCStreamNotFound
|
||||||
peer.logger.Debug().
|
|
||||||
Int("peer_bitrate", peerBitrate).
|
|
||||||
Msg("evaluated bitrate")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := peer.videoTrack.SetBitrate(peerBitrate)
|
// set video stream to track
|
||||||
|
changed, err := peer.videoTrack.SetStream(stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if video stream was already set, do nothing
|
||||||
if !changed {
|
if !changed {
|
||||||
// TODO: return error?
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
videoID := peer.videoTrack.stream.ID()
|
videoID := stream.ID()
|
||||||
bitrate := peer.videoTrack.stream.Bitrate()
|
|
||||||
|
|
||||||
peer.metrics.SetVideoID(videoID)
|
peer.metrics.SetVideoID(videoID)
|
||||||
peer.logger.Debug().
|
|
||||||
Int("peer_bitrate", peerBitrate).
|
peer.logger.Info().Str("video_id", videoID).Msg("set video")
|
||||||
Int("video_bitrate", bitrate).
|
|
||||||
Str("video_id", videoID).
|
|
||||||
Msg("peer bitrate triggered video stream change")
|
|
||||||
|
|
||||||
go peer.session.Send(
|
go peer.session.Send(
|
||||||
event.SIGNAL_VIDEO,
|
event.SIGNAL_VIDEO,
|
||||||
message.SignalVideo{
|
message.SignalVideo{
|
||||||
Video: videoID,
|
Video: videoID,
|
||||||
Bitrate: bitrate,
|
Auto: peer.videoAuto,
|
||||||
VideoAuto: peer.videoTrack.VideoAuto(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
|
func (peer *WebRTCPeerCtx) VideoID() (string, bool) {
|
||||||
peer.mu.Lock()
|
peer.mu.Lock()
|
||||||
defer peer.mu.Unlock()
|
defer peer.mu.Unlock()
|
||||||
|
|
||||||
changed, err := peer.videoTrack.SetVideoID(videoID)
|
stream, ok := peer.videoTrack.Stream()
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !changed {
|
return stream.ID(), true
|
||||||
// TODO: return error?
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
bitrate := peer.videoTrack.stream.Bitrate()
|
|
||||||
|
|
||||||
peer.logger.Debug().
|
|
||||||
Str("video_id", videoID).
|
|
||||||
Int("video_bitrate", bitrate).
|
|
||||||
Msg("peer video id triggered video stream change")
|
|
||||||
|
|
||||||
go peer.session.Send(
|
|
||||||
event.SIGNAL_VIDEO,
|
|
||||||
message.SignalVideo{
|
|
||||||
Video: videoID,
|
|
||||||
Bitrate: bitrate,
|
|
||||||
VideoAuto: peer.videoTrack.VideoAuto(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) GetVideoID() string {
|
|
||||||
peer.mu.Lock()
|
|
||||||
defer peer.mu.Unlock()
|
|
||||||
|
|
||||||
// TODO: Refactor.
|
|
||||||
return peer.videoTrack.stream.ID()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||||
@ -239,18 +360,32 @@ func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *WebRTCPeerCtx) Paused() bool {
|
||||||
|
peer.mu.Lock()
|
||||||
|
defer peer.mu.Unlock()
|
||||||
|
|
||||||
|
return peer.videoTrack.Paused() || peer.audioTrack.Paused()
|
||||||
|
}
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
|
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
|
||||||
|
peer.mu.Lock()
|
||||||
|
defer peer.mu.Unlock()
|
||||||
|
|
||||||
// if estimator is enabled and is not passive, enable video auto bitrate
|
// if estimator is enabled and is not passive, enable video auto bitrate
|
||||||
if peer.estimator != nil && !peer.estimatorPassive {
|
if peer.estimator != nil && !peer.estimatorConfig.Passive {
|
||||||
peer.videoTrack.SetVideoAuto(videoAuto)
|
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
|
||||||
|
peer.videoAuto = videoAuto
|
||||||
} else {
|
} else {
|
||||||
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
|
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
|
||||||
peer.videoTrack.SetVideoAuto(false) // ensure video auto is disabled
|
peer.videoAuto = false // ensure video auto is disabled
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *WebRTCPeerCtx) VideoAuto() bool {
|
func (peer *WebRTCPeerCtx) VideoAuto() bool {
|
||||||
return peer.videoTrack.VideoAuto()
|
peer.mu.Lock()
|
||||||
|
defer peer.mu.Unlock()
|
||||||
|
|
||||||
|
return peer.videoAuto
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -2,7 +2,6 @@ package webrtc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -22,25 +21,13 @@ type Track struct {
|
|||||||
rtcpCh chan []rtcp.Packet
|
rtcpCh chan []rtcp.Packet
|
||||||
sample chan types.Sample
|
sample chan types.Sample
|
||||||
|
|
||||||
videoAuto bool
|
|
||||||
videoAutoMu sync.RWMutex
|
|
||||||
|
|
||||||
paused bool
|
paused bool
|
||||||
stream types.StreamSinkManager
|
stream types.StreamSinkManager
|
||||||
streamMu sync.Mutex
|
streamMu sync.Mutex
|
||||||
|
|
||||||
bitrateChange func(int) (bool, error)
|
|
||||||
videoChange func(string) (bool, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type trackOption func(*Track)
|
type trackOption func(*Track)
|
||||||
|
|
||||||
func WithVideoAuto(auto bool) trackOption {
|
|
||||||
return func(t *Track) {
|
|
||||||
t.videoAuto = auto
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
|
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
|
||||||
return func(t *Track) {
|
return func(t *Track) {
|
||||||
t.rtcpCh = rtcp
|
t.rtcpCh = rtcp
|
||||||
@ -100,6 +87,8 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- sample ---
|
||||||
|
|
||||||
func (t *Track) sampleReader() {
|
func (t *Track) sampleReader() {
|
||||||
for {
|
for {
|
||||||
sample, ok := <-t.sample
|
sample, ok := <-t.sample
|
||||||
@ -120,6 +109,12 @@ func (t *Track) sampleReader() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Track) WriteSample(sample types.Sample) {
|
||||||
|
t.sample <- sample
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- stream ---
|
||||||
|
|
||||||
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
||||||
t.streamMu.Lock()
|
t.streamMu.Lock()
|
||||||
defer t.streamMu.Unlock()
|
defer t.streamMu.Unlock()
|
||||||
@ -167,6 +162,15 @@ func (t *Track) RemoveStream() {
|
|||||||
t.stream = nil
|
t.stream = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Track) Stream() (types.StreamSinkManager, bool) {
|
||||||
|
t.streamMu.Lock()
|
||||||
|
defer t.streamMu.Unlock()
|
||||||
|
|
||||||
|
return t.stream, t.stream != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- paused ---
|
||||||
|
|
||||||
func (t *Track) SetPaused(paused bool) {
|
func (t *Track) SetPaused(paused bool) {
|
||||||
t.streamMu.Lock()
|
t.streamMu.Lock()
|
||||||
defer t.streamMu.Unlock()
|
defer t.streamMu.Unlock()
|
||||||
@ -190,42 +194,9 @@ func (t *Track) SetPaused(paused bool) {
|
|||||||
t.paused = paused
|
t.paused = paused
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Track) WriteSample(sample types.Sample) {
|
func (t *Track) Paused() bool {
|
||||||
t.sample <- sample
|
t.streamMu.Lock()
|
||||||
}
|
defer t.streamMu.Unlock()
|
||||||
|
|
||||||
func (t *Track) SetBitrate(bitrate int) (bool, error) {
|
return t.paused
|
||||||
if t.bitrateChange == nil {
|
|
||||||
return false, fmt.Errorf("bitrate change not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.bitrateChange(bitrate)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Track) SetVideoID(videoID string) (bool, error) {
|
|
||||||
if t.videoChange == nil {
|
|
||||||
return false, fmt.Errorf("video change not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.videoChange(videoID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
|
|
||||||
t.bitrateChange = f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
|
|
||||||
t.videoChange = f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Track) SetVideoAuto(auto bool) {
|
|
||||||
t.videoAutoMu.Lock()
|
|
||||||
defer t.videoAutoMu.Unlock()
|
|
||||||
t.videoAuto = auto
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Track) VideoAuto() bool {
|
|
||||||
t.videoAutoMu.RLock()
|
|
||||||
defer t.videoAutoMu.RUnlock()
|
|
||||||
return t.videoAuto
|
|
||||||
}
|
}
|
||||||
|
@ -20,28 +20,26 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
|||||||
payload.Video = videos[0]
|
payload.Video = videos[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
offer, peer, err := h.webrtc.CreatePeer(session)
|
||||||
if payload.Bitrate == 0 {
|
|
||||||
// get bitrate from video id
|
|
||||||
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
|
// set webrtc as paused if session has private mode enabled
|
||||||
// set webrtc as paused if session has private mode enabled
|
if session.PrivateModeEnabled() {
|
||||||
if session.PrivateModeEnabled() {
|
peer.SetPaused(true)
|
||||||
webrtcPeer.SetPaused(true)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
payload.Video = webrtcPeer.GetVideoID()
|
// set video auto state
|
||||||
payload.VideoAuto = webrtcPeer.VideoAuto()
|
peer.SetVideoAuto(payload.Auto)
|
||||||
|
|
||||||
|
// set video stream
|
||||||
|
err = peer.SetVideo(types.StreamSelector{
|
||||||
|
ID: payload.Video,
|
||||||
|
Type: types.StreamSelectorTypeNearest,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Send(
|
session.Send(
|
||||||
@ -49,9 +47,6 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
|||||||
message.SignalProvide{
|
message.SignalProvide{
|
||||||
SDP: offer.SDP,
|
SDP: offer.SDP,
|
||||||
ICEServers: h.webrtc.ICEServers(),
|
ICEServers: h.webrtc.ICEServers(),
|
||||||
Video: payload.Video, // TODO: Refactor
|
|
||||||
Bitrate: payload.Bitrate,
|
|
||||||
VideoAuto: payload.VideoAuto,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -133,16 +128,13 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
|
|||||||
return errors.New("webRTC peer does not exist")
|
return errors.New("webRTC peer does not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.SetVideoAuto(payload.VideoAuto)
|
peer.SetVideoAuto(payload.Auto)
|
||||||
|
|
||||||
if payload.Video != "" {
|
if payload.Video != "" {
|
||||||
if err := peer.SetVideoID(payload.Video); err != nil {
|
return peer.SetVideo(types.StreamSelector{
|
||||||
h.logger.Error().Err(err).Msg("failed to set video id")
|
ID: payload.Video,
|
||||||
}
|
Type: types.StreamSelectorTypeNearest,
|
||||||
} else {
|
})
|
||||||
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
|
|
||||||
h.logger.Error().Err(err).Msg("failed to set video bitrate")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -31,26 +31,6 @@ type SampleListener interface {
|
|||||||
WriteSample(Sample)
|
WriteSample(Sample)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Receiver interface {
|
|
||||||
SetStream(stream StreamSinkManager) (changed bool, err error)
|
|
||||||
RemoveStream()
|
|
||||||
OnBitrateChange(f func(bitrate int) (changed bool, err error))
|
|
||||||
OnVideoChange(f func(videoID string) (changed bool, err error))
|
|
||||||
VideoAuto() bool
|
|
||||||
SetVideoAuto(videoAuto bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
type BucketsManager interface {
|
|
||||||
IDs() []string
|
|
||||||
Codec() codec.RTPCodec
|
|
||||||
SetReceiver(receiver Receiver)
|
|
||||||
RemoveReceiver(receiver Receiver) error
|
|
||||||
|
|
||||||
DestroyAll()
|
|
||||||
RecreateAll() error
|
|
||||||
Shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
type BroadcastManager interface {
|
type BroadcastManager interface {
|
||||||
Start(url string) error
|
Start(url string) error
|
||||||
Stop()
|
Stop()
|
||||||
@ -64,10 +44,74 @@ type ScreencastManager interface {
|
|||||||
Image() ([]byte, error)
|
Image() ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamSelectorType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// select exact stream
|
||||||
|
StreamSelectorTypeExact StreamSelectorType = iota
|
||||||
|
// select nearest stream (in either direction) if exact stream is not available
|
||||||
|
StreamSelectorTypeNearest
|
||||||
|
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
|
||||||
|
StreamSelectorTypeLower
|
||||||
|
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
|
||||||
|
StreamSelectorTypeHigher
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s StreamSelectorType) String() string {
|
||||||
|
switch s {
|
||||||
|
case StreamSelectorTypeExact:
|
||||||
|
return "exact"
|
||||||
|
case StreamSelectorTypeNearest:
|
||||||
|
return "nearest"
|
||||||
|
case StreamSelectorTypeLower:
|
||||||
|
return "lower"
|
||||||
|
case StreamSelectorTypeHigher:
|
||||||
|
return "higher"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d", int(s))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
|
||||||
|
switch strings.ToLower(string(text)) {
|
||||||
|
case "exact", "":
|
||||||
|
*s = StreamSelectorTypeExact
|
||||||
|
case "nearest":
|
||||||
|
*s = StreamSelectorTypeNearest
|
||||||
|
case "lower":
|
||||||
|
*s = StreamSelectorTypeLower
|
||||||
|
case "higher":
|
||||||
|
*s = StreamSelectorTypeHigher
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid stream selector type: %s", string(text))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s StreamSelectorType) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(s.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamSelector struct {
|
||||||
|
// type of stream selector
|
||||||
|
Type StreamSelectorType
|
||||||
|
// select stream by its ID
|
||||||
|
ID string
|
||||||
|
// select stream by its bitrate
|
||||||
|
Bitrate uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamSelectorManager interface {
|
||||||
|
IDs() []string
|
||||||
|
Codec() codec.RTPCodec
|
||||||
|
|
||||||
|
GetStream(selector StreamSelector) (StreamSinkManager, bool)
|
||||||
|
}
|
||||||
|
|
||||||
type StreamSinkManager interface {
|
type StreamSinkManager interface {
|
||||||
ID() string
|
ID() string
|
||||||
Codec() codec.RTPCodec
|
Codec() codec.RTPCodec
|
||||||
Bitrate() int
|
Bitrate() uint64
|
||||||
|
|
||||||
AddListener(listener SampleListener) error
|
AddListener(listener SampleListener) error
|
||||||
RemoveListener(listener SampleListener) error
|
RemoveListener(listener SampleListener) error
|
||||||
@ -94,12 +138,10 @@ type CaptureManager interface {
|
|||||||
Start()
|
Start()
|
||||||
Shutdown() error
|
Shutdown() error
|
||||||
|
|
||||||
GetBitrateFromVideoID(videoID string) (int, error)
|
|
||||||
|
|
||||||
Broadcast() BroadcastManager
|
Broadcast() BroadcastManager
|
||||||
Screencast() ScreencastManager
|
Screencast() ScreencastManager
|
||||||
Audio() StreamSinkManager
|
Audio() StreamSinkManager
|
||||||
Video() BucketsManager
|
Video() StreamSelectorManager
|
||||||
|
|
||||||
Webcam() StreamSrcManager
|
Webcam() StreamSrcManager
|
||||||
Microphone() StreamSrcManager
|
Microphone() StreamSrcManager
|
||||||
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
|
|||||||
config.GstSuffix,
|
config.GstSuffix,
|
||||||
}[:], " "), nil
|
}[:], " "), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
|
|
||||||
return func() (int, error) {
|
|
||||||
if config.Bitrate > 0 {
|
|
||||||
return config.Bitrate, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
screen := getScreen()
|
|
||||||
|
|
||||||
values := map[string]any{
|
|
||||||
"width": screen.Width,
|
|
||||||
"height": screen.Height,
|
|
||||||
"fps": screen.Rate,
|
|
||||||
}
|
|
||||||
|
|
||||||
language := []gval.Language{
|
|
||||||
gval.Function("round", func(args ...any) (any, error) {
|
|
||||||
return (int)(math.Round(args[0].(float64))), nil
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
// TOOD: do not read target-bitrate from pipeline, but only from config.
|
|
||||||
|
|
||||||
// TODO: This is only for vp8.
|
|
||||||
expr, ok := config.GstParams["target-bitrate"]
|
|
||||||
if !ok {
|
|
||||||
// TODO: This is only for h264.
|
|
||||||
expr, ok = config.GstParams["bitrate"]
|
|
||||||
if !ok {
|
|
||||||
return 0, fmt.Errorf("bitrate not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bitrate, err := gval.Evaluate(expr, values, language...)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var bitrateInt int
|
|
||||||
switch val := bitrate.(type) {
|
|
||||||
case int:
|
|
||||||
bitrateInt = val
|
|
||||||
case float64:
|
|
||||||
bitrateInt = (int)(val)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("bitrate is not int or float64")
|
|
||||||
}
|
|
||||||
|
|
||||||
return bitrateInt, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -48,10 +48,6 @@ type SystemDisconnect struct {
|
|||||||
type SignalProvide struct {
|
type SignalProvide struct {
|
||||||
SDP string `json:"sdp"`
|
SDP string `json:"sdp"`
|
||||||
ICEServers []types.ICEServer `json:"iceservers"`
|
ICEServers []types.ICEServer `json:"iceservers"`
|
||||||
// TODO: Use SignalVideo struct.
|
|
||||||
Video string `json:"video"`
|
|
||||||
Bitrate int `json:"bitrate"`
|
|
||||||
VideoAuto bool `json:"video_auto"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SignalCandidate struct {
|
type SignalCandidate struct {
|
||||||
@ -63,9 +59,8 @@ type SignalDescription struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SignalVideo struct {
|
type SignalVideo struct {
|
||||||
Video string `json:"video"`
|
Video string `json:"video"`
|
||||||
Bitrate int `json:"bitrate"`
|
Auto bool `json:"auto"`
|
||||||
VideoAuto bool `json:"video_auto"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
|
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
|
||||||
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
|
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
|
||||||
|
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICEServer struct {
|
type ICEServer struct {
|
||||||
@ -23,10 +24,10 @@ type WebRTCPeer interface {
|
|||||||
SetRemoteDescription(webrtc.SessionDescription) error
|
SetRemoteDescription(webrtc.SessionDescription) error
|
||||||
SetCandidate(webrtc.ICECandidateInit) error
|
SetCandidate(webrtc.ICECandidateInit) error
|
||||||
|
|
||||||
SetVideoBitrate(bitrate int) error
|
SetVideo(StreamSelector) error
|
||||||
SetVideoID(videoID string) error
|
VideoID() (string, bool)
|
||||||
GetVideoID() string
|
|
||||||
SetPaused(isPaused bool) error
|
SetPaused(isPaused bool) error
|
||||||
|
Paused() bool
|
||||||
SetVideoAuto(auto bool)
|
SetVideoAuto(auto bool)
|
||||||
VideoAuto() bool
|
VideoAuto() bool
|
||||||
|
|
||||||
@ -42,6 +43,6 @@ type WebRTCManager interface {
|
|||||||
|
|
||||||
ICEServers() []ICEServer
|
ICEServers() []ICEServer
|
||||||
|
|
||||||
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
|
CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
|
||||||
SetCursorPosition(x, y int)
|
SetCursorPosition(x, y int)
|
||||||
}
|
}
|
||||||
|
153
pkg/utils/trenddetector.go
Normal file
153
pkg/utils/trenddetector.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ------------------------------------------------
|
||||||
|
|
||||||
|
type TrendDirection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TrendDirectionNeutral TrendDirection = iota
|
||||||
|
TrendDirectionUpward
|
||||||
|
TrendDirectionDownward
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t TrendDirection) String() string {
|
||||||
|
switch t {
|
||||||
|
case TrendDirectionNeutral:
|
||||||
|
return "NEUTRAL"
|
||||||
|
case TrendDirectionUpward:
|
||||||
|
return "UPWARD"
|
||||||
|
case TrendDirectionDownward:
|
||||||
|
return "DOWNWARD"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d", int(t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------
|
||||||
|
|
||||||
|
type TrendDetectorParams struct {
|
||||||
|
RequiredSamples int
|
||||||
|
DownwardTrendThreshold float64
|
||||||
|
CollapseValues bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TrendDetector struct {
|
||||||
|
params TrendDetectorParams
|
||||||
|
|
||||||
|
startTime time.Time
|
||||||
|
numSamples int
|
||||||
|
values []int64
|
||||||
|
lowestValue int64
|
||||||
|
highestValue int64
|
||||||
|
|
||||||
|
direction TrendDirection
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
|
||||||
|
return &TrendDetector{
|
||||||
|
params: params,
|
||||||
|
startTime: time.Now(),
|
||||||
|
direction: TrendDirectionNeutral,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) Seed(value int64) {
|
||||||
|
if len(t.values) != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.values = append(t.values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) AddValue(value int64) {
|
||||||
|
t.numSamples++
|
||||||
|
if t.lowestValue == 0 || value < t.lowestValue {
|
||||||
|
t.lowestValue = value
|
||||||
|
}
|
||||||
|
if value > t.highestValue {
|
||||||
|
t.highestValue = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ignore duplicate values
|
||||||
|
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(t.values) == t.params.RequiredSamples {
|
||||||
|
t.values = t.values[1:]
|
||||||
|
}
|
||||||
|
t.values = append(t.values, value)
|
||||||
|
|
||||||
|
t.updateDirection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) GetLowest() int64 {
|
||||||
|
return t.lowestValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) GetHighest() int64 {
|
||||||
|
return t.highestValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) GetValues() []int64 {
|
||||||
|
return t.values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) GetDirection() TrendDirection {
|
||||||
|
return t.direction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) ToString() string {
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(t.startTime).Seconds()
|
||||||
|
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
|
||||||
|
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TrendDetector) updateDirection() {
|
||||||
|
if len(t.values) < t.params.RequiredSamples {
|
||||||
|
t.direction = TrendDirectionNeutral
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// using Kendall's Tau to find trend
|
||||||
|
kt := kendallsTau(t.values)
|
||||||
|
|
||||||
|
t.direction = TrendDirectionNeutral
|
||||||
|
switch {
|
||||||
|
case kt > 0:
|
||||||
|
t.direction = TrendDirectionUpward
|
||||||
|
case kt < t.params.DownwardTrendThreshold:
|
||||||
|
t.direction = TrendDirectionDownward
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------
|
||||||
|
|
||||||
|
func kendallsTau(values []int64) float64 {
|
||||||
|
concordantPairs := 0
|
||||||
|
discordantPairs := 0
|
||||||
|
|
||||||
|
for i := 0; i < len(values)-1; i++ {
|
||||||
|
for j := i + 1; j < len(values); j++ {
|
||||||
|
if values[i] < values[j] {
|
||||||
|
concordantPairs++
|
||||||
|
} else if values[i] > values[j] {
|
||||||
|
discordantPairs++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (concordantPairs + discordantPairs) == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user