WebRTC congestion control (#26)

* Add congestion control

* Improve stream matching, add manual stream selection, add metrics

* Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively

* Missing signal response, fix video auto bug

* Remove redundant mutex

* Bitrate history queue

* Get bitrate fn support h264 & float64

---------

Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
Miroslav Šedivý 2023-02-06 19:45:51 +01:00 committed by GitHub
parent e80ae8019e
commit 2364facd60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 738 additions and 222 deletions

View File

@ -1,105 +0,0 @@
package capture
import (
"errors"
"fmt"
"math"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]*StreamSinkManagerCtx
streamIDs []string
}
func bucketsNew(codec codec.RTPCodec, streams map[string]*StreamSinkManagerCtx, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (m *BucketsManagerCtx) shutdown() {
m.logger.Info().Msgf("shutdown")
}
func (m *BucketsManagerCtx) destroyAll() {
for _, stream := range m.streams {
if stream.Started() {
stream.destroyPipeline()
}
}
}
func (m *BucketsManagerCtx) recreateAll() error {
for _, stream := range m.streams {
if stream.Started() {
err := stream.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (m *BucketsManagerCtx) IDs() []string {
return m.streamIDs
}
func (m *BucketsManagerCtx) Codec() codec.RTPCodec {
return m.codec
}
func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(func(bitrate int) error {
stream, ok := m.findNearestStream(bitrate)
if !ok {
return fmt.Errorf("no stream found for bitrate %d", bitrate)
}
return receiver.SetStream(stream)
})
return nil
}
func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) {
minDiff := math.MaxInt
for _, s := range m.streams {
streamBitrate, err := s.Bitrate()
if err != nil {
m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID())
continue
}
diffAbs := int(math.Abs(float64(bitrate - streamBitrate)))
if diffAbs < minDiff {
minDiff, ss = diffAbs, s
}
}
ok = ss != nil
return
}
func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -0,0 +1,145 @@
package buckets
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *BucketsManagerCtx) Shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.DestroyAll()
}
func (manager *BucketsManagerCtx) DestroyAll() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *BucketsManagerCtx) RecreateAll() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *BucketsManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
// bitrate history is per receiver
bitrateHistory := &queue{}
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
bitrate := peerBitrate
if receiver.VideoAuto() {
bitrate = bitrateHistory.normaliseBitrate(bitrate)
}
stream := manager.findNearestStream(bitrate)
streamID := stream.ID()
// TODO: make this less noisy in logs
manager.logger.Debug().
Str("video_id", streamID).
Int("len", bitrateHistory.len()).
Int("peer_bitrate", peerBitrate).
Int("bitrate", bitrate).
Msg("change video bitrate")
return receiver.SetStream(stream)
})
receiver.OnVideoChange(func(videoID string) (bool, error) {
stream := manager.streams[videoID]
manager.logger.Info().
Str("video_id", videoID).
Msg("video change")
return receiver.SetStream(stream)
})
}
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: peerBitrate - stream.Bitrate(),
})
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.OnVideoChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -0,0 +1,83 @@
package buckets
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
type fields struct {
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
}
type args struct {
peerBitrate int
}
tests := []struct {
name string
fields fields
args args
want types.StreamSinkManager
}{
{
name: "findNearestStream",
fields: fields{
streams: map[string]types.StreamSinkManager{
"1": mockStreamSink{
id: "1",
bitrate: 500,
},
"2": mockStreamSink{
id: "2",
bitrate: 750,
},
"3": mockStreamSink{
id: "3",
bitrate: 1000,
},
"4": mockStreamSink{
id: "4",
bitrate: 1250,
},
"5": mockStreamSink{
id: "5",
bitrate: 1700,
},
},
},
args: args{
peerBitrate: 950,
},
want: mockStreamSink{
id: "2",
bitrate: 750,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
}
})
}
}
type mockStreamSink struct {
id string
bitrate int
types.StreamSinkManager
}
func (m mockStreamSink) ID() string {
return m.id
}
func (m mockStreamSink) Bitrate() int {
return m.bitrate
}

View File

