mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
WebRTC congestion control (#26)
* Add congestion control * Improve stream matching, add manual stream selection, add metrics * Use a ticker for bitrate estimation and make bandwidth drops switch to lower streams more aggressively * Missing signal response, fix video auto bug * Remove redundant mutex * Bitrate history queue * Get bitrate fn support h264 & float64 --------- Co-authored-by: Aleksandar Sukovic <aleksandar.sukovic@gmail.com>
This commit is contained in:
@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/interceptor/pkg/cc"
|
||||
"github.com/pion/interceptor/pkg/gcc"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
@ -35,6 +37,9 @@ const keepAliveInterval = 2 * time.Second
|
||||
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
|
||||
const rtcpPLIInterval = 3 * time.Second
|
||||
|
||||
// how often we check the bitrate of each client. Default is 250ms
|
||||
const bitrateCheckInterval = 250 * time.Millisecond
|
||||
|
||||
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
|
||||
configuration := webrtc.Configuration{
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
|
||||
@ -153,12 +158,12 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
|
||||
return manager.config.ICEServers
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, error) {
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||
// create media engine
|
||||
engine := &webrtc.MediaEngine{}
|
||||
for _, codec := range codecs {
|
||||
if err := codec.Register(engine); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,8 +210,29 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
|
||||
|
||||
// create interceptor registry
|
||||
registry := &interceptor.Registry{}
|
||||
|
||||
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
||||
if bitrate == 0 {
|
||||
bitrate = 1000000
|
||||
}
|
||||
return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(bitrate))
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
||||
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
|
||||
estimatorChan <- estimator
|
||||
})
|
||||
|
||||
registry.Add(congestionController)
|
||||
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// create new API
|
||||
@ -217,10 +243,12 @@ func (manager *WebRTCManagerCtx) newPeerConnection(codecs []codec.RTPCodec, logg
|
||||
)
|
||||
|
||||
// create new peer connection
|
||||
return api.NewPeerConnection(manager.webrtcConfiguration)
|
||||
configuration := manager.webrtcConfiguration
|
||||
connection, err := api.NewPeerConnection(configuration)
|
||||
return connection, <-estimatorChan, err
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) (*webrtc.SessionDescription, error) {
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, videoAuto bool) (*webrtc.SessionDescription, error) {
|
||||
id := atomic.AddInt32(&manager.peerId, 1)
|
||||
manager.metrics.NewConnection(session)
|
||||
|
||||
@ -236,7 +264,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
|
||||
video := manager.capture.Video()
|
||||
videoCodec := video.Codec()
|
||||
|
||||
connection, err := manager.newPeerConnection([]codec.RTPCodec{
|
||||
connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{
|
||||
audioCodec,
|
||||
videoCodec,
|
||||
}, logger)
|
||||
@ -244,6 +272,10 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bitrate == 0 {
|
||||
bitrate = estimator.GetTargetBitrate()
|
||||
}
|
||||
|
||||
// asynchronously send local ICE Candidates
|
||||
if manager.config.ICETrickle {
|
||||
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
@ -268,31 +300,117 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
|
||||
}
|
||||
|
||||
// set stream for audio track
|
||||
err = audioTrack.SetStream(audio)
|
||||
_, err = audioTrack.SetStream(audio)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// video track
|
||||
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection)
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection, WithVideoAuto(videoAuto))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// let video stream bucket manager handle stream subscriptions
|
||||
err = video.SetReceiver(videoTrack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
video.SetReceiver(videoTrack)
|
||||
|
||||
changeVideoFromBitrate := func(peerBitrate int) {
|
||||
// when switching from manual to auto bitrate estimation, in case the estimator is
|
||||
// idle (lastBitrate > maxBitrate), we want to go back to the previous estimated bitrate
|
||||
if peerBitrate == 0 {
|
||||
peerBitrate = estimator.GetTargetBitrate()
|
||||
manager.logger.Debug().
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Msg("evaluated bitrate")
|
||||
}
|
||||
|
||||
ok, err := videoTrack.SetBitrate(peerBitrate)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Msg("unable to set video bitrate")
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
videoID := videoTrack.stream.ID()
|
||||
bitrate := videoTrack.stream.Bitrate()
|
||||
|
||||
manager.metrics.SetVideoID(session, videoID)
|
||||
manager.logger.Debug().
|
||||
Int("peer_bitrate", peerBitrate).
|
||||
Int("video_bitrate", bitrate).
|
||||
Str("video_id", videoID).
|
||||
Msg("peer bitrate triggered video stream change")
|
||||
|
||||
go session.Send(
|
||||
event.SIGNAL_VIDEO,
|
||||
message.SignalVideo{
|
||||
Video: videoID,
|
||||
Bitrate: bitrate,
|
||||
VideoAuto: videoTrack.VideoAuto(),
|
||||
})
|
||||
}
|
||||
|
||||
changeVideoFromID := func(videoID string) (bitrate int) {
|
||||
changed, err := videoTrack.SetVideoID(videoID)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).
|
||||
Str("video_id", videoID).
|
||||
Msg("unable to set video stream")
|
||||
return
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
bitrate = videoTrack.stream.Bitrate()
|
||||
|
||||
manager.logger.Debug().
|
||||
Str("video_id", videoID).
|
||||
Int("video_bitrate", bitrate).
|
||||
Msg("peer video id triggered video stream change")
|
||||
|
||||
go session.Send(
|
||||
event.SIGNAL_VIDEO,
|
||||
message.SignalVideo{
|
||||
Video: videoID,
|
||||
Bitrate: bitrate,
|
||||
VideoAuto: videoTrack.VideoAuto(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
manager.logger.Info().
|
||||
Int("target_bitrate", bitrate).
|
||||
Msg("estimated initial peer bitrate")
|
||||
|
||||
// set initial video bitrate
|
||||
if err = videoTrack.SetBitrate(bitrate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
changeVideoFromBitrate(bitrate)
|
||||
|
||||
videoID := videoTrack.stream.ID()
|
||||
manager.metrics.SetVideoID(session, videoID)
|
||||
// use a ticker to get current client target bitrate
|
||||
go func() {
|
||||
ticker := time.NewTicker(bitrateCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
targetBitrate := estimator.GetTargetBitrate()
|
||||
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, float64(targetBitrate))
|
||||
|
||||
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
if !videoTrack.VideoAuto() {
|
||||
continue
|
||||
}
|
||||
changeVideoFromBitrate(targetBitrate)
|
||||
}
|
||||
}()
|
||||
|
||||
// data channel
|
||||
|
||||
@ -302,27 +420,20 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
|
||||
}
|
||||
|
||||
peer := &WebRTCPeerCtx{
|
||||
logger: logger,
|
||||
connection: connection,
|
||||
dataChannel: dataChannel,
|
||||
changeVideo: func(bitrate int) error {
|
||||
if err := videoTrack.SetBitrate(bitrate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
videoID := videoTrack.stream.ID()
|
||||
manager.metrics.SetVideoID(session, videoID)
|
||||
return nil
|
||||
},
|
||||
logger: logger,
|
||||
connection: connection,
|
||||
dataChannel: dataChannel,
|
||||
changeVideoFromBitrate: changeVideoFromBitrate,
|
||||
changeVideoFromID: changeVideoFromID,
|
||||
// TODO: Refactor.
|
||||
videoId: func() string {
|
||||
return videoTrack.stream.ID()
|
||||
},
|
||||
videoId: videoTrack.stream.ID,
|
||||
setPaused: func(isPaused bool) {
|
||||
videoTrack.SetPaused(isPaused)
|
||||
audioTrack.SetPaused(isPaused)
|
||||
},
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
setVideoAuto: videoTrack.SetVideoAuto,
|
||||
getVideoAuto: videoTrack.VideoAuto,
|
||||
}
|
||||
|
||||
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
@ -515,11 +626,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int)
|
||||
})
|
||||
|
||||
videoTrack.OnRTCP(func(p rtcp.Packet) {
|
||||
switch rtcpPacket := p.(type) {
|
||||
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
|
||||
manager.metrics.SetReceiverEstimatedMaximumBitrate(session, rtcpPacket.Bitrate)
|
||||
|
||||
case *rtcp.ReceiverReport:
|
||||
if rtcpPacket, ok := p.(*rtcp.ReceiverReport); ok {
|
||||
l := len(rtcpPacket.Reports)
|
||||
if l > 0 {
|
||||
// use only last report
|
||||
|
@ -327,10 +327,10 @@ func (m *metricsCtx) SetVideoID(session types.Session, videoId string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float32) {
|
||||
func (m *metricsCtx) SetReceiverEstimatedMaximumBitrate(session types.Session, bitrate float64) {
|
||||
met := m.getBySession(session)
|
||||
|
||||
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
|
||||
met.receiverEstimatedMaximumBitrate.Set(bitrate)
|
||||
}
|
||||
|
||||
func (m *metricsCtx) SetReceiverReport(session types.Session, report rtcp.ReceptionReport) {
|
||||
|
@ -13,14 +13,17 @@ import (
|
||||
)
|
||||
|
||||
type WebRTCPeerCtx struct {
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
connection *webrtc.PeerConnection
|
||||
dataChannel *webrtc.DataChannel
|
||||
changeVideo func(bitrate int) error
|
||||
videoId func() string
|
||||
setPaused func(isPaused bool)
|
||||
iceTrickle bool
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
connection *webrtc.PeerConnection
|
||||
dataChannel *webrtc.DataChannel
|
||||
changeVideoFromBitrate func(bitrate int)
|
||||
changeVideoFromID func(id string) int
|
||||
videoId func() string
|
||||
setPaused func(isPaused bool)
|
||||
setVideoAuto func(auto bool)
|
||||
getVideoAuto func() bool
|
||||
iceTrickle bool
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) {
|
||||
@ -115,7 +118,7 @@ func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error
|
||||
return peer.connection.AddICECandidate(candidate)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
|
||||
func (peer *WebRTCPeerCtx) SetVideoBitrate(peerBitrate int) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
@ -123,12 +126,24 @@ func (peer *WebRTCPeerCtx) SetVideoBitrate(bitrate int) error {
|
||||
return types.ErrWebRTCConnectionNotFound
|
||||
}
|
||||
|
||||
peer.logger.Info().Int("bitrate", bitrate).Msg("change video bitrate")
|
||||
return peer.changeVideo(bitrate)
|
||||
peer.changeVideoFromBitrate(peerBitrate)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoID(videoID string) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
if peer.connection == nil {
|
||||
return types.ErrWebRTCConnectionNotFound
|
||||
}
|
||||
|
||||
peer.changeVideoFromID(videoID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Refactor.
|
||||
func (peer *WebRTCPeerCtx) GetVideoId() string {
|
||||
func (peer *WebRTCPeerCtx) GetVideoID() string {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
@ -215,3 +230,11 @@ func (peer *WebRTCPeerCtx) Destroy() {
|
||||
peer.connection = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideoAuto(auto bool) {
|
||||
peer.setVideoAuto(auto)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) VideoAuto() bool {
|
||||
return peer.getVideoAuto()
|
||||
}
|
||||
|
@ -16,10 +16,12 @@ import (
|
||||
)
|
||||
|
||||
type Track struct {
|
||||
logger zerolog.Logger
|
||||
track *webrtc.TrackLocalStaticSample
|
||||
paused bool
|
||||
listener func(sample types.Sample)
|
||||
logger zerolog.Logger
|
||||
track *webrtc.TrackLocalStaticSample
|
||||
paused bool
|
||||
videoAuto bool
|
||||
videoAutoMu sync.RWMutex
|
||||
listener func(sample types.Sample)
|
||||
|
||||
stream types.StreamSinkManager
|
||||
streamMu sync.Mutex
|
||||
@ -27,10 +29,19 @@ type Track struct {
|
||||
onRtcp func(rtcp.Packet)
|
||||
onRtcpMu sync.RWMutex
|
||||
|
||||
bitrateChange func(int) error
|
||||
bitrateChange func(int) (bool, error)
|
||||
videoChange func(string) (bool, error)
|
||||
}
|
||||
|
||||
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection) (*Track, error) {
|
||||
type option func(*Track)
|
||||
|
||||
func WithVideoAuto(auto bool) option {
|
||||
return func(t *Track) {
|
||||
t.videoAuto = auto
|
||||
}
|
||||
}
|
||||
|
||||
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...option) (*Track, error) {
|
||||
id := codec.Type.String()
|
||||
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
|
||||
if err != nil {
|
||||
@ -44,6 +55,10 @@ func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.Pe
|
||||
track: track,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
t.listener = func(sample types.Sample) {
|
||||
if t.paused {
|
||||
return
|
||||
@ -96,13 +111,13 @@ func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Track) SetStream(stream types.StreamSinkManager) error {
|
||||
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if we already listen to the stream, do nothing
|
||||
if t.stream == stream {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
@ -111,12 +126,13 @@ func (t *Track) SetStream(stream types.StreamSinkManager) error {
|
||||
} else {
|
||||
err = stream.AddListener(&t.listener)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.stream = stream
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return err
|
||||
t.stream = stream
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (t *Track) RemoveStream() {
|
||||
@ -140,14 +156,38 @@ func (t *Track) OnRTCP(f func(rtcp.Packet)) {
|
||||
t.onRtcp = f
|
||||
}
|
||||
|
||||
func (t *Track) SetBitrate(bitrate int) error {
|
||||
func (t *Track) SetBitrate(bitrate int) (bool, error) {
|
||||
if t.bitrateChange == nil {
|
||||
return fmt.Errorf("bitrate change not supported")
|
||||
return false, fmt.Errorf("bitrate change not supported")
|
||||
}
|
||||
|
||||
return t.bitrateChange(bitrate)
|
||||
}
|
||||
|
||||
func (t *Track) OnBitrateChange(f func(int) error) {
|
||||
func (t *Track) SetVideoID(videoID string) (bool, error) {
|
||||
if t.videoChange == nil {
|
||||
return false, fmt.Errorf("video change not supported")
|
||||
}
|
||||
|
||||
return t.videoChange(videoID)
|
||||
}
|
||||
|
||||
func (t *Track) OnBitrateChange(f func(bitrate int) (bool, error)) {
|
||||
t.bitrateChange = f
|
||||
}
|
||||
|
||||
func (t *Track) OnVideoChange(f func(string) (bool, error)) {
|
||||
t.videoChange = f
|
||||
}
|
||||
|
||||
func (t *Track) SetVideoAuto(auto bool) {
|
||||
t.videoAutoMu.Lock()
|
||||
defer t.videoAutoMu.Unlock()
|
||||
t.videoAuto = auto
|
||||
}
|
||||
|
||||
func (t *Track) VideoAuto() bool {
|
||||
t.videoAutoMu.RLock()
|
||||
defer t.videoAutoMu.RUnlock()
|
||||
return t.videoAuto
|
||||
}
|
||||
|
Reference in New Issue
Block a user