WebRTC congestion control (#26)

* Add congestion control

* Improve stream matching, add manual stream selection, add metrics

* Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively

* Missing signal response, fix video auto bug

* Remove redundant mutex

* Bitrate history queue

* Get bitrate fn support h264 & float64

---------

Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
Miroslav Šedivý 2023-02-06 19:45:51 +01:00 committed by GitHub
parent e80ae8019e
commit 2364facd60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 738 additions and 222 deletions

View File

@ -1,105 +0,0 @@
package capture
import (
"errors"
"fmt"
"math"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]*StreamSinkManagerCtx
streamIDs []string
}
func bucketsNew(codec codec.RTPCodec, streams map[string]*StreamSinkManagerCtx, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (m *BucketsManagerCtx) shutdown() {
m.logger.Info().Msgf("shutdown")
}
func (m *BucketsManagerCtx) destroyAll() {
for _, stream := range m.streams {
if stream.Started() {
stream.destroyPipeline()
}
}
}
func (m *BucketsManagerCtx) recreateAll() error {
for _, stream := range m.streams {
if stream.Started() {
err := stream.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (m *BucketsManagerCtx) IDs() []string {
return m.streamIDs
}
func (m *BucketsManagerCtx) Codec() codec.RTPCodec {
return m.codec
}
func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(func(bitrate int) error {
stream, ok := m.findNearestStream(bitrate)
if !ok {
return fmt.Errorf("no stream found for bitrate %d", bitrate)
}
return receiver.SetStream(stream)
})
return nil
}
func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) {
minDiff := math.MaxInt
for _, s := range m.streams {
streamBitrate, err := s.Bitrate()
if err != nil {
m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID())
continue
}
diffAbs := int(math.Abs(float64(bitrate - streamBitrate)))
if diffAbs < minDiff {
minDiff, ss = diffAbs, s
}
}
ok = ss != nil
return
}
func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -0,0 +1,145 @@
package buckets
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type BucketsManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func BucketsNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *BucketsManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "buckets").
Logger()
return &BucketsManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *BucketsManagerCtx) Shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.DestroyAll()
}
func (manager *BucketsManagerCtx) DestroyAll() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *BucketsManagerCtx) RecreateAll() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *BucketsManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *BucketsManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *BucketsManagerCtx) SetReceiver(receiver types.Receiver) {
// bitrate history is per receiver
bitrateHistory := &queue{}
receiver.OnBitrateChange(func(peerBitrate int) (bool, error) {
bitrate := peerBitrate
if receiver.VideoAuto() {
bitrate = bitrateHistory.normaliseBitrate(bitrate)
}
stream := manager.findNearestStream(bitrate)
streamID := stream.ID()
// TODO: make this less noisy in logs
manager.logger.Debug().
Str("video_id", streamID).
Int("len", bitrateHistory.len()).
Int("peer_bitrate", peerBitrate).
Int("bitrate", bitrate).
Msg("change video bitrate")
return receiver.SetStream(stream)
})
receiver.OnVideoChange(func(videoID string) (bool, error) {
stream := manager.streams[videoID]
manager.logger.Info().
Str("video_id", videoID).
Msg("video change")
return receiver.SetStream(stream)
})
}
func (manager *BucketsManagerCtx) findNearestStream(peerBitrate int) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: peerBitrate - stream.Bitrate(),
})
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}
func (manager *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
receiver.OnBitrateChange(nil)
receiver.OnVideoChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -0,0 +1,83 @@
package buckets
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
func TestBucketsManagerCtx_FindNearestStream(t *testing.T) {
type fields struct {
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
}
type args struct {
peerBitrate int
}
tests := []struct {
name string
fields fields
args args
want types.StreamSinkManager
}{
{
name: "findNearestStream",
fields: fields{
streams: map[string]types.StreamSinkManager{
"1": mockStreamSink{
id: "1",
bitrate: 500,
},
"2": mockStreamSink{
id: "2",
bitrate: 750,
},
"3": mockStreamSink{
id: "3",
bitrate: 1000,
},
"4": mockStreamSink{
id: "4",
bitrate: 1250,
},
"5": mockStreamSink{
id: "5",
bitrate: 1700,
},
},
},
args: args{
peerBitrate: 950,
},
want: mockStreamSink{
id: "2",
bitrate: 750,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := BucketsNew(tt.fields.codec, tt.fields.streams, []string{})
if got := m.findNearestStream(tt.args.peerBitrate); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findNearestStream() = %v, want %v", got, tt.want)
}
})
}
}
type mockStreamSink struct {
id string
bitrate int
types.StreamSinkManager
}
func (m mockStreamSink) ID() string {
return m.id
}
func (m mockStreamSink) Bitrate() int {
return m.bitrate
}

