Bandwidth estimator refactor (#46)

* rewrite to use stream selector.

* WIP.

* add nacks to metrics.

* add estimate trend.

* estimator based on trend detector.

* add estimator unstable duration.

* add estimator debug.

* add stalled duration.

* estimator move values to config.

* change default estimator values.

* minor style changes.

* fix websocket video messages.

* replace video track with ivdeo id.
This commit is contained in:
Miroslav Šedivý
2023-05-15 19:29:39 +02:00
committed by GitHub
parent 8660c1a256
commit 3e8d686c0f
17 changed files with 845 additions and 775 deletions

View File

@ -31,26 +31,6 @@ type SampleListener interface {
WriteSample(Sample)
}
type Receiver interface {
SetStream(stream StreamSinkManager) (changed bool, err error)
RemoveStream()
OnBitrateChange(f func(bitrate int) (changed bool, err error))
OnVideoChange(f func(videoID string) (changed bool, err error))
VideoAuto() bool
SetVideoAuto(videoAuto bool)
}
type BucketsManager interface {
IDs() []string
Codec() codec.RTPCodec
SetReceiver(receiver Receiver)
RemoveReceiver(receiver Receiver) error
DestroyAll()
RecreateAll() error
Shutdown()
}
type BroadcastManager interface {
Start(url string) error
Stop()
@ -64,10 +44,74 @@ type ScreencastManager interface {
Image() ([]byte, error)
}
type StreamSelectorType int
const (
// select exact stream
StreamSelectorTypeExact StreamSelectorType = iota
// select nearest stream (in either direction) if exact stream is not available
StreamSelectorTypeNearest
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
StreamSelectorTypeLower
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
StreamSelectorTypeHigher
)
func (s StreamSelectorType) String() string {
switch s {
case StreamSelectorTypeExact:
return "exact"
case StreamSelectorTypeNearest:
return "nearest"
case StreamSelectorTypeLower:
return "lower"
case StreamSelectorTypeHigher:
return "higher"
default:
return fmt.Sprintf("%d", int(s))
}
}
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
case "exact", "":
*s = StreamSelectorTypeExact
case "nearest":
*s = StreamSelectorTypeNearest
case "lower":
*s = StreamSelectorTypeLower
case "higher":
*s = StreamSelectorTypeHigher
default:
return fmt.Errorf("invalid stream selector type: %s", string(text))
}
return nil
}
func (s StreamSelectorType) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
type StreamSelector struct {
// type of stream selector
Type StreamSelectorType
// select stream by its ID
ID string
// select stream by its bitrate
Bitrate uint64
}
type StreamSelectorManager interface {
IDs() []string
Codec() codec.RTPCodec
GetStream(selector StreamSelector) (StreamSinkManager, bool)
}
type StreamSinkManager interface {
ID() string
Codec() codec.RTPCodec
Bitrate() int
Bitrate() uint64
AddListener(listener SampleListener) error
RemoveListener(listener SampleListener) error
@ -94,12 +138,10 @@ type CaptureManager interface {
Start()
Shutdown() error
GetBitrateFromVideoID(videoID string) (int, error)
Broadcast() BroadcastManager
Screencast() ScreencastManager
Audio() StreamSinkManager
Video() BucketsManager
Video() StreamSelectorManager
Webcam() StreamSrcManager
Microphone() StreamSrcManager
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
config.GstSuffix,
}[:], " "), nil
}
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
return func() (int, error) {
if config.Bitrate > 0 {
return config.Bitrate, nil
}
screen := getScreen()
values := map[string]any{
"width": screen.Width,
"height": screen.Height,
"fps": screen.Rate,
}
language := []gval.Language{
gval.Function("round", func(args ...any) (any, error) {
return (int)(math.Round(args[0].(float64))), nil
}),
}
// TOOD: do not read target-bitrate from pipeline, but only from config.
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
// TODO: This is only for h264.
expr, ok = config.GstParams["bitrate"]
if !ok {
return 0, fmt.Errorf("bitrate not found")
}
}
bitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
}
var bitrateInt int
switch val := bitrate.(type) {
case int:
bitrateInt = val
case float64:
bitrateInt = (int)(val)
default:
return 0, fmt.Errorf("bitrate is not int or float64")
}
return bitrateInt, nil
}
}

