Capture bandwidth switch (#14)

* Handle bitrate change by finding the stream with closest bitrate as peer

* Convert video id into bitrate when creating peer or changing bitrate

* Try to fix prometheus panic

* Revert metrics label name change

* minor fixes.

* bitrate selector.

* skip if moving to the same stream.

* no closure for getting target bitrate.

* fix: high res switch to lo video, stream bitrate out of range

* revert dev config change.

* white space.

Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
Miroslav Šedivý 2022-10-25 20:25:00 +02:00 committed by GitHub
parent e0bee67e85
commit 6067367acd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 186 additions and 52 deletions

View File

@ -2,6 +2,8 @@ package capture
import ( import (
"errors" "errors"
"fmt"
"math"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -36,17 +38,17 @@ func (m *BucketsManagerCtx) shutdown() {
} }
func (m *BucketsManagerCtx) destroyAll() { func (m *BucketsManagerCtx) destroyAll() {
for _, video := range m.streams { for _, stream := range m.streams {
if video.Started() { if stream.Started() {
video.destroyPipeline() stream.destroyPipeline()
} }
} }
} }
func (m *BucketsManagerCtx) recreateAll() error { func (m *BucketsManagerCtx) recreateAll() error {
for _, video := range m.streams { for _, stream := range m.streams {
if video.Started() { if stream.Started() {
err := video.createPipeline() err := stream.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err return err
} }
@ -65,22 +67,39 @@ func (m *BucketsManagerCtx) Codec() codec.RTPCodec {
} }
func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error { func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error {
receiver.OnVideoIdChange(func(videoID string) error { receiver.OnBitrateChange(func(bitrate int) error {
videoStream, ok := m.streams[videoID] stream, ok := m.findNearestStream(bitrate)
if !ok { if !ok {
return types.ErrWebRTCVideoNotFound return fmt.Errorf("no stream found for bitrate %d", bitrate)
} }
return receiver.SetStream(videoStream) return receiver.SetStream(stream)
}) })
// TODO: Save receiver.
return nil return nil
} }
func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) {
minDiff := math.MaxInt
for _, s := range m.streams {
streamBitrate, err := s.Bitrate()
if err != nil {
m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID())
continue
}
diffAbs := int(math.Abs(float64(bitrate - streamBitrate)))
if diffAbs < minDiff {
minDiff, ss = diffAbs, s
}
}
ok = ss != nil
return
}
func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error { func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
// TODO: Unsubribe from OnVideoIdChange. receiver.OnBitrateChange(nil)
// TODO: Remove receiver.
receiver.RemoveStream() receiver.RemoveStream()
return nil return nil
} }

View File

@ -16,6 +16,7 @@ import (
type CaptureManagerCtx struct { type CaptureManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
desktop types.DesktopManager desktop types.DesktopManager
config *config.Capture
// sinks // sinks
broadcast *BroacastManagerCtx broadcast *BroacastManagerCtx
@ -66,13 +67,18 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Str("pipeline", pipeline). Str("pipeline", pipeline).
Msg("syntax check for video stream pipeline passed") Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate")
}
// append to videos // append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id) videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
} }
return &CaptureManagerCtx{ return &CaptureManagerCtx{
logger: logger, logger: logger,
desktop: desktop, desktop: desktop,
config: config,
// sinks // sinks
broadcast: broadcastNew(func(url string) (string, error) { broadcast: broadcastNew(func(url string) (string, error) {
@ -132,7 +138,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! %s "+ "! %s "+
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil ), nil
}, "audio"), }, "audio", nil),
video: bucketsNew(config.VideoCodec, videos, config.VideoIDs), video: bucketsNew(config.VideoCodec, videos, config.VideoIDs),
// sources // sources
@ -242,6 +248,15 @@ func (manager *CaptureManagerCtx) Shutdown() error {
return nil return nil
} }
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
cfg, ok := manager.config.VideoPipelines[videoID]
if !ok {
return 0, fmt.Errorf("video config not found for %s", videoID)
}
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
}
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager { func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
return manager.broadcast return manager.broadcast
} }

View File

