mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
Bandwidth estimator refactor (#46)
* rewrite to use stream selector. * WIP. * add nacks to metrics. * add estimate trend. * estimator based on trend detector. * add estimator unstable duration. * add estimator debug. * add stalled duration. * estimator move values to config. * change default estimator values. * minor style changes. * fix websocket video messages. * replace video track with ivdeo id.
This commit is contained in:
@ -31,26 +31,6 @@ type SampleListener interface {
|
||||
WriteSample(Sample)
|
||||
}
|
||||
|
||||
type Receiver interface {
|
||||
SetStream(stream StreamSinkManager) (changed bool, err error)
|
||||
RemoveStream()
|
||||
OnBitrateChange(f func(bitrate int) (changed bool, err error))
|
||||
OnVideoChange(f func(videoID string) (changed bool, err error))
|
||||
VideoAuto() bool
|
||||
SetVideoAuto(videoAuto bool)
|
||||
}
|
||||
|
||||
type BucketsManager interface {
|
||||
IDs() []string
|
||||
Codec() codec.RTPCodec
|
||||
SetReceiver(receiver Receiver)
|
||||
RemoveReceiver(receiver Receiver) error
|
||||
|
||||
DestroyAll()
|
||||
RecreateAll() error
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
type BroadcastManager interface {
|
||||
Start(url string) error
|
||||
Stop()
|
||||
@ -64,10 +44,74 @@ type ScreencastManager interface {
|
||||
Image() ([]byte, error)
|
||||
}
|
||||
|
||||
type StreamSelectorType int
|
||||
|
||||
const (
|
||||
// select exact stream
|
||||
StreamSelectorTypeExact StreamSelectorType = iota
|
||||
// select nearest stream (in either direction) if exact stream is not available
|
||||
StreamSelectorTypeNearest
|
||||
// if exact stream is found select the next lower stream, otherwise select the nearest lower stream
|
||||
StreamSelectorTypeLower
|
||||
// if exact stream is found select the next higher stream, otherwise select the nearest higher stream
|
||||
StreamSelectorTypeHigher
|
||||
)
|
||||
|
||||
func (s StreamSelectorType) String() string {
|
||||
switch s {
|
||||
case StreamSelectorTypeExact:
|
||||
return "exact"
|
||||
case StreamSelectorTypeNearest:
|
||||
return "nearest"
|
||||
case StreamSelectorTypeLower:
|
||||
return "lower"
|
||||
case StreamSelectorTypeHigher:
|
||||
return "higher"
|
||||
default:
|
||||
return fmt.Sprintf("%d", int(s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StreamSelectorType) UnmarshalText(text []byte) error {
|
||||
switch strings.ToLower(string(text)) {
|
||||
case "exact", "":
|
||||
*s = StreamSelectorTypeExact
|
||||
case "nearest":
|
||||
*s = StreamSelectorTypeNearest
|
||||
case "lower":
|
||||
*s = StreamSelectorTypeLower
|
||||
case "higher":
|
||||
*s = StreamSelectorTypeHigher
|
||||
default:
|
||||
return fmt.Errorf("invalid stream selector type: %s", string(text))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s StreamSelectorType) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
type StreamSelector struct {
|
||||
// type of stream selector
|
||||
Type StreamSelectorType
|
||||
// select stream by its ID
|
||||
ID string
|
||||
// select stream by its bitrate
|
||||
Bitrate uint64
|
||||
}
|
||||
|
||||
type StreamSelectorManager interface {
|
||||
IDs() []string
|
||||
Codec() codec.RTPCodec
|
||||
|
||||
GetStream(selector StreamSelector) (StreamSinkManager, bool)
|
||||
}
|
||||
|
||||
type StreamSinkManager interface {
|
||||
ID() string
|
||||
Codec() codec.RTPCodec
|
||||
Bitrate() int
|
||||
Bitrate() uint64
|
||||
|
||||
AddListener(listener SampleListener) error
|
||||
RemoveListener(listener SampleListener) error
|
||||
@ -94,12 +138,10 @@ type CaptureManager interface {
|
||||
Start()
|
||||
Shutdown() error
|
||||
|
||||
GetBitrateFromVideoID(videoID string) (int, error)
|
||||
|
||||
Broadcast() BroadcastManager
|
||||
Screencast() ScreencastManager
|
||||
Audio() StreamSinkManager
|
||||
Video() BucketsManager
|
||||
Video() StreamSelectorManager
|
||||
|
||||
Webcam() StreamSrcManager
|
||||
Microphone() StreamSrcManager
|
||||
@ -201,54 +243,3 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
|
||||
config.GstSuffix,
|
||||
}[:], " "), nil
|
||||
}
|
||||
|
||||
func (config *VideoConfig) GetBitrateFn(getScreen func() ScreenSize) func() (int, error) {
|
||||
return func() (int, error) {
|
||||
if config.Bitrate > 0 {
|
||||
return config.Bitrate, nil
|
||||
}
|
||||
|
||||
screen := getScreen()
|
||||
|
||||
values := map[string]any{
|
||||
"width": screen.Width,
|
||||
"height": screen.Height,
|
||||
"fps": screen.Rate,
|
||||
}
|
||||
|
||||
language := []gval.Language{
|
||||
gval.Function("round", func(args ...any) (any, error) {
|
||||
return (int)(math.Round(args[0].(float64))), nil
|
||||
}),
|
||||
}
|
||||
|
||||
// TOOD: do not read target-bitrate from pipeline, but only from config.
|
||||
|
||||
// TODO: This is only for vp8.
|
||||
expr, ok := config.GstParams["target-bitrate"]
|
||||
if !ok {
|
||||
// TODO: This is only for h264.
|
||||
expr, ok = config.GstParams["bitrate"]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("bitrate not found")
|
||||
}
|
||||
}
|
||||
|
||||
bitrate, err := gval.Evaluate(expr, values, language...)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to evaluate bitrate: %w", err)
|
||||
}
|
||||
|
||||
var bitrateInt int
|
||||
switch val := bitrate.(type) {
|
||||
case int:
|
||||
bitrateInt = val
|
||||
case float64:
|
||||
bitrateInt = (int)(val)
|
||||
default:
|
||||
return 0, fmt.Errorf("bitrate is not int or float64")
|
||||
}
|
||||
|
||||
return bitrateInt, nil
|
||||
}
|
||||
}
|
||||
|
@ -48,10 +48,6 @@ type SystemDisconnect struct {
|
||||
type SignalProvide struct {
|
||||
SDP string `json:"sdp"`
|
||||
ICEServers []types.ICEServer `json:"iceservers"`
|
||||
// TODO: Use SignalVideo struct.
|
||||
Video string `json:"video"`
|
||||
Bitrate int `json:"bitrate"`
|
||||
VideoAuto bool `json:"video_auto"`
|
||||
}
|
||||
|
||||
type SignalCandidate struct {
|
||||
@ -63,9 +59,8 @@ type SignalDescription struct {
|
||||
}
|
||||
|
||||
type SignalVideo struct {
|
||||
Video string `json:"video"`
|
||||
Bitrate int `json:"bitrate"`
|
||||
VideoAuto bool `json:"video_auto"`
|
||||
Video string `json:"video"`
|
||||
Auto bool `json:"auto"`
|
||||
}
|
||||
|
||||
/////////////////////////////
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
var (
|
||||
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
|
||||
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
|
||||
ErrWebRTCStreamNotFound = errors.New("webrtc stream not found")
|
||||
)
|
||||
|
||||
type ICEServer struct {
|
||||
@ -23,10 +24,10 @@ type WebRTCPeer interface {
|
||||
SetRemoteDescription(webrtc.SessionDescription) error
|
||||
SetCandidate(webrtc.ICECandidateInit) error
|
||||
|
||||
SetVideoBitrate(bitrate int) error
|
||||
SetVideoID(videoID string) error
|
||||
GetVideoID() string
|
||||
SetVideo(StreamSelector) error
|
||||
VideoID() (string, bool)
|
||||
SetPaused(isPaused bool) error
|
||||
Paused() bool
|
||||
SetVideoAuto(auto bool)
|
||||
VideoAuto() bool
|
||||
|
||||
@ -42,6 +43,6 @@ type WebRTCManager interface {
|
||||
|
||||
ICEServers() []ICEServer
|
||||
|
||||
CreatePeer(session Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error)
|
||||
CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error)
|
||||
SetCursorPosition(x, y int)
|
||||
}
|
||||
|
153
pkg/utils/trenddetector.go
Normal file
153
pkg/utils/trenddetector.go
Normal file
@ -0,0 +1,153 @@
|
||||
// From https://github.com/livekit/livekit/blob/master/pkg/sfu/streamallocator/trenddetector.go
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
type TrendDirection int
|
||||
|
||||
const (
|
||||
TrendDirectionNeutral TrendDirection = iota
|
||||
TrendDirectionUpward
|
||||
TrendDirectionDownward
|
||||
)
|
||||
|
||||
func (t TrendDirection) String() string {
|
||||
switch t {
|
||||
case TrendDirectionNeutral:
|
||||
return "NEUTRAL"
|
||||
case TrendDirectionUpward:
|
||||
return "UPWARD"
|
||||
case TrendDirectionDownward:
|
||||
return "DOWNWARD"
|
||||
default:
|
||||
return fmt.Sprintf("%d", int(t))
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
type TrendDetectorParams struct {
|
||||
RequiredSamples int
|
||||
DownwardTrendThreshold float64
|
||||
CollapseValues bool
|
||||
}
|
||||
|
||||
type TrendDetector struct {
|
||||
params TrendDetectorParams
|
||||
|
||||
startTime time.Time
|
||||
numSamples int
|
||||
values []int64
|
||||
lowestValue int64
|
||||
highestValue int64
|
||||
|
||||
direction TrendDirection
|
||||
}
|
||||
|
||||
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
|
||||
return &TrendDetector{
|
||||
params: params,
|
||||
startTime: time.Now(),
|
||||
direction: TrendDirectionNeutral,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TrendDetector) Seed(value int64) {
|
||||
if len(t.values) != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
t.values = append(t.values, value)
|
||||
}
|
||||
|
||||
func (t *TrendDetector) AddValue(value int64) {
|
||||
t.numSamples++
|
||||
if t.lowestValue == 0 || value < t.lowestValue {
|
||||
t.lowestValue = value
|
||||
}
|
||||
if value > t.highestValue {
|
||||
t.highestValue = value
|
||||
}
|
||||
|
||||
// ignore duplicate values
|
||||
if t.params.CollapseValues && len(t.values) != 0 && t.values[len(t.values)-1] == value {
|
||||
return
|
||||
}
|
||||
|
||||
if len(t.values) == t.params.RequiredSamples {
|
||||
t.values = t.values[1:]
|
||||
}
|
||||
t.values = append(t.values, value)
|
||||
|
||||
t.updateDirection()
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetLowest() int64 {
|
||||
return t.lowestValue
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetHighest() int64 {
|
||||
return t.highestValue
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetValues() []int64 {
|
||||
return t.values
|
||||
}
|
||||
|
||||
func (t *TrendDetector) GetDirection() TrendDirection {
|
||||
return t.direction
|
||||
}
|
||||
|
||||
func (t *TrendDetector) ToString() string {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(t.startTime).Seconds()
|
||||
str := fmt.Sprintf("t: %+v|%+v|%.2fs", t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed)
|
||||
str += fmt.Sprintf(", v: %d|%d|%d|%+v|%.2f", t.numSamples, t.lowestValue, t.highestValue, t.values, kendallsTau(t.values))
|
||||
return str
|
||||
}
|
||||
|
||||
func (t *TrendDetector) updateDirection() {
|
||||
if len(t.values) < t.params.RequiredSamples {
|
||||
t.direction = TrendDirectionNeutral
|
||||
return
|
||||
}
|
||||
|
||||
// using Kendall's Tau to find trend
|
||||
kt := kendallsTau(t.values)
|
||||
|
||||
t.direction = TrendDirectionNeutral
|
||||
switch {
|
||||
case kt > 0:
|
||||
t.direction = TrendDirectionUpward
|
||||
case kt < t.params.DownwardTrendThreshold:
|
||||
t.direction = TrendDirectionDownward
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------
|
||||
|
||||
func kendallsTau(values []int64) float64 {
|
||||
concordantPairs := 0
|
||||
discordantPairs := 0
|
||||
|
||||
for i := 0; i < len(values)-1; i++ {
|
||||
for j := i + 1; j < len(values); j++ {
|
||||
if values[i] < values[j] {
|
||||
concordantPairs++
|
||||
} else if values[i] > values[j] {
|
||||
discordantPairs++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (concordantPairs + discordantPairs) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
|
||||
}
|
Reference in New Issue
Block a user