@ -0,0 +1,88 @@
package buckets
import (
"math"
"sync"
"time"
)
type queue struct {
sync.Mutex
q []elem
}
type elem struct {
created time.Time
bitrate int
}
func (q *queue) push(v elem) {
q.Lock()
defer q.Unlock()
// if the first element is older than 10 seconds, remove it
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
q.q = q.q[1:]
}
q.q = append(q.q, v)
}
func (q *queue) len() int {
q.Lock()
defer q.Unlock()
return len(q.q)
}
func (q *queue) avg() int {
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q {
sum += v.bitrate
}
return sum / len(q.q)
}
func (q *queue) avgLastN(n int) int {
if n <= 0 {
return q.avg()
}
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q[len(q.q)-n:] {
sum += v.bitrate
}
return sum / n
}
func (q *queue) normaliseBitrate(currentBitrate int) int {
avgBitrate := float64(q.avg())
histLen := float64(q.len())
q.push(elem{
bitrate: currentBitrate,
created: time.Now(),
})
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
return currentBitrate
}
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
if lastN > q.len() {
lastN = q.len()
}
if lastN == 0 {
return currentBitrate
}
return q.avgLastN(lastN)
}

View File

@ -0,0 +1,99 @@
package buckets
import "testing"
func Queue_normaliseBitrate(t *testing.T) {
type fields struct {
queue *queue
}
type args struct {
currentBitrate int
}
tests := []struct {
name string
fields fields
args args
want []int
}{
{
name: "normaliseBitrate: big drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 350,
},
want: []int{816, 700, 537, 350, 350},
}, {
name: "normaliseBitrate: small drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 700,
},
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
}, {
name: "normaliseBitrate",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 1350,
},
want: []int{943, 1003, 1060, 1085},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.queue
for i := 0; i < len(tt.want); i++ {
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
}
}
})
}
}

View File

