mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
Capture bandwidth switch (#14)
* Handle bitrate change by finding the stream with closest bitrate as peer * Convert video id into bitrate when creating peer or changing bitrate * Try to fix prometheus panic * Revert metrics label name change * minor fixes. * bitrate selector. * skip if moving to the same stream. * no closure for getting target bitrate. * fix: high res switch to lo video, stream bitrate out of range * revert dev config change. * white space. Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
parent
e0bee67e85
commit
6067367acd
@ -2,6 +2,8 @@ package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -36,17 +38,17 @@ func (m *BucketsManagerCtx) shutdown() {
|
||||
}
|
||||
|
||||
func (m *BucketsManagerCtx) destroyAll() {
|
||||
for _, video := range m.streams {
|
||||
if video.Started() {
|
||||
video.destroyPipeline()
|
||||
for _, stream := range m.streams {
|
||||
if stream.Started() {
|
||||
stream.destroyPipeline()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BucketsManagerCtx) recreateAll() error {
|
||||
for _, video := range m.streams {
|
||||
if video.Started() {
|
||||
err := video.createPipeline()
|
||||
for _, stream := range m.streams {
|
||||
if stream.Started() {
|
||||
err := stream.createPipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
@ -65,22 +67,39 @@ func (m *BucketsManagerCtx) Codec() codec.RTPCodec {
|
||||
}
|
||||
|
||||
func (m *BucketsManagerCtx) SetReceiver(receiver types.Receiver) error {
|
||||
receiver.OnVideoIdChange(func(videoID string) error {
|
||||
videoStream, ok := m.streams[videoID]
|
||||
receiver.OnBitrateChange(func(bitrate int) error {
|
||||
stream, ok := m.findNearestStream(bitrate)
|
||||
if !ok {
|
||||
return types.ErrWebRTCVideoNotFound
|
||||
return fmt.Errorf("no stream found for bitrate %d", bitrate)
|
||||
}
|
||||
|
||||
return receiver.SetStream(videoStream)
|
||||
return receiver.SetStream(stream)
|
||||
})
|
||||
|
||||
// TODO: Save receiver.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BucketsManagerCtx) findNearestStream(bitrate int) (ss *StreamSinkManagerCtx, ok bool) {
|
||||
minDiff := math.MaxInt
|
||||
for _, s := range m.streams {
|
||||
streamBitrate, err := s.Bitrate()
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msgf("failed to get bitrate for stream %s", s.ID())
|
||||
continue
|
||||
}
|
||||
|
||||
diffAbs := int(math.Abs(float64(bitrate - streamBitrate)))
|
||||
|
||||
if diffAbs < minDiff {
|
||||
minDiff, ss = diffAbs, s
|
||||
}
|
||||
}
|
||||
ok = ss != nil
|
||||
return
|
||||
}
|
||||
|
||||
func (m *BucketsManagerCtx) RemoveReceiver(receiver types.Receiver) error {
|
||||
// TODO: Unsubribe from OnVideoIdChange.
|
||||
// TODO: Remove receiver.
|
||||
receiver.OnBitrateChange(nil)
|
||||
receiver.RemoveStream()
|
||||
return nil
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
type CaptureManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
desktop types.DesktopManager
|
||||
config *config.Capture
|
||||
|
||||
// sinks
|
||||
broadcast *BroacastManagerCtx
|
||||
@ -66,13 +67,18 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
||||
Str("pipeline", pipeline).
|
||||
Msg("syntax check for video stream pipeline passed")
|
||||
|
||||
getVideoBitrate := pipelineConf.GetBitrateFn(desktop.GetScreenSize)
|
||||
if err != nil {
|
||||
logger.Panic().Err(err).Msg("unable to get video bitrate")
|
||||
}
|
||||
// append to videos
|
||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
|
||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id, getVideoBitrate)
|
||||
}
|
||||
|
||||
return &CaptureManagerCtx{
|
||||
logger: logger,
|
||||
desktop: desktop,
|
||||
config: config,
|
||||
|
||||
// sinks
|
||||
broadcast: broadcastNew(func(url string) (string, error) {
|
||||
@ -132,7 +138,7 @@ func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCt
|
||||
"! %s "+
|
||||
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
|
||||
), nil
|
||||
}, "audio"),
|
||||
}, "audio", nil),
|
||||
video: bucketsNew(config.VideoCodec, videos, config.VideoIDs),
|
||||
|
||||
// sources
|
||||
@ -242,6 +248,15 @@ func (manager *CaptureManagerCtx) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) GetBitrateFromVideoID(videoID string) (int, error) {
|
||||
cfg, ok := manager.config.VideoPipelines[videoID]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("video config not found for %s", videoID)
|
||||
}
|
||||
|
||||
return cfg.GetBitrateFn(manager.desktop.GetScreenSize)()
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
|
||||
return manager.broadcast
|
||||
}
|
||||
|
@ -19,6 +19,9 @@ import (
|
||||
var moveSinkListenerMu = sync.Mutex{}
|
||||
|
||||
type StreamSinkManagerCtx struct {
|
||||
id string
|
||||
getBitrate func() (int, error)
|
||||
|
||||
logger zerolog.Logger
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
@ -37,13 +40,16 @@ type StreamSinkManagerCtx struct {
|
||||
pipelinesActive prometheus.Gauge
|
||||
}
|
||||
|
||||
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), video_id string) *StreamSinkManagerCtx {
|
||||
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-sink").
|
||||
Str("video_id", video_id).Logger()
|
||||
Str("id", id).Logger()
|
||||
|
||||
manager := &StreamSinkManagerCtx{
|
||||
id: id,
|
||||
getBitrate: getBitrate,
|
||||
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
pipelineFn: pipelineFn,
|
||||
@ -56,7 +62,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
|
||||
Subsystem: "capture",
|
||||
Help: "Current number of listeners for a pipeline.",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": video_id,
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
@ -68,7 +74,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
|
||||
Help: "Total number of created pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": video_id,
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
@ -80,7 +86,7 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), vide
|
||||
Help: "Total number of active pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": video_id,
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
@ -103,6 +109,18 @@ func (manager *StreamSinkManagerCtx) shutdown() {
|
||||
manager.wg.Wait()
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) ID() string {
|
||||
return manager.id
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Bitrate() (int, error) {
|
||||
if manager.getBitrate == nil {
|
||||
return 0, nil
|
||||
}
|
||||
// recalculate bitrate every time, take screen resolution (and fps) into account
|
||||
return manager.getBitrate()
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
|
||||
return api.NewPeerConnection(manager.webrtcConfiguration)
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID string) (*webrtc.SessionDescription, error) {
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) {
|
||||
id := atomic.AddInt32(&manager.peerId, 1)
|
||||
manager.metrics.NewConnection(session)
|
||||
|
||||
@ -280,11 +280,12 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set default video id
|
||||
err = videoTrack.SetVideoID(videoID)
|
||||
if err != nil {
|
||||
// set initial video bitrate
|
||||
if err = videoTrack.SetBitrate(bitrate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
videoID := videoTrack.stream.ID()
|
||||
manager.metrics.SetVideoID(session, videoID)
|
||||
|
||||
// data channel
|
||||
@ -298,14 +299,19 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
logger: logger,
|
||||
connection: connection,
|
||||
dataChannel: dataChannel,
|
||||
changeVideo: func(videoID string) error {
|
||||
if err := videoTrack.SetVideoID(videoID); err != nil {
|
||||
changeVideo: func(bitrate int) error {
|
||||
if err := videoTrack.SetBitrate(bitrate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
videoID := videoTrack.stream.ID()
|
||||
manager.metrics.SetVideoID(session, videoID)
|
||||
return nil
|
||||
},
|
||||
// TODO: Refactor.
|
||||
videoId: func() string {
|
||||
return videoTrack.stream.ID()
|
||||
},
|
||||
setPaused: func(isPaused bool) {
|
||||
videoTrack.SetPaused(isPaused)
|
||||
audioTrack.SetPaused(isPaused)
|
||||
@ -418,7 +424,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
connection.Close()
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
session.SetWebRTCConnected(peer, false)
|
||||
video.RemoveReceiver(videoTrack)
|
||||
if err = video.RemoveReceiver(videoTrack); err != nil {
|
||||
logger.Err(err).Msg("failed to remove video receiver")
|
||||
}
|
||||
audioTrack.RemoveStream()
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,8 @@ type WebRTCPeerCtx struct {
|
||||
logger zerolog.Logger
|
||||
connection *webrtc.PeerConnection
|
||||
dataChannel *webrtc.DataChannel
|
||||
changeVideo func(videoID string) error
|
||||
changeVideo func(bitrate int) error
|
||||
videoId func() string
|
||||
setPaused func(isPaused bool)
|
||||
iceTrickle bool
|
||||
}
|
||||
@ -114,7 +115,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
|
||||
return peer.connection.AddICECandidate(candidate)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
|
||||
func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
@ -122,8 +123,16 @@ func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
|
||||
return types.ErrWebRTCConnectionNotFound
|
||||
}
|
||||
|
||||
peer.logger.Info().Str("video_id", videoID).Msg("change video id")
|
||||
return peer.changeVideo(videoID)
|
||||
peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate")
|
||||
return peer.changeVideo(bitrate)
|
||||
}
|
||||
|
||||
// TODO: Refactor.
|
||||
func (peer *WebRTCPeerCtx) GetVideoId() string {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.videoId()
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||
|
@ -27,7 +27,7 @@ type Track struct {
|
||||
onRtcp func(rtcp.Packet)
|
||||
onRtcpMu sync.RWMutex
|
||||
|
||||
videoIdChange func(string) error
|
||||
bitrateChange func(int) error
|
||||
}
|
||||
|
||||
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) {
|
||||
@ -140,14 +140,14 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
|
||||
t.onRtcp = f
|
||||
}
|
||||
|
||||
func (t *Track) SetVideoID(videoID string) error {
|
||||
if t.videoIdChange == nil {
|
||||
return fmt.Errorf("video id change not supported")
|
||||
func (t *Track) SetBitrate(bitrate int) error {
|
||||
if t.bitrateChange == nil {
|
||||
return fmt.Errorf("bitrate change not supported")
|
||||
}
|
||||
|
||||
return t.videoIdChange(videoID)
|
||||
return t.bitrateChange(bitrate)
|
||||
}
|
||||
|
||||
func (t *Track) OnVideoIdChange(f func(string) error) {
|
||||
t.videoIdChange = f
|
||||
func (t *Track) OnBitrateChange(f func(int) error) {
|
||||
t.bitrateChange = f
|
||||
}
|
||||
|
@ -19,14 +19,27 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
||||
payload.Video = videos[0]
|
||||
}
|
||||
|
||||
offer, err := h.webrtc.CreatePeer(session, payload.Video)
|
||||
var err error
|
||||
if payload.Bitrate == 0 {
|
||||
// get bitrate from video id
|
||||
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
offer, err := h.webrtc.CreatePeer(session, payload.Bitrate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set webrtc as paused if session has private mode enabled
|
||||
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil && session.PrivateModeEnabled() {
|
||||
webrtcPeer.SetPaused(true)
|
||||
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
|
||||
// set webrtc as paused if session has private mode enabled
|
||||
if session.PrivateModeEnabled() {
|
||||
webrtcPeer.SetPaused(true)
|
||||
}
|
||||
|
||||
payload.Video = webrtcPeer.GetVideoId()
|
||||
}
|
||||
|
||||
session.Send(
|
||||
@ -34,7 +47,7 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
|
||||
message.SignalProvide{
|
||||
SDP: offer.SDP,
|
||||
ICEServers: h.webrtc.ICEServers(),
|
||||
Video: payload.Video,
|
||||
Video: payload.Video, // TODO: Refactor.
|
||||
})
|
||||
|
||||
return nil
|
||||
@ -110,15 +123,24 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
err := peer.SetVideoID(payload.Video)
|
||||
if err != nil {
|
||||
var err error
|
||||
if payload.Bitrate == 0 {
|
||||
// get bitrate from video id
|
||||
payload.Bitrate, err = h.capture.GetBitrateFromVideoID(payload.Video)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = peer.SetVideoBitrate(payload.Bitrate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_VIDEO,
|
||||
message.SignalVideo{
|
||||
Video: payload.Video,
|
||||
Video: peer.GetVideoId(), // TODO: Refactor.
|
||||
Bitrate: payload.Bitrate,
|
||||
})
|
||||
|
||||
return nil
|
||||
|
@ -22,7 +22,7 @@ type Sample media.Sample
|
||||
type Receiver interface {
|
||||
SetStream(stream StreamSinkManager) error
|
||||
RemoveStream()
|
||||
OnVideoIdChange(f func(string) error)
|
||||
OnBitrateChange(f func(int) error)
|
||||
}
|
||||
|
||||
type BucketsManager interface {
|
||||
@ -46,6 +46,7 @@ type ScreencastManager interface {
|
||||
}
|
||||
|
||||
type StreamSinkManager interface {
|
||||
ID() string
|
||||
Codec() codec.RTPCodec
|
||||
|
||||
AddListener(listener *func(sample Sample)) error
|
||||
@ -70,6 +71,8 @@ type CaptureManager interface {
|
||||
Start()
|
||||
Shutdown() error
|
||||
|
||||
GetBitrateFromVideoID(videoID string) (int, error)
|
||||
|
||||
Broadcast() BroadcastManager
|
||||
Screencast() ScreencastManager
|
||||
Audio() StreamSinkManager
|
||||
@ -83,6 +86,7 @@ type VideoConfig struct {
|
||||
Width string `mapstructure:"width"` // expression
|
||||
Height string `mapstructure:"height"` // expression
|
||||
Fps string `mapstructure:"fps"` // expression
|
||||
Bitrate int `mapstructure:"bitrate"` // pipeline bitrate
|
||||
GstPrefix string `mapstructure:"gst_prefix"` // pipeline prefix, starts with !
|
||||
GstEncoder string `mapstructure:"gst_encoder"` // gst encoder name
|
||||
GstParams map[string]string `mapstructure:"gst_params"` // map of expressions
|
||||
@ -173,3 +177,41 @@ func (config *VideoConfig) GetPipeline(screen ScreenSize) (string, error) {
|
||||
config.GstSuffix,
|
||||
}[:], " "), nil
|
||||
}
|
||||
|
||||
func (config *VideoConfig) GetBitrateFn(getScreen func() *ScreenSize) func() (int, error) {
|
||||
return func() (int, error) {
|
||||
if config.Bitrate > 0 {
|
||||
return config.Bitrate, nil
|
||||
}
|
||||
|
||||
screen := getScreen()
|
||||
if screen == nil {
|
||||
return 0, fmt.Errorf("screen is nil")
|
||||
}
|
||||
|
||||
values := map[string]any{
|
||||
"width": screen.Width,
|
||||
"height": screen.Height,
|
||||
"fps": screen.Rate,
|
||||
}
|
||||
|
||||
language := []gval.Language{
|
||||
gval.Function("round", func(args ...any) (any, error) {
|
||||
return (int)(math.Round(args[0].(float64))), nil
|
||||
}),
|
||||
}
|
||||
|
||||
// TODO: This is only for vp8.
|
||||
expr, ok := config.GstParams["target-bitrate"]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("target-bitrate not found")
|
||||
}
|
||||
|
||||
targetBitrate, err := gval.Evaluate(expr, values, language...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return targetBitrate.(int), nil
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ type SystemDisconnect struct {
|
||||
type SignalProvide struct {
|
||||
SDP string `json:"sdp"`
|
||||
ICEServers []types.ICEServer `json:"iceservers"`
|
||||
Video string `json:"video"`
|
||||
Video string `json:"video"` // TODO: Refactor.
|
||||
}
|
||||
|
||||
type SignalCandidate struct {
|
||||
@ -60,7 +60,8 @@ type SignalDescription struct {
|
||||
}
|
||||
|
||||
type SignalVideo struct {
|
||||
Video string `json:"video"`
|
||||
Video string `json:"video"` // TODO: Refactor.
|
||||
Bitrate int `json:"bitrate"`
|
||||
}
|
||||
|
||||
/////////////////////////////
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWebRTCVideoNotFound = errors.New("webrtc video not found")
|
||||
ErrWebRTCDataChannelNotFound = errors.New("webrtc data channel not found")
|
||||
ErrWebRTCConnectionNotFound = errors.New("webrtc connection not found")
|
||||
)
|
||||
@ -25,7 +24,8 @@ type WebRTCPeer interface {
|
||||
SetAnswer(sdp string) error
|
||||
SetCandidate(candidate webrtc.ICECandidateInit) error
|
||||
|
||||
SetVideoID(videoID string) error
|
||||
SetVideoBitrate(bitrate int) error
|
||||
GetVideoId() string
|
||||
SetPaused(isPaused bool) error
|
||||
|
||||
SendCursorPosition(x, y int) error
|
||||
@ -40,6 +40,6 @@ type WebRTCManager interface {
|
||||
|
||||
ICEServers() []ICEServer
|
||||
|
||||
CreatePeer(session Session, videoID string) (*webrtc.SessionDescription, error)
|
||||
CreatePeer(session Session, bitrate int) (*webrtc.SessionDescription, error)
|
||||
SetCursorPosition(x, y int)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user