View File

@ -0,0 +1,88 @@
package buckets
import (
"math"
"sync"
"time"
)
type queue struct {
sync.Mutex
q []elem
}
type elem struct {
created time.Time
bitrate int
}
func (q *queue) push(v elem) {
q.Lock()
defer q.Unlock()
// if the first element is older than 10 seconds, remove it
if len(q.q) > 0 && time.Since(q.q[0].created) > 10*time.Second {
q.q = q.q[1:]
}
q.q = append(q.q, v)
}
func (q *queue) len() int {
q.Lock()
defer q.Unlock()
return len(q.q)
}
func (q *queue) avg() int {
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q {
sum += v.bitrate
}
return sum / len(q.q)
}
func (q *queue) avgLastN(n int) int {
if n <= 0 {
return q.avg()
}
q.Lock()
defer q.Unlock()
if len(q.q) == 0 {
return 0
}
sum := 0
for _, v := range q.q[len(q.q)-n:] {
sum += v.bitrate
}
return sum / n
}
func (q *queue) normaliseBitrate(currentBitrate int) int {
avgBitrate := float64(q.avg())
histLen := float64(q.len())
q.push(elem{
bitrate: currentBitrate,
created: time.Now(),
})
if avgBitrate == 0 || histLen == 0 || currentBitrate == 0 {
return currentBitrate
}
lastN := int(math.Floor(float64(currentBitrate) / avgBitrate * histLen))
if lastN > q.len() {
lastN = q.len()
}
if lastN == 0 {
return currentBitrate
}
return q.avgLastN(lastN)
}

View File

@ -0,0 +1,99 @@
package buckets
import "testing"
func Queue_normaliseBitrate(t *testing.T) {
type fields struct {
queue *queue
}
type args struct {
currentBitrate int
}
tests := []struct {
name string
fields fields
args args
want []int
}{
{
name: "normaliseBitrate: big drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 350,
},
want: []int{816, 700, 537, 350, 350},
}, {
name: "normaliseBitrate: small drop",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 700,
},
want: []int{878, 842, 825, 825, 812, 787, 750, 700},
}, {
name: "normaliseBitrate",
fields: fields{
queue: &queue{
q: []elem{
{bitrate: 900},
{bitrate: 750},
{bitrate: 780},
{bitrate: 1100},
{bitrate: 950},
{bitrate: 700},
{bitrate: 800},
{bitrate: 900},
{bitrate: 1000},
{bitrate: 1100},
// avg = 898
},
},
},
args: args{
currentBitrate: 1350,
},
want: []int{943, 1003, 1060, 1085},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := tt.fields.queue
for i := 0; i < len(tt.want); i++ {
if got := m.normaliseBitrate(tt.args.currentBitrate); got != tt.want[i] {
t.Errorf("normaliseBitrate() [%d] = %v, want %v", i, got, tt.want[i])
}
}
})
}
}

View File

