Capture bandwidth switch (#14)

* Handle bitrate change by finding the stream with closest bitrate as peer

* Convert video id into bitrate when creating peer or changing bitrate

* Try to fix prometheus panic

* Revert metrics label name change

* minor fixes.

* bitrate selector.

* skip if moving to the same stream.

* no closure for getting target bitrate.

* fix: high res switch to lo video, stream bitrate out of range

* revert dev config change.

* white space.

Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
Miroslav Šedivý 2022-10-25 20:25:00 +02:00 committed by GitHub
parent e0bee67e85
commit 6067367acd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 186 additions and 52 deletions

View File

@ -2,6 +2,8 @@ package capture
import (
"errors"
"fmt"
"math"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -36,17 +38,17 @@ func (m *BucketsManagerCtx) shutdown() {
}
func (m *BucketsManagerCtx) destroyAll() {
for _, video := range m.streams {
if video.Started() {
video.destroyPipeline()
for _, stream := range m.streams {
if stream.Started() {
stream.destroyPipeline()
}
}
}
func (m *BucketsManagerCtx) recreateAll() error {
for _, video := range m.streams {
if video.Started() {
err := video.createPipeline()
for _, stream := range m.streams {
if stream.Started() {
err := stream.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
@ -65,22 +67,39 @@ func (m *BucketsManagerCtx) Codec() codec.RTPCodec {
}
func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error {
receiver.OnVideoIdChange(func(videoID string) error {
videoStream, ok := m.streams[videoID]
receiver.OnBitrateChange(func(bitrate int) error {
stream, ok := m.findNearestStream(bitrate)
if !ok {
return types.ErrWebRTCVideoNotFound
return fmt.Errorf("no stream found for bitrate %d", bitrate)
}
return receiver.SetStream(videoStream)
return receiver.SetStream(stream)
})
// TODO: Save receiver.
return nil
}
func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) {
minDiff := math.MaxInt
for _, s := range m.streams {
streamBitrate, err := s.Bitrate()
if err != nil {
m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID())
continue
}
diffAbs := int(math.Abs(float64(bitrate - streamBitrate)))
if diffAbs < minDiff {
minDiff, ss = diffAbs, s
}
}
ok = ss != nil
return
}
func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
// TODO: Unsubribe from OnVideoIdChange.
// TODO: Remove receiver.
receiver.OnBitrateChange(nil)
receiver.RemoveStream()
return nil
}

View File

@ -16,6 +16,7 @@ import (
type CaptureManagerCtx struct {
logger zerolog.Logger
desktop types.DesktopManager
config *config.Capture
// sinks
broadcast *BroacastManagerCtx
@ -66,13 +67,18 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
Str("pipeline", pipeline).
Msg("syntax check for video stream pipeline passed")
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
if err != nil {
logger.Panic().Err(err).Msg("unable to get video bitrate")
}
// append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
}
return &CaptureManagerCtx{
logger: logger,
desktop: desktop,
config: config,
// sinks
broadcast: broadcastNew(func(url string) (string, error) {
@ -132,7 +138,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
"! %s "+
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil
}, "audio"),
}, "audio", nil),
video: bucketsNew(config.VideoCodec, videos, config.VideoIDs),
// sources
@ -242,6 +248,15 @@ func (manager *CaptureManagerCtx) Shutdown() error {
return nil
}
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
cfg, ok := manager.config.VideoPipelines[videoID]
if !ok {
return 0, fmt.Errorf("video config not found for %s", videoID)
}
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
}
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
return manager.broadcast
}

View File

@ -19,6 +19,9 @@ import (
var moveSinkListenerMu = sync.Mutex{}
type StreamSinkManagerCtx struct {
id string
getBitrate func() (int, error)
logger zerolog.Logger
mu sync.Mutex
wg sync.WaitGroup
@ -37,13 +40,16 @@ type StreamSinkManagerCtx struct {
pipelinesActive prometheus.Gauge
}
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), video_id string) *StreamSinkManagerCtx {
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-sink").
Str("video_id", video_id).Logger()
Str("id", id).Logger()
manager := &StreamSinkManagerCtx{
id: id,
getBitrate: getBitrate,
logger: logger,
codec: codec,
pipelineFn: pipelineFn,
@ -56,7 +62,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Subsystem: "capture",
Help: "Current number of listeners for a pipeline.",
ConstLabels: map[string]string{
"video_id": video_id,
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
@ -68,7 +74,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Help: "Total number of created pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": video_id,
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
@ -80,7 +86,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
Help: "Total number of active pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": video_id,
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
@ -103,6 +109,18 @@ func (manager *StreamSinkManagerCtx) shutdown() {
manager.wg.Wait()
}
func (manager *StreamSinkManagerCtx) ID() string {
return manager.id
}
func (manager *StreamSinkManagerCtx) Bitrate() (int, error) {
if manager.getBitrate == nil {
return 0, nil
}
// recalculate bitrate every time, take screen resolution (and fps) into account
return manager.getBitrate()
}
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}

View File

@ -214,7 +214,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
return api.NewPeerConnection(manager.webrtcConfiguration)
}
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID string) (*webrtc.SessionDescription, error) {
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) {
id := atomic.AddInt32(&manager.peerId, 1)
manager.metrics.NewConnection(session)
@ -280,11 +280,12 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
return nil, err
}
// set default video id
err = videoTrack.SetVideoID(videoID)
if err != nil {
// set initial video bitrate
if err = videoTrack.SetBitrate(bitrate); err != nil {
return nil, err
}
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID)
// data channel
@ -298,14 +299,19 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
logger: logger,
connection: connection,
dataChannel: dataChannel,
changeVideo: func(videoID string) error {
if err := videoTrack.SetVideoID(videoID); err != nil {
changeVideo: func(bitrate int) error {
if err := videoTrack.SetBitrate(bitrate); err != nil {
return err
}
videoID := videoTrack.stream.ID()
manager.metrics.SetVideoID(session, videoID)
return nil
},
// TODO: Refactor.
videoId: func() string {
return videoTrack.stream.ID()
},
setPaused: func(isPaused bool) {
videoTrack.SetPaused(isPaused)
audioTrack.SetPaused(isPaused)
@ -418,7 +424,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
connection.Close()
case webrtc.PeerConnectionStateClosed:
session.SetWebRTCConnected(peer, false)
video.RemoveReceiver(videoTrack)
if err = video.RemoveReceiver(videoTrack); err != nil {
logger.Err(err).Msg("failed to remove video receiver")
}
audioTrack.RemoveStream()
}

View File

@ -17,7 +17,8 @@ type WebRTCPeerCtx struct {
logger zerolog.Logger
connection *webrtc.PeerConnection
dataChannel *webrtc.DataChannel
changeVideo func(videoID string) error
changeVideo func(bitrate int) error
videoId func() string
setPaused func(isPaused bool)
iceTrickle bool
}
@ -114,7 +115,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
return peer.connection.AddICECandidate(candidate)
}
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
peer.mu.Lock()
defer peer.mu.Unlock()
@ -122,8 +123,16 @@ func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
return types.ErrWebRTCConnectionNotFound
}
peer.logger.Info().Str("video_id", videoID).Msg("change video id")
return peer.changeVideo(videoID)
peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate")
return peer.changeVideo(bitrate)
}
// TODO: Refactor.
func (peer *WebRTCPeerCtx) GetVideoId() string {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoId()
}
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {

View File

@ -27,7 +27,7 @@ type Track struct {
onRtcp func(rtcp.Packet)
onRtcpMu sync.RWMutex
videoIdChange func(string) error
bitrateChange func(int) error
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) {
@ -140,14 +140,14 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
t.onRtcp = f
}
func (t *Track) SetVideoID(videoID string) error {
if t.videoIdChange == nil {
return fmt.Errorf("video id change not supported")
func (t *Track) SetBitrate(bitrate int) error {
if t.bitrateChange == nil {
return fmt.Errorf("bitrate change not supported")
}
return t.videoIdChange(videoID)
return t.bitrateChange(bitrate)
}
func (t *Track) OnVideoIdChange(f func(string) error) {
t.videoIdChange = f
func (t *Track) OnBitrateChange(f func(int) error) {
t.bitrateChange = f
}

View File

@ -19,22 +19,35 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
payload.Video = videos[0]
}
offer, err := h.webrtc.CreatePeer(session, payload.Video)
var err error
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate)
if err != nil {
return err
}
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
// set webrtc as paused if session has private mode enabled
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil && session.PrivateModeEnabled() {
if session.PrivateModeEnabled() {
webrtcPeer.SetPaused(true)
}
payload.Video = webrtcPeer.GetVideoId()
}
session.Send(
event.SIGNAL_PROVIDE,
message.SignalProvide{
SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(),
Video: payload.Video,
Video: payload.Video, // TODO: Refactor.
})
return nil
@ -110,15 +123,24 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist")
}
err := peer.SetVideoID(payload.Video)
var err error
if payload.Bitrate == 0 {
// get bitrate from video id
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
if err != nil {
return err
}
}
if err = peer.SetVideoBitrate(payload.Bitrate); err != nil {
return err
}
session.Send(
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: payload.Video,
Video: peer.GetVideoId(), // TODO: Refactor.
Bitrate: payload.Bitrate,
})
return nil

View File

@ -22,7 +22,7 @@ type Sample media.Sample
type Receiver interface {
SetStream(stream StreamSinkManager) error
RemoveStream()
OnVideoIdChange(f func(string) error)
OnBitrateChange(f func(int) error)
}
type BucketsManager interface {
@ -46,6 +46,7 @@ type ScreencastManager interface {
}
type StreamSinkManager interface {
ID() string
Codec() codec.RTPCodec
AddListener(listener *func(sample Sample)) error
@ -70,6 +71,8 @@ type CaptureManager interface {
Start()
Shutdown() error
GetBitrateFromVideoID(videoID string) (int, error)
Broadcast() BroadcastManager
Screencast() ScreencastManager
Audio() StreamSinkManager
@ -83,6 +86,7 @@ type VideoConfig struct {
Width string `mapstructure:"width"` // expression
Height string `mapstructure:"height"` // expression
Fps string `mapstructure:"fps"` // expression
Bitrate int `mapstructure:"bitrate"` // pipeline bitrate
GstPrefix string `mapstructure:"gst_prefix"` // pipeline prefix, starts with !
GstEncoder string `mapstructure:"gst_encoder"` // gst encoder name
GstParams map[string]string `mapstructure:"gst_params"` // map of expressions
@ -173,3 +177,41 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
config.GstSuffix,
}[:], " "), nil
}
func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (int, error) {
return func() (int, error) {
if config.Bitrate > 0 {
return config.Bitrate, nil
}
screen := getScreen()
if screen == nil {
return 0, fmt.Errorf("screen is nil")
}
values := map[string]any{
"width": screen.Width,
"height": screen.Height,
"fps": screen.Rate,
}
language := []gval.Language{
gval.Function("round", func(args ...any) (any, error) {
return (int)(math.Round(args[0].(float64))), nil
}),
}
// TODO: This is only for vp8.
expr, ok := config.GstParams["target-bitrate"]
if !ok {
return 0, fmt.Errorf("target-bitrate not found")
}
targetBitrate, err := gval.Evaluate(expr, values, language...)
if err != nil {
return 0, err
}
return targetBitrate.(int), nil
}
}

View File

@ -48,7 +48,7 @@ type SystemDisconnect struct {
type SignalProvide struct {
SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"`
Video string `json:"video"`
Video string `json:"video"` // TODO: Refactor.
}
type SignalCandidate struct {
@ -60,7 +60,8 @@ type SignalDescription struct {
}
type SignalVideo struct {
Video string `json:"video"`
Video string `json:"video"` // TODO: Refactor.
Bitrate int `json:"bitrate"`
}
/////////////////////////////

View File

@ -7,7 +7,6 @@ import (
)
var (
ErrWebRTCVideoNotFound = errors.New("webrtc video not found")
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
)
@ -25,7 +24,8 @@ type WebRTCPeer interface {
SetAnswer(sdp string) error
SetCandidate(candidate webrtc.ICECandidateInit) error
SetVideoID(videoID string) error
SetVideoBitrate(bitrate int) error
GetVideoId() string
SetPaused(isPaused bool) error
SendCursorPosition(x, y int) error
@ -40,6 +40,6 @@ type WebRTCManager interface {
ICEServers() []ICEServer
CreatePeer(session Session, videoID string) (*webrtc.SessionDescription, error)
CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error)
SetCursorPosition(x, y int)
}