@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/capture/buckets"
"github.com/demodesk/neko/internal/config" "github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec" "github.com/demodesk/neko/pkg/types/codec"
@ -22,7 +23,7 @@ type CaptureManagerCtx struct {
broadcast *BroacastManagerCtx broadcast *BroacastManagerCtx
screencast *ScreencastManagerCtx screencast *ScreencastManagerCtx
audio *StreamSinkManagerCtx audio *StreamSinkManagerCtx
video *BucketsManagerCtx video types.BucketsManager
// sources // sources
webcam *StreamSrcManagerCtx webcam *StreamSrcManagerCtx
@ -32,7 +33,7 @@ type CaptureManagerCtx struct {
func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx { func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx {
logger := log.With().Str("module", "capture").Logger() logger := log.With().Str("module", "capture").Logger()
videos := map[string]*StreamSinkManagerCtx{} videos := map[string]types.StreamSinkManager{}
for video_id, cnf := range config.VideoPipelines { for video_id, cnf := range config.VideoPipelines {
pipelineConf := cnf pipelineConf := cnf
@ -68,9 +69,10 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Msg("syntax check for video stream pipeline passed") Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize) getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if err != nil { if _, err = getVideoBitrate(); err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate") logger.Panic().Err(err).Msg("unable to get video bitrate")
} }
// append to videos // append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate) videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
} }
@ -139,7 +141,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil ), nil
}, "audio", nil), }, "audio", nil),
video: bucketsNew(config.VideoCodec, videos, config.VideoIDs), video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs),
// sources // sources
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{ webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
@ -200,7 +202,7 @@ func (manager *CaptureManagerCtx) Start() {
} }
manager.desktop.OnBeforeScreenSizeChange(func() { manager.desktop.OnBeforeScreenSizeChange(func() {
manager.video.destroyAll() manager.video.DestroyAll()
if manager.broadcast.Started() { if manager.broadcast.Started() {
manager.broadcast.destroyPipeline() manager.broadcast.destroyPipeline()
@ -212,7 +214,7 @@ func (manager *CaptureManagerCtx) Start() {
}) })
manager.desktop.OnAfterScreenSizeChange(func() { manager.desktop.OnAfterScreenSizeChange(func() {
err := manager.video.recreateAll() err := manager.video.RecreateAll()
if err != nil { if err != nil {
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines") manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
} }
@ -240,7 +242,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
manager.screencast.shutdown() manager.screencast.shutdown()
manager.audio.shutdown() manager.audio.shutdown()
manager.video.shutdown() manager.video.Shutdown()
manager.webcam.shutdown() manager.webcam.shutdown()
manager.microphone.shutdown() manager.microphone.shutdown()

View File

@ -105,7 +105,7 @@ func (manager *StreamSinkManagerCtx) shutdown() {
} }
manager.listenersMu.Unlock() manager.listenersMu.Unlock()
manager.destroyPipeline() manager.DestroyPipeline()
manager.wg.Wait() manager.wg.Wait()
} }
@ -113,12 +113,19 @@ func (manager *StreamSinkManagerCtx) ID() string {
return manager.id return manager.id
} }
func (manager *StreamSinkManagerCtx) Bitrate() (int, error) { func (manager *StreamSinkManagerCtx) Bitrate() int {
if manager.getBitrate == nil { if manager.getBitrate == nil {
return 0, nil return 0
} }
// recalculate bitrate every time, take screen resolution (and fps) into account // recalculate bitrate every time, take screen resolution (and fps) into account
return manager.getBitrate() // we called this function during startup, so it shouldn't error here
bitrate, err := manager.getBitrate()
if err != nil {
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
}
return bitrate
} }
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
@ -127,7 +134,7 @@ func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
func (manager *StreamSinkManagerCtx) start() error { func (manager *StreamSinkManagerCtx) start() error {
if len(manager.listeners) == 0 { if len(manager.listeners) == 0 {
err := manager.createPipeline() err := manager.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err return err
} }
@ -140,7 +147,7 @@ func (manager *StreamSinkManagerCtx) start() error {
func (manager *StreamSinkManagerCtx) stop() { func (manager *StreamSinkManagerCtx) stop() {
if len(manager.listeners) == 0 { if len(manager.listeners) == 0 {
manager.destroyPipeline() manager.DestroyPipeline()
manager.logger.Info().Msgf("last listener, stopping") manager.logger.Info().Msgf("last listener, stopping")
} }
} }
@ -259,7 +266,7 @@ func (manager *StreamSinkManagerCtx) Started() bool {
return manager.ListenersCount() > 0 return manager.ListenersCount() > 0
} }
func (manager *StreamSinkManagerCtx) createPipeline() error { func (manager *StreamSinkManagerCtx) CreatePipeline() error {
manager.pipelineMu.Lock() manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock() defer manager.pipelineMu.Unlock()
@ -313,7 +320,7 @@ func (manager *StreamSinkManagerCtx) createPipeline() error {
return nil return nil
} }
func (manager *StreamSinkManagerCtx) destroyPipeline() { func (manager *StreamSinkManagerCtx) DestroyPipeline() {
manager.pipelineMu.Lock() manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock() defer manager.pipelineMu.Unlock()

View File

@ -9,6 +9,8 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"github.com/pion/interceptor" "github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/interceptor/pkg/gcc"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -35,6 +37,9 @@ const keepAliveInterval = 2 * time.Second
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval // send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
const rtcpPLIInterval = 3 * time.Second const rtcpPLIInterval = 3 * time.Second
// how often we check the bitrate of each client. Default is 250ms
const bitrateCheckInterval = 250 * time.Millisecond
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
configuration := webrtc.Configuration{ configuration := webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
@ -153,12 +158,12 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServers return manager.config.ICEServers
} }
func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, error) { func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine // create media engine
engine := &webrtc.MediaEngine{} engine := &webrtc.MediaEngine{}
for _, codec := range codecs { for _, codec := range codecs {
if err := codec.Register(engine); err != nil { if err := codec.Register(engine); err != nil {
return nil, err return nil, nil, err
} }
} }
@ -205,8 +210,29 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
// create interceptor registry // create interceptor registry
registry := &interceptor.Registry{} registry := &interceptor.Registry{}
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
if bitrate == 0 {
bitrate = 1000000
}
return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(bitrate))
})
if err != nil {
return nil, nil, err
}
estimatorChan := make(chan cc.BandwidthEstimator, 1)
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
estimatorChan <- estimator
})
registry.Add(congestionController)
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
return nil, nil, err
}
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil { if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
return nil, err return nil, nil, err
} }
// create new API // create new API
@ -217,10 +243,12 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
) )
// create new peer connection // create new peer connection
return api.NewPeerConnection(manager.webrtcConfiguration) configuration := manager.webrtcConfiguration
connection, err := api.NewPeerConnection(configuration)
return connection, <-estimatorChan, err
} }
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) { func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
id := atomic.AddInt32(&manager.peerId, 1) id := atomic.AddInt32(&manager.peerId, 1)
manager.metrics.NewConnection(session) manager.metrics.NewConnection(session)
@ -236,7 +264,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
video := manager.capture.Video() video := manager.capture.Video()
videoCodec := video.Codec() videoCodec := video.Codec()
connection, err := manager.newPeerConnection([]codec.RTPCodec{ connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{
audioCodec, audioCodec,
videoCodec, videoCodec,
}, logger) }, logger)
@ -244,6 +272,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
return nil, err return nil, err
} }
if bitrate == 0 {
bitrate = estimator.GetTargetBitrate()
}
// asynchronously send local ICE Candidates // asynchronously send local ICE Candidates
if manager.config.ICETrickle { if manager.config.ICETrickle {
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) { connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
@ -268,31 +300,117 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
} }
// set stream for audio track // set stream for audio track
err = audioTrack.SetStream(audio) _, err = audioTrack.SetStream(audio)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// video track // video track
videoTrack, err := NewTrack(logger, videoCodec, connection, WithVideoAuto(videoAuto))
videoTrack, err := NewTrack(logger, videoCodec, connection)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// let video stream bucket manager handle stream subscriptions // let video stream bucket manager handle stream subscriptions
err = video.SetReceiver(videoTrack) video.SetReceiver(videoTrack)
if err != nil {
return nil, err changeVideoFromBitrate := func(peerBitrate int) {
// when switching from manual to auto bitrate estimation, in case the estimator is
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
if peerBitrate == 0 {
peerBitrate = estimator.GetTargetBitrate()
manager.logger.Debug().
Int("peer_bitrate", peerBitrate).
Msg("evaluated bitrate")
} }
// set initial video bitrate ok, err := videoTrack.SetBitrate(peerBitrate)
if err = videoTrack.SetBitrate(bitrate); err != nil { if err != nil {
return nil, err logger.Error().Err(err).
Int("peer_bitrate", peerBitrate).
Msg("unable to set video bitrate")
return
}
if !ok {
return
} }
videoID := videoTrack.stream.ID() videoID := videoTrack.stream.ID()
bitrate := videoTrack.stream.Bitrate()
manager.metrics.SetVideoID(session, videoID) manager.metrics.SetVideoID(session, videoID)
manager.logger.Debug().
Int("peer_bitrate", peerBitrate).
Int("video_bitrate", bitrate).
Str("video_id", videoID).
Msg("peer bitrate triggered video stream change")
go session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: videoTrack.VideoAuto(),
})
}
changeVideoFromID := func(videoID string) (bitrate int) {
changed, err := videoTrack.SetVideoID(videoID)
if err != nil {
logger.Error().Err(err).
Str("video_id", videoID).
Msg("unable to set video stream")
return
}
if !changed {
return
}
bitrate = videoTrack.stream.Bitrate()
manager.logger.Debug().
Str("video_id", videoID).
Int("video_bitrate", bitrate).
Msg("peer video id triggered video stream change")
go session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: videoTrack.VideoAuto(),
})
return
}
manager.logger.Info().
Int("target_bitrate", bitrate).
Msg("estimated initial peer bitrate")
// set initial video bitrate
changeVideoFromBitrate(bitrate)
// use a ticker to get current client target bitrate
go func() {
ticker := time.NewTicker(bitrateCheckInterval)
defer ticker.Stop()
for range ticker.C {
targetBitrate := estimator.GetTargetBitrate()
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, float64(targetBitrate))
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
if !videoTrack.VideoAuto() {
continue
}
changeVideoFromBitrate(targetBitrate)
}
}()
// data channel // data channel
@ -305,24 +423,17 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
logger: logger, logger: logger,
connection: connection, connection: connection,
dataChannel: dataChannel, dataChannel: dataChannel,
changeVideo: func(bitrate int) error { changeVideoFromBitrate: changeVideoFromBitrate,
if err := videoTrack.SetBitrate(bitrate); err != nil { changeVideoFromID: changeVideoFromID,
return err
}
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID)
return nil
},
// TODO: Refactor. // TODO: Refactor.
videoId: func() string { videoId: videoTrack.stream.ID,
return videoTrack.stream.ID()
},
setPaused: func(isPaused bool) { setPaused: func(isPaused bool) {
videoTrack.SetPaused(isPaused) videoTrack.SetPaused(isPaused)
audioTrack.SetPaused(isPaused) audioTrack.SetPaused(isPaused)
}, },
iceTrickle: manager.config.ICETrickle, iceTrickle: manager.config.ICETrickle,
setVideoAuto: videoTrack.SetVideoAuto,
getVideoAuto: videoTrack.VideoAuto,
} }
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
@ -515,11 +626,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
}) })
videoTrack.OnRTCP(func(p rtcp.Packet) { videoTrack.OnRTCP(func(p rtcp.Packet) {
switch rtcpPacket := p.(type) { if rtcpPacket, ok := p.(*rtcp.ReceiverReport); ok {
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, rtcpPacket.Bitrate)
case *rtcp.ReceiverReport:
l := len(rtcpPacket.Reports) l := len(rtcpPacket.Reports)
if l > 0 { if l > 0 {
// use only last report // use only last report

View File

@ -327,10 +327,10 @@ func (m *metricsCtx) SetVideoID(session types.Session, videoId string) {
} }
} }
func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float32) { func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float64) {
met := m.getBySession(session) met := m.getBySession(session)
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate)) met.receiverEstimatedMaximumBitrate.Set(bitrate)
} }
func (m *metricsCtx) SetReceiverReport(session types.Session, report rtcp.ReceptionReport) { func (m *metricsCtx) SetReceiverReport(session types.Session, report rtcp.ReceptionReport) {

View File

@ -17,9 +17,12 @@ type WebRTCPeerCtx struct {
logger zerolog.Logger logger zerolog.Logger
connection *webrtc.PeerConnection connection *webrtc.PeerConnection
dataChannel *webrtc.DataChannel dataChannel *webrtc.DataChannel
changeVideo func(bitrate int) error changeVideoFromBitrate func(bitrate int)
changeVideoFromID func(id string) int
videoId func() string videoId func() string
setPaused func(isPaused bool) setPaused func(isPaused bool)
setVideoAuto func(auto bool)
getVideoAuto func() bool
iceTrickle bool iceTrickle bool
} }
@ -115,7 +118,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate) return peer.connection.AddICECandidate(candidate)
} }
func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error { func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
@ -123,12 +126,24 @@ func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
return types.ErrWebRTCConnectionNotFound return types.ErrWebRTCConnectionNotFound
} }
peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate") peer.changeVideoFromBitrate(peerBitrate)
return peer.changeVideo(bitrate) return nil
}
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
peer.mu.Lock()
defer peer.mu.Unlock()
if peer.connection == nil {
return types.ErrWebRTCConnectionNotFound
}
peer.changeVideoFromID(videoID)
return nil
} }
// TODO: Refactor. // TODO: Refactor.
func (peer *WebRTCPeerCtx) GetVideoId() string { func (peer *WebRTCPeerCtx) GetVideoID() string {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
@ -215,3 +230,11 @@ func (peer *WebRTCPeerCtx) Destroy() {
peer.connection = nil peer.connection = nil
} }
} }
func (peer *WebRTCPeerCtx) SetVideoAuto(auto bool) {
peer.setVideoAuto(auto)
}
func (peer *WebRTCPeerCtx) VideoAuto() bool {
return peer.getVideoAuto()
}