@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/capture/buckets"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
@ -22,7 +23,7 @@ type CaptureManagerCtx struct {
broadcast *BroacastManagerCtx
screencast *ScreencastManagerCtx
audio *StreamSinkManagerCtx
video *BucketsManagerCtx
video types.BucketsManager
// sources
webcam *StreamSrcManagerCtx
@ -32,7 +33,7 @@ type CaptureManagerCtx struct {
func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx {
logger := log.With().Str("module", "capture").Logger()
videos := map[string]*StreamSinkManagerCtx{}
videos := map[string]types.StreamSinkManager{}
for video_id, cnf := range config.VideoPipelines {
pipelineConf := cnf
@ -68,9 +69,10 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if err != nil {
if _, err = getVideoBitrate(); err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate")
}
// append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
}
@ -139,7 +141,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil
}, "audio", nil),
video: bucketsNew(config.VideoCodec, videos, config.VideoIDs),
video: buckets.BucketsNew(config.VideoCodec, videos, config.VideoIDs),
// sources
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
@ -200,7 +202,7 @@ func (manager *CaptureManagerCtx) Start() {
}
manager.desktop.OnBeforeScreenSizeChange(func() {
manager.video.destroyAll()
manager.video.DestroyAll()
if manager.broadcast.Started() {
manager.broadcast.destroyPipeline()
@ -212,7 +214,7 @@ func (manager *CaptureManagerCtx) Start() {
})
manager.desktop.OnAfterScreenSizeChange(func() {
err := manager.video.recreateAll()
err := manager.video.RecreateAll()
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
}
@ -240,7 +242,7 @@ func (manager *CaptureManagerCtx) Shutdown() error {
manager.screencast.shutdown()
manager.audio.shutdown()
manager.video.shutdown()
manager.video.Shutdown()
manager.webcam.shutdown()
manager.microphone.shutdown()

View File

@ -105,7 +105,7 @@ func (manager *StreamSinkManagerCtx) shutdown() {
}
manager.listenersMu.Unlock()
manager.destroyPipeline()
manager.DestroyPipeline()
manager.wg.Wait()
}
@ -113,12 +113,19 @@ func (manager *StreamSinkManagerCtx) ID() string {
return manager.id
}
func (manager *StreamSinkManagerCtx) Bitrate() (int, error) {
func (manager *StreamSinkManagerCtx) Bitrate() int {
if manager.getBitrate == nil {
return 0, nil
return 0
}
// recalculate bitrate every time, take screen resolution (and fps) into account
return manager.getBitrate()
// we called this function during startup, so it shouldn't error here
bitrate, err := manager.getBitrate()
if err != nil {
manager.logger.Err(err).Msg("unexpected error while getting bitrate")
}
return bitrate
}
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
@ -127,7 +134,7 @@ func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
func (manager *StreamSinkManagerCtx) start() error {
if len(manager.listeners) == 0 {
err := manager.createPipeline()
err := manager.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
@ -140,7 +147,7 @@ func (manager *StreamSinkManagerCtx) start() error {
func (manager *StreamSinkManagerCtx) stop() {
if len(manager.listeners) == 0 {
manager.destroyPipeline()
manager.DestroyPipeline()
manager.logger.Info().Msgf("last listener, stopping")
}
}
@ -259,7 +266,7 @@ func (manager *StreamSinkManagerCtx) Started() bool {
return manager.ListenersCount() > 0
}
func (manager *StreamSinkManagerCtx) createPipeline() error {
func (manager *StreamSinkManagerCtx) CreatePipeline() error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
@ -313,7 +320,7 @@ func (manager *StreamSinkManagerCtx) createPipeline() error {
return nil
}
func (manager *StreamSinkManagerCtx) destroyPipeline() {
func (manager *StreamSinkManagerCtx) DestroyPipeline() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()

View File

@ -9,6 +9,8 @@ import (
"github.com/pion/ice/v2"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/interceptor/pkg/gcc"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
@ -35,6 +37,9 @@ const keepAliveInterval = 2 * time.Second
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
const rtcpPLIInterval = 3 * time.Second
// how often we check the bitrate of each client. Default is 250ms
const bitrateCheckInterval = 250 * time.Millisecond
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
configuration := webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
@ -153,12 +158,12 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServers
}
func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, error) {
func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine
engine := &webrtc.MediaEngine{}
for _, codec := range codecs {
if err := codec.Register(engine); err != nil {
return nil, err
return nil, nil, err
}
}
@ -205,8 +210,29 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
// create interceptor registry
registry := &interceptor.Registry{}
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
if bitrate == 0 {
bitrate = 1000000
}
return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(bitrate))
})
if err != nil {
return nil, nil, err
}
estimatorChan := make(chan cc.BandwidthEstimator, 1)
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
estimatorChan <- estimator
})
registry.Add(congestionController)
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
return nil, nil, err
}
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
return nil, err
return nil, nil, err
}
// create new API
@ -217,10 +243,12 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
)
// create new peer connection
return api.NewPeerConnection(manager.webrtcConfiguration)
configuration := manager.webrtcConfiguration
connection, err := api.NewPeerConnection(configuration)
return connection, <-estimatorChan, err
}
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) {
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
id := atomic.AddInt32(&manager.peerId, 1)
manager.metrics.NewConnection(session)
@ -236,7 +264,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
video := manager.capture.Video()
videoCodec := video.Codec()
connection, err := manager.newPeerConnection([]codec.RTPCodec{
connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{
audioCodec,
videoCodec,
}, logger)
@ -244,6 +272,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
return nil, err
}
if bitrate == 0 {
bitrate = estimator.GetTargetBitrate()
}
// asynchronously send local ICE Candidates
if manager.config.ICETrickle {
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
@ -268,31 +300,117 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
}
// set stream for audio track
err = audioTrack.SetStream(audio)
_, err = audioTrack.SetStream(audio)
if err != nil {
return nil, err
}
// video track
videoTrack, err := NewTrack(logger, videoCodec, connection)
videoTrack, err := NewTrack(logger, videoCodec, connection, WithVideoAuto(videoAuto))
if err != nil {
return nil, err
}
// let video stream bucket manager handle stream subscriptions
err = video.SetReceiver(videoTrack)
if err != nil {
return nil, err
video.SetReceiver(videoTrack)
changeVideoFromBitrate := func(peerBitrate int) {
// when switching from manual to auto bitrate estimation, in case the estimator is
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
if peerBitrate == 0 {
peerBitrate = estimator.GetTargetBitrate()
manager.logger.Debug().
Int("peer_bitrate", peerBitrate).
Msg("evaluated bitrate")
}
// set initial video bitrate
if err = videoTrack.SetBitrate(bitrate); err != nil {
return nil, err
ok, err := videoTrack.SetBitrate(peerBitrate)
if err != nil {
logger.Error().Err(err).
Int("peer_bitrate", peerBitrate).
Msg("unable to set video bitrate")
return
}
if !ok {
return
}
videoID := videoTrack.stream.ID()
bitrate := videoTrack.stream.Bitrate()
manager.metrics.SetVideoID(session, videoID)
manager.logger.Debug().
Int("peer_bitrate", peerBitrate).
Int("video_bitrate", bitrate).
Str("video_id", videoID).
Msg("peer bitrate triggered video stream change")
go session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: videoTrack.VideoAuto(),
})
}
changeVideoFromID := func(videoID string) (bitrate int) {
changed, err := videoTrack.SetVideoID(videoID)
if err != nil {
logger.Error().Err(err).
Str("video_id", videoID).
Msg("unable to set video stream")
return
}
if !changed {
return
}
bitrate = videoTrack.stream.Bitrate()
manager.logger.Debug().
Str("video_id", videoID).
Int("video_bitrate", bitrate).
Msg("peer video id triggered video stream change")
go session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Bitrate: bitrate,
VideoAuto: videoTrack.VideoAuto(),
})
return
}
manager.logger.Info().
Int("target_bitrate", bitrate).
Msg("estimated initial peer bitrate")
// set initial video bitrate
changeVideoFromBitrate(bitrate)
// use a ticker to get current client target bitrate
go func() {
ticker := time.NewTicker(bitrateCheckInterval)
defer ticker.Stop()
for range ticker.C {
targetBitrate := estimator.GetTargetBitrate()
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, float64(targetBitrate))
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
if !videoTrack.VideoAuto() {
continue
}
changeVideoFromBitrate(targetBitrate)
}
}()
// data channel
@ -305,24 +423,17 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
logger: logger,
connection: connection,
dataChannel: dataChannel,
changeVideo: func(bitrate int) error {
if err := videoTrack.SetBitrate(bitrate); err != nil {
return err
}
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID)
return nil
},
changeVideoFromBitrate: changeVideoFromBitrate,
changeVideoFromID: changeVideoFromID,
// TODO: Refactor.
videoId: func() string {
return videoTrack.stream.ID()
},
videoId: videoTrack.stream.ID,
setPaused: func(isPaused bool) {
videoTrack.SetPaused(isPaused)
audioTrack.SetPaused(isPaused)
},
iceTrickle: manager.config.ICETrickle,
setVideoAuto: videoTrack.SetVideoAuto,
getVideoAuto: videoTrack.VideoAuto,
}
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
@ -515,11 +626,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
})
videoTrack.OnRTCP(func(p rtcp.Packet) {
switch rtcpPacket := p.(type) {
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, rtcpPacket.Bitrate)
case *rtcp.ReceiverReport:
if rtcpPacket, ok := p.(*rtcp.ReceiverReport); ok {
l := len(rtcpPacket.Reports)
if l > 0 {
// use only last report

View File

@ -327,10 +327,10 @@ func (m *metricsCtx) SetVideoID(session types.Session, videoId string) {
}
}
func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float32) {
func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float64) {
met := m.getBySession(session)
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
met.receiverEstimatedMaximumBitrate.Set(bitrate)
}
func (m *metricsCtx) SetReceiverReport(session types.Session, report rtcp.ReceptionReport) {

View File

@ -17,9 +17,12 @@ type WebRTCPeerCtx struct {
logger zerolog.Logger
connection *webrtc.PeerConnection
dataChannel *webrtc.DataChannel
changeVideo func(bitrate int) error
changeVideoFromBitrate func(bitrate int)
changeVideoFromID func(id string) int
videoId func() string
setPaused func(isPaused bool)
setVideoAuto func(auto bool)
getVideoAuto func() bool
iceTrickle bool
}
@ -115,7 +118,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate)
}
func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
peer.mu.Lock()
defer peer.mu.Unlock()
@ -123,12 +126,24 @@ func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
return types.ErrWebRTCConnectionNotFound
}
peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate")
return peer.changeVideo(bitrate)
peer.changeVideoFromBitrate(peerBitrate)
return nil
}
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
peer.mu.Lock()
defer peer.mu.Unlock()
if peer.connection == nil {
return types.ErrWebRTCConnectionNotFound
}
peer.changeVideoFromID(videoID)
return nil
}
// TODO: Refactor.
func (peer *WebRTCPeerCtx) GetVideoId() string {
func (peer *WebRTCPeerCtx) GetVideoID() string {
peer.mu.Lock()
defer peer.mu.Unlock()
@ -215,3 +230,11 @@ func (peer *WebRTCPeerCtx) Destroy() {
peer.connection = nil
}
}
func (peer *WebRTCPeerCtx) SetVideoAuto(auto bool) {
peer.setVideoAuto(auto)
}
func (peer *WebRTCPeerCtx) VideoAuto() bool {
return peer.getVideoAuto()
}

View File

@ -19,6 +19,8 @@ type Track struct {
logger zerolog.Logger
track *webrtc.TrackLocalStaticSample
paused bool
videoAuto bool
videoAutoMu sync.RWMutex
listener func(sample types.Sample)
stream types.StreamSinkManager
@ -27,10 +29,19 @@ type Track struct {
onRtcp func(rtcp.Packet)
onRtcpMu sync.RWMutex
bitrateChange func(int) error
bitrateChange func(int) (bool, error)
videoChange func(string) (bool, error)
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) {
type option func(*Track)
func WithVideoAuto(auto bool) option {
return func(t *Track) {
t.videoAuto = auto
}
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...option) (*Track, error) {
id := codec.Type.String()
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
if err != nil {
@ -44,6 +55,10 @@ func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.Pe
track: track,
}
for _, opt := range opts {
opt(t)
}
t.listener = func(sample types.Sample) {
if t.paused {
return
@ -96,13 +111,13 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
}
}
func (t *Track) SetStream(stream types.StreamSinkManager) error {
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if we already listen to the stream, do nothing
if t.stream == stream {
return nil
return false, nil
}
var err error
@ -111,12 +126,13 @@ func (t *Track) SetStream(stream types.StreamSinkManager) error {
} else {
err = stream.AddListener(&t.listener)
}
if err == nil {
t.stream = stream
if err != nil {
return false, err
}
return err
t.stream = stream
return true, nil
}
func (t *Track) RemoveStream() {
@ -140,14 +156,38 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
t.onRtcp = f
}
func (t *Track) SetBitrate(bitrate int) error {
func (t *Track) SetBitrate(bitrate int) (bool, error) {
if t.bitrateChange == nil {
return fmt.Errorf("bitrate change not supported")
return false, fmt.Errorf("bitrate change not supported")
}
return t.bitrateChange(bitrate)
}
func (t *Track) OnBitrateChange(f func(int) error) {
func (t *Track) SetVideoID(videoID string) (bool, error) {
if t.videoChange == nil {
return false, fmt.Errorf("video change not supported")
}
return t.videoChange(videoID)
}
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
t.bitrateChange = f
}
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
t.videoChange = f
}
func (t *Track) SetVideoAuto(auto bool) {
t.videoAutoMu.Lock()
defer t.videoAutoMu.Unlock()
t.videoAuto = auto
}
func (t *Track) VideoAuto() bool {
t.videoAutoMu.RLock()
defer t.videoAutoMu.RUnlock()
return t.videoAuto
}

View File

@ -28,7 +28,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
}
}
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate)
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate, payload.VideoAuto)
if err != nil {
return err
}
@ -39,7 +39,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
webrtcPeer.SetPaused(true)
}
payload.Video = webrtcPeer.GetVideoId()
payload.Video = webrtcPeer.GetVideoID()
}
session.Send(
@ -47,7 +47,9 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{
SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(),
Video: payload.Video, // TODO: Refactor.
Video: payload.Video, // TODO: Refactor
Bitrate: payload.Bitrate,
VideoAuto: payload.VideoAuto,
})
return nil
@ -64,7 +66,7 @@ func (h *MessageHandlerCtx) signalRestart(session types.Session) error {
return err
}
// TODO: Use offer event intead.
// TODO: Use offer event instead.
session.Send(
event.SIGNAL_RESTART,
message.SignalDescription{
@ -123,25 +125,17 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist")
}
var err error
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
peer.SetVideoAuto(payload.VideoAuto)
if err = peer.SetVideoBitrate(payload.Bitrate); err != nil {
return err
if payload.Video != "" {
if err := peer.SetVideoID(payload.Video); err != nil {
h.logger.Error().Err(err).Msg("failed to set video id")
}
} else {
if err := peer.SetVideoBitrate(payload.Bitrate); err != nil {
h.logger.Error().Err(err).Msg("failed to set video bitrate")
}
}
session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: peer.GetVideoId(), // TODO: Refactor.
Bitrate: payload.Bitrate,
})
return nil
}