@ -19,6 +19,9 @@ import (
var moveSinkListenerMu = sync.Mutex{} var moveSinkListenerMu = sync.Mutex{}
type StreamSinkManagerCtx struct { type StreamSinkManagerCtx struct {
id string
getBitrate func() (int, error)
logger zerolog.Logger logger zerolog.Logger
mu sync.Mutex mu sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
@ -37,13 +40,16 @@ type StreamSinkManagerCtx struct {
pipelinesActive prometheus.Gauge pipelinesActive prometheus.Gauge
} }
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), video_id string) *StreamSinkManagerCtx { func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
logger := log.With(). logger := log.With().
Str("module", "capture"). Str("module", "capture").
Str("submodule", "stream-sink"). Str("submodule", "stream-sink").
Str("video_id", video_id).Logger() Str("id", id).Logger()
manager := &StreamSinkManagerCtx{ manager := &StreamSinkManagerCtx{
id: id,
getBitrate: getBitrate,
logger: logger, logger: logger,
codec: codec, codec: codec,
pipelineFn: pipelineFn, pipelineFn: pipelineFn,
@ -56,7 +62,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Subsystem: "capture", Subsystem: "capture",
Help: "Current number of listeners for a pipeline.", Help: "Current number of listeners for a pipeline.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"video_id": video_id, "video_id": id,
"codec_name": codec.Name, "codec_name": codec.Name,
"codec_type": codec.Type.String(), "codec_type": codec.Type.String(),
}, },
@ -68,7 +74,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Help: "Total number of created pipelines.", Help: "Total number of created pipelines.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"submodule": "streamsink", "submodule": "streamsink",
"video_id": video_id, "video_id": id,
"codec_name": codec.Name, "codec_name": codec.Name,
"codec_type": codec.Type.String(), "codec_type": codec.Type.String(),
}, },
@ -80,7 +86,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Help: "Total number of active pipelines.", Help: "Total number of active pipelines.",
ConstLabels: map[string]string{ ConstLabels: map[string]string{
"submodule": "streamsink", "submodule": "streamsink",
"video_id": video_id, "video_id": id,
"codec_name": codec.Name, "codec_name": codec.Name,
"codec_type": codec.Type.String(), "codec_type": codec.Type.String(),
}, },
@ -103,6 +109,18 @@ func (manager *StreamSinkManagerCtx) shutdown() {
manager.wg.Wait() manager.wg.Wait()
} }
func (manager *StreamSinkManagerCtx) ID() string {
return manager.id
}
func (manager *StreamSinkManagerCtx) Bitrate() (int, error) {
if manager.getBitrate == nil {
return 0, nil
}
// recalculate bitrate every time, take screen resolution (and fps) into account
return manager.getBitrate()
}
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
return manager.codec return manager.codec
} }

View File

@ -214,7 +214,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
return api.NewPeerConnection(manager.webrtcConfiguration) return api.NewPeerConnection(manager.webrtcConfiguration)
} }
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID string) (*webrtc.SessionDescription, error) { func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) {
id := atomic.AddInt32(&manager.peerId, 1) id := atomic.AddInt32(&manager.peerId, 1)
manager.metrics.NewConnection(session) manager.metrics.NewConnection(session)
@ -280,11 +280,12 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
return nil, err return nil, err
} }
// set default video id // set initial video bitrate
err = videoTrack.SetVideoID(videoID) if err = videoTrack.SetBitrate(bitrate); err != nil {
if err != nil {
return nil, err return nil, err
} }
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID) manager.metrics.SetVideoID(session, videoID)
// data channel // data channel
@ -298,14 +299,19 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
logger: logger, logger: logger,
connection: connection, connection: connection,
dataChannel: dataChannel, dataChannel: dataChannel,
changeVideo: func(videoID string) error { changeVideo: func(bitrate int) error {
if err := videoTrack.SetVideoID(videoID); err != nil { if err := videoTrack.SetBitrate(bitrate); err != nil {
return err return err
} }
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID) manager.metrics.SetVideoID(session, videoID)
return nil return nil
}, },
// TODO: Refactor.
videoId: func() string {
return videoTrack.stream.ID()
},
setPaused: func(isPaused bool) { setPaused: func(isPaused bool) {
videoTrack.SetPaused(isPaused) videoTrack.SetPaused(isPaused)
audioTrack.SetPaused(isPaused) audioTrack.SetPaused(isPaused)
@ -418,7 +424,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
connection.Close() connection.Close()
case webrtc.PeerConnectionStateClosed: case webrtc.PeerConnectionStateClosed:
session.SetWebRTCConnected(peer, false) session.SetWebRTCConnected(peer, false)
video.RemoveReceiver(videoTrack) if err = video.RemoveReceiver(videoTrack); err != nil {
logger.Err(err).Msg("failed to remove video receiver")
}
audioTrack.RemoveStream() audioTrack.RemoveStream()
} }

