From 6067367acd0a6ed08e8d42e12fcda1467862eb7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Tue, 25 Oct 2022 20:25:00 +0200 Subject: [PATCH] Capture bandwidth switch (#14) * Handle bitrate change by finding the stream with closest bitrate as peer * Convert video id into bitrate when creating peer or changing bitrate * Try to fix prometheus panic * Revert metrics label name change * minor fixes. * bitrate selector. * skip if moving to the same stream. * no closure for getting target bitrate. * fix: high res switch to lo video, stream bitrate out of range * revert dev config change. * white space. Co-authored-by: Aleksandar Sukovic --- internal/capture/buckets.go | 45 ++++++++++++++++++++-------- internal/capture/manager.go | 19 ++++++++++-- internal/capture/streamsink.go | 28 +++++++++++++---- internal/webrtc/manager.go | 22 +++++++++----- internal/webrtc/peer.go | 17 ++++++++--- internal/webrtc/track.go | 14 ++++----- internal/websocket/handler/signal.go | 38 ++++++++++++++++++----- pkg/types/capture.go | 44 ++++++++++++++++++++++++++- pkg/types/message/messages.go | 5 ++-- pkg/types/webrtc.go | 6 ++-- 10 files changed, 186 insertions(+), 52 deletions(-) diff --git a/internal/capture/buckets.go b/internal/capture/buckets.go index c35986aa..04a87f09 100644 --- a/internal/capture/buckets.go +++ b/internal/capture/buckets.go @@ -2,6 +2,8 @@ package capture import ( "errors" + "fmt" + "math" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -36,17 +38,17 @@ func (m *BucketsManagerCtx) shutdown() { } func (m *BucketsManagerCtx) destroyAll() { - for _, video := range m.streams { - if video.Started() { - video.destroyPipeline() + for _, stream := range m.streams { + if stream.Started() { + stream.destroyPipeline() } } } func (m *BucketsManagerCtx) recreateAll() error { - for _, video := range m.streams { - if video.Started() { - err := video.createPipeline() + for _, stream := range m.streams { + if stream.Started() { + err := stream.createPipeline() if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { return err } @@ -65,22 +67,39 @@ func (m *BucketsManagerCtx) Codec() codec.RTPCodec { } func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error { - receiver.OnVideoIdChange(func(videoID string) error { - videoStream, ok := m.streams[videoID] + receiver.OnBitrateChange(func(bitrate int) error { + stream, ok := m.findNearestStream(bitrate) if !ok { - return types.ErrWebRTCVideoNotFound + return fmt.Errorf("no stream found for bitrate %d", bitrate) } - return receiver.SetStream(videoStream) + return receiver.SetStream(stream) }) - // TODO: Save receiver. return nil } +func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) { + minDiff := math.MaxInt + for _, s := range m.streams { + streamBitrate, err := s.Bitrate() + if err != nil { + m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID()) + continue + } + + diffAbs := int(math.Abs(float64(bitrate - streamBitrate))) + + if diffAbs < minDiff { + minDiff, ss = diffAbs, s + } + } + ok = ss != nil + return +} + func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error { - // TODO: Unsubribe from OnVideoIdChange. - // TODO: Remove receiver. + receiver.OnBitrateChange(nil) receiver.RemoveStream() return nil } diff --git a/internal/capture/manager.go b/internal/capture/manager.go index c15b22fc..2f96c8e1 100644 --- a/internal/capture/manager.go +++ b/internal/capture/manager.go @@ -16,6 +16,7 @@ import ( type CaptureManagerCtx struct { logger zerolog.Logger desktop types.DesktopManager + config *config.Capture // sinks broadcast *BroacastManagerCtx @@ -66,13 +67,18 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt Str("pipeline", pipeline). Msg("syntax check for video stream pipeline passed") + getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize) + if err != nil { + logger.Panic().Err(err).Msg("unable to get video bitrate") + } // append to videos - videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id) + videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate) } return &CaptureManagerCtx{ logger: logger, desktop: desktop, + config: config, // sinks broadcast: broadcastNew(func(url string) (string, error) { @@ -132,7 +138,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt "! %s "+ "! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline, ), nil - }, "audio"), + }, "audio", nil), video: bucketsNew(config.VideoCodec, videos, config.VideoIDs), // sources @@ -242,6 +248,15 @@ func (manager *CaptureManagerCtx) Shutdown() error { return nil } +func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) { + cfg, ok := manager.config.VideoPipelines[videoID] + if !ok { + return 0, fmt.Errorf("video config not found for %s", videoID) + } + + return cfg.GetBitrateFn(manager.desktop.GetScreenSize)() +} + func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager { return manager.broadcast } diff --git a/internal/capture/streamsink.go b/internal/capture/streamsink.go index 25df5259..cb6225af 100644 --- a/internal/capture/streamsink.go +++ b/internal/capture/streamsink.go @@ -19,6 +19,9 @@ import ( var moveSinkListenerMu = sync.Mutex{} type StreamSinkManagerCtx struct { + id string + getBitrate func() (int, error) + logger zerolog.Logger mu sync.Mutex wg sync.WaitGroup @@ -37,13 +40,16 @@ type StreamSinkManagerCtx struct { pipelinesActive prometheus.Gauge } -func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), video_id string) *StreamSinkManagerCtx { +func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx { logger := log.With(). Str("module", "capture"). Str("submodule", "stream-sink"). - Str("video_id", video_id).Logger() + Str("id", id).Logger() manager := &StreamSinkManagerCtx{ + id: id, + getBitrate: getBitrate, + logger: logger, codec: codec, pipelineFn: pipelineFn, @@ -56,7 +62,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide Subsystem: "capture", Help: "Current number of listeners for a pipeline.", ConstLabels: map[string]string{ - "video_id": video_id, + "video_id": id, "codec_name": codec.Name, "codec_type": codec.Type.String(), }, @@ -68,7 +74,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide Help: "Total number of created pipelines.", ConstLabels: map[string]string{ "submodule": "streamsink", - "video_id": video_id, + "video_id": id, "codec_name": codec.Name, "codec_type": codec.Type.String(), }, @@ -80,7 +86,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide Help: "Total number of active pipelines.", ConstLabels: map[string]string{ "submodule": "streamsink", - "video_id": video_id, + "video_id": id, "codec_name": codec.Name, "codec_type": codec.Type.String(), }, @@ -103,6 +109,18 @@ func (manager *StreamSinkManagerCtx) shutdown() { manager.wg.Wait() } +func (manager *StreamSinkManagerCtx) ID() string { + return manager.id +} + +func (manager *StreamSinkManagerCtx) Bitrate() (int, error) { + if manager.getBitrate == nil { + return 0, nil + } + // recalculate bitrate every time, take screen resolution (and fps) into account + return manager.getBitrate() +} + func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { return manager.codec } diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 48f68a92..33a6f3df 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -214,7 +214,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg return api.NewPeerConnection(manager.webrtcConfiguration) } -func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID string) (*webrtc.SessionDescription, error) { +func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) { id := atomic.AddInt32(&manager.peerId, 1) manager.metrics.NewConnection(session) @@ -280,11 +280,12 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin return nil, err } - // set default video id - err = videoTrack.SetVideoID(videoID) - if err != nil { + // set initial video bitrate + if err = videoTrack.SetBitrate(bitrate); err != nil { return nil, err } + + videoID := videoTrack.stream.ID() manager.metrics.SetVideoID(session, videoID) // data channel @@ -298,14 +299,19 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin logger: logger, connection: connection, dataChannel: dataChannel, - changeVideo: func(videoID string) error { - if err := videoTrack.SetVideoID(videoID); err != nil { + changeVideo: func(bitrate int) error { + if err := videoTrack.SetBitrate(bitrate); err != nil { return err } + videoID := videoTrack.stream.ID() manager.metrics.SetVideoID(session, videoID) return nil }, + // TODO: Refactor. + videoId: func() string { + return videoTrack.stream.ID() + }, setPaused: func(isPaused bool) { videoTrack.SetPaused(isPaused) audioTrack.SetPaused(isPaused) @@ -418,7 +424,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin connection.Close() case webrtc.PeerConnectionStateClosed: session.SetWebRTCConnected(peer, false) - video.RemoveReceiver(videoTrack) + if err = video.RemoveReceiver(videoTrack); err != nil { + logger.Err(err).Msg("failed to remove video receiver") + } audioTrack.RemoveStream() } diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index e9be20f1..bee30a64 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -17,7 +17,8 @@ type WebRTCPeerCtx struct { logger zerolog.Logger connection *webrtc.PeerConnection dataChannel *webrtc.DataChannel - changeVideo func(videoID string) error + changeVideo func(bitrate int) error + videoId func() string setPaused func(isPaused bool) iceTrickle bool } @@ -114,7 +115,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error return peer.connection.AddICECandidate(candidate) } -func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { +func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error { peer.mu.Lock() defer peer.mu.Unlock() @@ -122,8 +123,16 @@ func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error { return types.ErrWebRTCConnectionNotFound } - peer.logger.Info().Str("video_id", videoID).Msg("change video id") - return peer.changeVideo(videoID) + peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate") + return peer.changeVideo(bitrate) +} + +// TODO: Refactor. +func (peer *WebRTCPeerCtx) GetVideoId() string { + peer.mu.Lock() + defer peer.mu.Unlock() + + return peer.videoId() } func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error { diff --git a/internal/webrtc/track.go b/internal/webrtc/track.go index b62c4800..516af9bf 100644 --- a/internal/webrtc/track.go +++ b/internal/webrtc/track.go @@ -27,7 +27,7 @@ type Track struct { onRtcp func(rtcp.Packet) onRtcpMu sync.RWMutex - videoIdChange func(string) error + bitrateChange func(int) error } func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) { @@ -140,14 +140,14 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) { t.onRtcp = f } -func (t *Track) SetVideoID(videoID string) error { - if t.videoIdChange == nil { - return fmt.Errorf("video id change not supported") +func (t *Track) SetBitrate(bitrate int) error { + if t.bitrateChange == nil { + return fmt.Errorf("bitrate change not supported") } - return t.videoIdChange(videoID) + return t.bitrateChange(bitrate) } -func (t *Track) OnVideoIdChange(f func(string) error) { - t.videoIdChange = f +func (t *Track) OnBitrateChange(f func(int) error) { + t.bitrateChange = f } diff --git a/internal/websocket/handler/signal.go b/internal/websocket/handler/signal.go index 5fa65597..9da371e8 100644 --- a/internal/websocket/handler/signal.go +++ b/internal/websocket/handler/signal.go @@ -19,14 +19,27 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag payload.Video = videos[0] } - offer, err := h.webrtc.CreatePeer(session, payload.Video) + var err error + if payload.Bitrate == 0 { + // get bitrate from video id + payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video) + if err != nil { + return err + } + } + + offer, err := h.webrtc.CreatePeer(session, payload.Bitrate) if err != nil { return err } - // set webrtc as paused if session has private mode enabled - if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil && session.PrivateModeEnabled() { - webrtcPeer.SetPaused(true) + if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil { + // set webrtc as paused if session has private mode enabled + if session.PrivateModeEnabled() { + webrtcPeer.SetPaused(true) + } + + payload.Video = webrtcPeer.GetVideoId() } session.Send( @@ -34,7 +47,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag message.SignalProvide{ SDP: offer.SDP, ICEServers: h.webrtc.ICEServers(), - Video: payload.Video, + Video: payload.Video, // TODO: Refactor. }) return nil @@ -110,15 +123,24 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message. return errors.New("webRTC peer does not exist") } - err := peer.SetVideoID(payload.Video) - if err != nil { + var err error + if payload.Bitrate == 0 { + // get bitrate from video id + payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video) + if err != nil { + return err + } + } + + if err = peer.SetVideoBitrate(payload.Bitrate); err != nil { return err } session.Send( event.SIGNAL_VIDEO, message.SignalVideo{ - Video: payload.Video, + Video: peer.GetVideoId(), // TODO: Refactor. + Bitrate: payload.Bitrate, }) return nil diff --git a/pkg/types/capture.go b/pkg/types/capture.go index 857132f2..a3a490ee 100644 --- a/pkg/types/capture.go +++ b/pkg/types/capture.go @@ -22,7 +22,7 @@ type Sample media.Sample type Receiver interface { SetStream(stream StreamSinkManager) error RemoveStream() - OnVideoIdChange(f func(string) error) + OnBitrateChange(f func(int) error) } type BucketsManager interface { @@ -46,6 +46,7 @@ type ScreencastManager interface { } type StreamSinkManager interface { + ID() string Codec() codec.RTPCodec AddListener(listener *func(sample Sample)) error @@ -70,6 +71,8 @@ type CaptureManager interface { Start() Shutdown() error + GetBitrateFromVideoID(videoID string) (int, error) + Broadcast() BroadcastManager Screencast() ScreencastManager Audio() StreamSinkManager @@ -83,6 +86,7 @@ type VideoConfig struct { Width string `mapstructure:"width"` // expression Height string `mapstructure:"height"` // expression Fps string `mapstructure:"fps"` // expression + Bitrate int `mapstructure:"bitrate"` // pipeline bitrate GstPrefix string `mapstructure:"gst_prefix"` // pipeline prefix, starts with ! GstEncoder string `mapstructure:"gst_encoder"` // gst encoder name GstParams map[string]string `mapstructure:"gst_params"` // map of expressions @@ -173,3 +177,41 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) { config.GstSuffix, }[:], " "), nil } + +func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (int, error) { + return func() (int, error) { + if config.Bitrate > 0 { + return config.Bitrate, nil + } + + screen := getScreen() + if screen == nil { + return 0, fmt.Errorf("screen is nil") + } + + values := map[string]any{ + "width": screen.Width, + "height": screen.Height, + "fps": screen.Rate, + } + + language := []gval.Language{ + gval.Function("round", func(args ...any) (any, error) { + return (int)(math.Round(args[0].(float64))), nil + }), + } + + // TODO: This is only for vp8. + expr, ok := config.GstParams["target-bitrate"] + if !ok { + return 0, fmt.Errorf("target-bitrate not found") + } + + targetBitrate, err := gval.Evaluate(expr, values, language...) + if err != nil { + return 0, err + } + + return targetBitrate.(int), nil + } +} diff --git a/pkg/types/message/messages.go b/pkg/types/message/messages.go index cc4172f9..c263a006 100644 --- a/pkg/types/message/messages.go +++ b/pkg/types/message/messages.go @@ -48,7 +48,7 @@ type SystemDisconnect struct { type SignalProvide struct { SDP string `json:"sdp"` ICEServers []types.ICEServer `json:"iceservers"` - Video string `json:"video"` + Video string `json:"video"` // TODO: Refactor. } type SignalCandidate struct { @@ -60,7 +60,8 @@ type SignalDescription struct { } type SignalVideo struct { - Video string `json:"video"` + Video string `json:"video"` // TODO: Refactor. + Bitrate int `json:"bitrate"` } ///////////////////////////// diff --git a/pkg/types/webrtc.go b/pkg/types/webrtc.go index 6063f203..c67f4857 100644 --- a/pkg/types/webrtc.go +++ b/pkg/types/webrtc.go @@ -7,7 +7,6 @@ import ( ) var ( - ErrWebRTCVideoNotFound = errors.New("webrtc video not found") ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found") ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found") ) @@ -25,7 +24,8 @@ type WebRTCPeer interface { SetAnswer(sdp string) error SetCandidate(candidate webrtc.ICECandidateInit) error - SetVideoID(videoID string) error + SetVideoBitrate(bitrate int) error + GetVideoId() string SetPaused(isPaused bool) error SendCursorPosition(x, y int) error @@ -40,6 +40,6 @@ type WebRTCManager interface { ICEServers() []ICEServer - CreatePeer(session Session, videoID string) (*webrtc.SessionDescription, error) + CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error) SetCursorPosition(x, y int) }