View File

@ -19,6 +19,8 @@ type Track struct {
logger zerolog.Logger logger zerolog.Logger
track *webrtc.TrackLocalStaticSample track *webrtc.TrackLocalStaticSample
paused bool paused bool
videoAuto bool
videoAutoMu sync.RWMutex
listener func(sample types.Sample) listener func(sample types.Sample)
stream types.StreamSinkManager stream types.StreamSinkManager
@ -27,10 +29,19 @@ type Track struct {
onRtcp func(rtcp.Packet) onRtcp func(rtcp.Packet)
onRtcpMu sync.RWMutex onRtcpMu sync.RWMutex
bitrateChange func(int) error bitrateChange func(int) (bool, error)
videoChange func(string) (bool, error)
} }
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) { type option func(*Track)
func WithVideoAuto(auto bool) option {
return func(t *Track) {
t.videoAuto = auto
}
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...option) (*Track, error) {
id := codec.Type.String() id := codec.Type.String()
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream") track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
if err != nil { if err != nil {
@ -44,6 +55,10 @@ func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.Pe
track: track, track: track,
} }
for _, opt := range opts {
opt(t)
}
t.listener = func(sample types.Sample) { t.listener = func(sample types.Sample) {
if t.paused { if t.paused {
return return
@ -96,13 +111,13 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
} }
} }
func (t *Track) SetStream(stream types.StreamSinkManager) error { func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock() t.streamMu.Lock()
defer t.streamMu.Unlock() defer t.streamMu.Unlock()
// if we already listen to the stream, do nothing // if we already listen to the stream, do nothing
if t.stream == stream { if t.stream == stream {
return nil return false, nil
} }
var err error var err error
@ -111,12 +126,13 @@ func (t *Track) SetStream(stream types.StreamSinkManager) error {
} else { } else {
err = stream.AddListener(&t.listener) err = stream.AddListener(&t.listener)
} }
if err != nil {
if err == nil { return false, err
t.stream = stream
} }
return err t.stream = stream
return true, nil
} }
func (t *Track) RemoveStream() { func (t *Track) RemoveStream() {
@ -140,14 +156,38 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
t.onRtcp = f t.onRtcp = f
} }
func (t *Track) SetBitrate(bitrate int) error { func (t *Track) SetBitrate(bitrate int) (bool, error) {
if t.bitrateChange == nil { if t.bitrateChange == nil {
return fmt.Errorf("bitrate change not supported") return false, fmt.Errorf("bitrate change not supported")
} }
return t.bitrateChange(bitrate) return t.bitrateChange(bitrate)
} }
func (t *Track) OnBitrateChange(f func(int) error) { func (t *Track) SetVideoID(videoID string) (bool, error) {
if t.videoChange == nil {
return false, fmt.Errorf("video change not supported")
}
return t.videoChange(videoID)
}
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
t.bitrateChange = f t.bitrateChange = f
} }
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
t.videoChange = f
}
func (t *Track) SetVideoAuto(auto bool) {
t.videoAutoMu.Lock()
defer t.videoAutoMu.Unlock()
t.videoAuto = auto
}
func (t *Track) VideoAuto() bool {
t.videoAutoMu.RLock()
defer t.videoAutoMu.RUnlock()
return t.videoAuto
}