View File

@ -17,7 +17,8 @@ type WebRTCPeerCtx struct {
logger zerolog.Logger logger zerolog.Logger
connection *webrtc.PeerConnection connection *webrtc.PeerConnection
dataChannel *webrtc.DataChannel dataChannel *webrtc.DataChannel
changeVideo func(videoID string) error changeVideo func(bitrate int) error
videoId func() string
setPaused func(isPaused bool) setPaused func(isPaused bool)
iceTrickle bool iceTrickle bool
} }
@ -114,7 +115,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate) return peer.connection.AddICECandidate(candidate)
} }
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
@ -122,8 +123,16 @@ func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
return types.ErrWebRTCConnectionNotFound return types.ErrWebRTCConnectionNotFound
} }
peer.logger.Info().Str("video_id", videoID).Msg("change video id") peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate")
return peer.changeVideo(videoID) return peer.changeVideo(bitrate)
}
// TODO: Refactor.
func (peer *WebRTCPeerCtx) GetVideoId() string {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoId()
} }
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error { func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {

View File

@ -27,7 +27,7 @@ type Track struct {
onRtcp func(rtcp.Packet) onRtcp func(rtcp.Packet)
onRtcpMu sync.RWMutex onRtcpMu sync.RWMutex
videoIdChange func(string) error bitrateChange func(int) error
} }
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) { func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) {
@ -140,14 +140,14 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
t.onRtcp = f t.onRtcp = f
} }
func (t *Track) SetVideoID(videoID string) error { func (t *Track) SetBitrate(bitrate int) error {
if t.videoIdChange == nil { if t.bitrateChange == nil {
return fmt.Errorf("video id change not supported") return fmt.Errorf("bitrate change not supported")
} }
return t.videoIdChange(videoID) return t.bitrateChange(bitrate)
} }
func (t *Track) OnVideoIdChange(f func(string) error) { func (t *Track) OnBitrateChange(f func(int) error) {
t.videoIdChange = f t.bitrateChange = f
} }

View File

@ -19,14 +19,27 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
payload.Video = videos[0] payload.Video = videos[0]
} }
offer, err := h.webrtc.CreatePeer(session, payload.Video) var err error
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate)
if err != nil { if err != nil {
return err return err
} }
// set webrtc as paused if session has private mode enabled if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil && session.PrivateModeEnabled() { // set webrtc as paused if session has private mode enabled
webrtcPeer.SetPaused(true) if session.PrivateModeEnabled() {
webrtcPeer.SetPaused(true)
}
payload.Video = webrtcPeer.GetVideoId()
} }
session.Send( session.Send(
@ -34,7 +47,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{ message.SignalProvide{
SDP: offer.SDP, SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(), ICEServers: h.webrtc.ICEServers(),
Video: payload.Video, Video: payload.Video, // TODO: Refactor.
}) })
return nil return nil
@ -110,15 +123,24 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist") return errors.New("webRTC peer does not exist")
} }
err := peer.SetVideoID(payload.Video) var err error
if err != nil { if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
if err = peer.SetVideoBitrate(payload.Bitrate); err != nil {
return err return err
} }
session.Send( session.Send(
event.SIGNAL_VIDEO, event.SIGNAL_VIDEO,
message.SignalVideo{ message.SignalVideo{
Video: payload.Video, Video: peer.GetVideoId(), // TODO: Refactor.
Bitrate: payload.Bitrate,
}) })
return nil return nil

View File