View File

@ -8,9 +8,8 @@ import (
"strings"
"github.com/PaesslerAG/gval"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/demodesk/neko/pkg/types/codec"
"github.com/pion/webrtc/v3/pkg/media"
)
var (
@ -20,16 +19,23 @@ var (
type Sample media.Sample
type Receiver interface {
SetStream(stream StreamSinkManager) error
SetStream(stream StreamSinkManager) (changed bool, err error)
RemoveStream()
OnBitrateChange(f func(int) error)
OnBitrateChange(f func(bitrate int) (changed bool, err error))
OnVideoChange(f func(videoID string) (changed bool, err error))
VideoAuto() bool
SetVideoAuto(videoAuto bool)
}
type BucketsManager interface {
IDs() []string
Codec() codec.RTPCodec
SetReceiver(receiver Receiver) error
SetReceiver(receiver Receiver)
RemoveReceiver(receiver Receiver) error
DestroyAll()
RecreateAll() error
Shutdown()
}
type BroadcastManager interface {
@ -48,6 +54,7 @@ type ScreencastManager interface {
type StreamSinkManager interface {
ID() string
Codec() codec.RTPCodec
Bitrate() int
AddListener(listener *func(sample Sample)) error
RemoveListener(listener *func(sample Sample)) error
@ -55,6 +62,9 @@ type StreamSinkManager interface {
ListenersCount() int
Started() bool
CreatePipeline() error
DestroyPipeline()
}
type StreamSrcManager interface {
@ -201,17 +211,33 @@ func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (in
}),
}
// TOOD: do not read target-bitrate from pipeline, but only from config.
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
return 0, fmt.Errorf("target-bitrate not found")
// TODO: This is only for h264.
expr, ok = config.GstParams["bitrate"]
if !ok {
return 0, fmt.Errorf("bitrate not found")
}
}
targetBitrate, err := gval.Evaluate(expr, values, language...)
bitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
}
return targetBitrate.(int), nil
var bitrateInt int
switch val := bitrate.(type) {
case int:
bitrateInt = val
case float64:
bitrateInt = (int)(val)
default:
return 0, fmt.Errorf("bitrate is not int or float64")
}
return bitrateInt, nil
}
}

View File

@ -48,7 +48,10 @@ type SystemDisconnect struct {
type SignalProvide struct {
SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"`
Video string `json:"video"` // TODO: Refactor.
// TODO: Use SignalVideo struct.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
}
type SignalCandidate struct {
@ -60,8 +63,9 @@ type SignalDescription struct {
}
type SignalVideo struct {
Video string `json:"video"` // TODO: Refactor.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
}
/////////////////////////////

View File

@ -25,8 +25,11 @@ type WebRTCPeer interface {
SetCandidate(candidate webrtc.ICECandidateInit) error
SetVideoBitrate(bitrate int) error
GetVideoId() string
SetVideoID(videoID string) error
GetVideoID() string
SetPaused(isPaused bool) error
SetVideoAuto(auto bool)
VideoAuto() bool
SendCursorPosition(x, y int) error
SendCursorImage(cur *CursorImage, img []byte) error
@ -40,6 +43,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer
CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error)
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
SetCursorPosition(x, y int)
}