View File

@ -28,7 +28,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
} }
} }
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate) offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
if err != nil { if err != nil {
return err return err
} }
@ -39,7 +39,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
webrtcPeer.SetPaused(true) webrtcPeer.SetPaused(true)
} }
payload.Video = webrtcPeer.GetVideoId() payload.Video = webrtcPeer.GetVideoID()
} }
session.Send( session.Send(
@ -47,7 +47,9 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{ message.SignalProvide{
SDP: offer.SDP, SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(), ICEServers: h.webrtc.ICEServers(),
Video: payload.Video, // TODO: Refactor. Video: payload.Video, // TODO: Refactor
Bitrate: payload.Bitrate,
VideoAuto: payload.VideoAuto,
}) })
return nil return nil
@ -64,7 +66,7 @@ func (h *MessageHandlerCtx) signalRestart(session types.Session) error {
return err return err
} }
// TODO: Use offer event intead. // TODO: Use offer event instead.
session.Send( session.Send(
event.SIGNAL_RESTART, event.SIGNAL_RESTART,
message.SignalDescription{ message.SignalDescription{
@ -123,25 +125,17 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist") return errors.New("webRTC peer does not exist")
} }
var err error peer.SetVideoAuto(payload.VideoAuto)
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
if err = peer.SetVideoBitrate(payload.Bitrate); err != nil { if payload.Video != "" {
return err if err := peer.SetVideoID(payload.Video); err != nil {
h.logger.Error().Err(err).Msg("failed to set video id")
}
} else {
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
h.logger.Error().Err(err).Msg("failed to set video bitrate")
}
} }
session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: peer.GetVideoId(), // TODO: Refactor.
Bitrate: payload.Bitrate,
})
return nil return nil
} }