View File

@ -48,10 +48,6 @@ type SystemDisconnect struct {
type SignalProvide struct {
SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"`
// TODO: Use SignalVideo struct.
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
}
type SignalCandidate struct {
@ -63,9 +59,8 @@ type SignalDescription struct {
}
type SignalVideo struct {
Video string `json:"video"`
Bitrate int `json:"bitrate"`
VideoAuto bool `json:"video_auto"`
Video string `json:"video"`
Auto bool `json:"auto"`
}
/////////////////////////////

View File

@ -9,6 +9,7 @@ import (
var (
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
)
type ICEServer struct {
@ -23,10 +24,10 @@ type WebRTCPeer interface {
SetRemoteDescription(webrtc.SessionDescription) error
SetCandidate(webrtc.ICECandidateInit) error
SetVideoBitrate(bitrate int) error
SetVideoID(videoID string) error
GetVideoID() string
SetVideo(StreamSelector) error
VideoID() (string, bool)
SetPaused(isPaused bool) error
Paused() bool
SetVideoAuto(auto bool)
VideoAuto() bool
@ -42,6 +43,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
SetCursorPosition(x, y int)
}

153
pkg/utils/trenddetector.go Normal file
View File

@ -0,0 +1,153 @@
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
package utils
import (
"fmt"
"time"
)
// ------------------------------------------------
type TrendDirection int
const (
TrendDirectionNeutral TrendDirection = iota
TrendDirectionUpward
TrendDirectionDownward
)
func (t TrendDirection) String() string {
switch t {
case TrendDirectionNeutral:
return "NEUTRAL"
case TrendDirectionUpward:
return "UPWARD"
case TrendDirectionDownward:
return "DOWNWARD"
default:
return fmt.Sprintf("%d", int(t))
}
}
// ------------------------------------------------
type TrendDetectorParams struct {
RequiredSamples int
DownwardTrendThreshold float64
CollapseValues bool
}
type TrendDetector struct {
params TrendDetectorParams
startTime time.Time
numSamples int
values []int64
lowestValue int64
highestValue int64
direction TrendDirection
}
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
return &TrendDetector{
params: params,
startTime: time.Now(),
direction: TrendDirectionNeutral,
}
}
func (t *TrendDetector) Seed(value int64) {
if len(t.values) != 0 {
return
}
t.values = append(t.values, value)
}
func (t *TrendDetector) AddValue(value int64) {
t.numSamples++
if t.lowestValue == 0 || value < t.lowestValue {
t.lowestValue = value
}
if value > t.highestValue {
t.highestValue = value
}
// ignore duplicate values
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
return
}
if len(t.values) == t.params.RequiredSamples {
t.values = t.values[1:]
}
t.values = append(t.values, value)
t.updateDirection()
}
func (t *TrendDetector) GetLowest() int64 {
return t.lowestValue
}
func (t *TrendDetector) GetHighest() int64 {
return t.highestValue
}
func (t *TrendDetector) GetValues() []int64 {
return t.values
}
func (t *TrendDetector) GetDirection() TrendDirection {
return t.direction
}
func (t *TrendDetector) ToString() string {
now := time.Now()
elapsed := now.Sub(t.startTime).Seconds()
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
return str
}
func (t *TrendDetector) updateDirection() {
if len(t.values) < t.params.RequiredSamples {
t.direction = TrendDirectionNeutral
return
}
// using Kendall's Tau to find trend
kt := kendallsTau(t.values)
t.direction = TrendDirectionNeutral
switch {
case kt > 0:
t.direction = TrendDirectionUpward
case kt < t.params.DownwardTrendThreshold:
t.direction = TrendDirectionDownward
}
}
// ------------------------------------------------
func kendallsTau(values []int64) float64 {
concordantPairs := 0
discordantPairs := 0
for i := 0; i < len(values)-1; i++ {
for j := i + 1; j < len(values); j++ {
if values[i] < values[j] {
concordantPairs++
} else if values[i] > values[j] {
discordantPairs++
}
}
}
if (concordantPairs + discordantPairs) == 0 {
return 0.0
}
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
}