@ -22,7 +22,7 @@ type Sample media.Sample
type Receiver interface { type Receiver interface {
SetStream(stream StreamSinkManager) error SetStream(stream StreamSinkManager) error
RemoveStream() RemoveStream()
OnVideoIdChange(f func(string) error) OnBitrateChange(f func(int) error)
} }
type BucketsManager interface { type BucketsManager interface {
@ -46,6 +46,7 @@ type ScreencastManager interface {
} }
type StreamSinkManager interface { type StreamSinkManager interface {
ID() string
Codec() codec.RTPCodec Codec() codec.RTPCodec
AddListener(listener *func(sample Sample)) error AddListener(listener *func(sample Sample)) error
@ -70,6 +71,8 @@ type CaptureManager interface {
Start() Start()
Shutdown() error Shutdown() error
GetBitrateFromVideoID(videoID string) (int, error)
Broadcast() BroadcastManager Broadcast() BroadcastManager
Screencast() ScreencastManager Screencast() ScreencastManager
Audio() StreamSinkManager Audio() StreamSinkManager
@ -83,6 +86,7 @@ type VideoConfig struct {
Width string `mapstructure:"width"` // expression Width string `mapstructure:"width"` // expression
Height string `mapstructure:"height"` // expression Height string `mapstructure:"height"` // expression
Fps string `mapstructure:"fps"` // expression Fps string `mapstructure:"fps"` // expression
Bitrate int `mapstructure:"bitrate"` // pipeline bitrate
GstPrefix string `mapstructure:"gst_prefix"` // pipeline prefix, starts with ! GstPrefix string `mapstructure:"gst_prefix"` // pipeline prefix, starts with !
GstEncoder string `mapstructure:"gst_encoder"` // gst encoder name GstEncoder string `mapstructure:"gst_encoder"` // gst encoder name
GstParams map[string]string `mapstructure:"gst_params"` // map of expressions GstParams map[string]string `mapstructure:"gst_params"` // map of expressions
@ -173,3 +177,41 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
config.GstSuffix, config.GstSuffix,
}[:], " "), nil }[:], " "), nil
} }
func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (int, error) {
return func() (int, error) {
if config.Bitrate > 0 {
return config.Bitrate, nil
}
screen := getScreen()
if screen == nil {
return 0, fmt.Errorf("screen is nil")
}
values := map[string]any{
"width": screen.Width,
"height": screen.Height,
"fps": screen.Rate,
}
language := []gval.Language{
gval.Function("round", func(args ...any) (any, error) {
return (int)(math.Round(args[0].(float64))), nil
}),
}
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
return 0, fmt.Errorf("target-bitrate not found")
}
targetBitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, err
}
return targetBitrate.(int), nil
}
}

View File

@ -48,7 +48,7 @@ type SystemDisconnect struct {
type SignalProvide struct { type SignalProvide struct {
SDP string `json:"sdp"` SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"` ICEServers []types.ICEServer `json:"iceservers"`
Video string `json:"video"` Video string `json:"video"` // TODO: Refactor.
} }
type SignalCandidate struct { type SignalCandidate struct {
@ -60,7 +60,8 @@ type SignalDescription struct {
} }
type SignalVideo struct { type SignalVideo struct {
Video string `json:"video"` Video string `json:"video"` // TODO: Refactor.
Bitrate int `json:"bitrate"`
} }
///////////////////////////// /////////////////////////////

View File

@ -7,7 +7,6 @@ import (
) )
var ( var (
ErrWebRTCVideoNotFound = errors.New("webrtc video not found")
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found") ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found") ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
) )
@ -25,7 +24,8 @@ type WebRTCPeer interface {
SetAnswer(sdp string) error SetAnswer(sdp string) error
SetCandidate(candidate webrtc.ICECandidateInit) error SetCandidate(candidate webrtc.ICECandidateInit) error
SetVideoID(videoID string) error SetVideoBitrate(bitrate int) error
GetVideoId() string
SetPaused(isPaused bool) error SetPaused(isPaused bool) error
SendCursorPosition(x, y int) error SendCursorPosition(x, y int) error
@ -40,6 +40,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer ICEServers() []ICEServer
CreatePeer(session Session, videoID string) (*webrtc.SessionDescription, error) CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error)
SetCursorPosition(x, y int) SetCursorPosition(x, y int)
} }