View File

@ -8,9 +8,8 @@ import (
"strings" "strings"
"github.com/PaesslerAG/gval" "github.com/PaesslerAG/gval"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/demodesk/neko/pkg/types/codec" "github.com/demodesk/neko/pkg/types/codec"
"github.com/pion/webrtc/v3/pkg/media"
) )
var ( var (
@ -20,16 +19,23 @@ var (
type Sample media.Sample type Sample media.Sample
type Receiver interface { type Receiver interface {
SetStream(stream StreamSinkManager) error SetStream(stream StreamSinkManager) (changed bool, err error)
RemoveStream() RemoveStream()
OnBitrateChange(f func(int) error) OnBitrateChange(f func(bitrate int) (changed bool, err error))
OnVideoChange(f func(videoID string) (changed bool, err error))
VideoAuto() bool
SetVideoAuto(videoAuto bool)
} }
type BucketsManager interface { type BucketsManager interface {
IDs() []string IDs() []string
Codec() codec.RTPCodec Codec() codec.RTPCodec
SetReceiver(receiver Receiver) error SetReceiver(receiver Receiver)
RemoveReceiver(receiver Receiver) error RemoveReceiver(receiver Receiver) error
DestroyAll()
RecreateAll() error
Shutdown()
} }
type BroadcastManager interface { type BroadcastManager interface {
@ -48,6 +54,7 @@ type ScreencastManager interface {
type StreamSinkManager interface { type StreamSinkManager interface {
ID() string ID() string
Codec() codec.RTPCodec Codec() codec.RTPCodec
Bitrate() int
AddListener(listener *func(sample Sample)) error AddListener(listener *func(sample Sample)) error
RemoveListener(listener *func(sample Sample)) error RemoveListener(listener *func(sample Sample)) error
@ -55,6 +62,9 @@ type StreamSinkManager interface {
ListenersCount() int ListenersCount() int
Started() bool Started() bool
CreatePipeline() error
DestroyPipeline()
} }
type StreamSrcManager interface { type StreamSrcManager interface {
@ -201,17 +211,33 @@ func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (in
}), }),
} }
// TOOD: do not read target-bitrate from pipeline, but only from config.
// TODO: This is only for vp8. // TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"] expr, ok := config.GstParams["target-bitrate"]
if !ok { if !ok {
return 0, fmt.Errorf("target-bitrate not found") // TODO: This is only for h264.
expr, ok = config.GstParams["bitrate"]
if !ok {
return 0, fmt.Errorf("bitrate not found")
}
} }
targetBitrate, err := gval.Evaluate(expr, values, language...) bitrate, err := gval.Evaluate(expr, values, language...)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
} }
return targetBitrate.(int), nil var bitrateInt int
switch val := bitrate.(type) {
case int:
bitrateInt = val
case float64:
bitrateInt = (int)(val)
default:
return 0, fmt.Errorf("bitrate is not int or float64")
}
return bitrateInt, nil
} }
} }

View File

@ -48,7 +48,10 @@ type SystemDisconnect struct {
type SignalProvide struct { type SignalProvide struct {
SDP string `json:"sdp"` SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"` ICEServers []types.ICEServer `json:"iceservers"`
Video string `json:"video"` // TODO: Refactor. // TODO: Use SignalVideo struct.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
} }
type SignalCandidate struct { type SignalCandidate struct {
@ -60,8 +63,9 @@ type SignalDescription struct {
} }
type SignalVideo struct { type SignalVideo struct {
Video string `json:"video"` // TODO: Refactor. Video string `json:"video"`
Bitrate int `json:"bitrate"` Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
} }
///////////////////////////// /////////////////////////////

View File

@ -25,8 +25,11 @@ type WebRTCPeer interface {
SetCandidate(candidate webrtc.ICECandidateInit) error SetCandidate(candidate webrtc.ICECandidateInit) error
SetVideoBitrate(bitrate int) error SetVideoBitrate(bitrate int) error
GetVideoId() string SetVideoID(videoID string) error
GetVideoID() string
SetPaused(isPaused bool) error SetPaused(isPaused bool) error
SetVideoAuto(auto bool)
VideoAuto() bool
SendCursorPosition(x, y int) error SendCursorPosition(x, y int) error
SendCursorImage(cur *CursorImage, img []byte) error SendCursorImage(cur *CursorImage, img []byte) error
@ -40,6 +43,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer ICEServers() []ICEServer
CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error) CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
SetCursorPosition(x, y int) SetCursorPosition(x, y int)
} }