Refactor signaling for video and audio (#51)

* add audio and signal request.

* disable audio by default.

* fix SignalProvide.

* disable estimator when track disabled.
This commit is contained in:
Miroslav Šedivý 2023-06-26 21:27:14 +02:00 committed by GitHub
parent cf17f4f503
commit e3e9d1606d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 253 additions and 106 deletions

View File

@ -312,6 +312,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess
return nil, nil, err return nil, nil, err
} }
// we disable audio by default manually
audioTrack.SetPaused(true)
// set stream for audio track // set stream for audio track
_, err = audioTrack.SetStream(audio) _, err = audioTrack.SetStream(audio)
if err != nil { if err != nil {
@ -355,7 +358,8 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess
CollapseValues: true, CollapseValues: true,
}), }),
// stream selectors // stream selectors
videoSelector: manager.capture.Video(), video: video,
audio: audio,
// tracks & channels // tracks & channels
audioTrack: audioTrack, audioTrack: audioTrack,
videoTrack: videoTrack, videoTrack: videoTrack,
@ -364,6 +368,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess
// config // config
iceTrickle: manager.config.ICETrickle, iceTrickle: manager.config.ICETrickle,
estimatorConfig: manager.config.Estimator, estimatorConfig: manager.config.Estimator,
audioDisabled: true, // we disable audio by default manually
} }
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {

View File

@ -15,7 +15,6 @@ import (
"github.com/demodesk/neko/internal/webrtc/payload" "github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event" "github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils" "github.com/demodesk/neko/pkg/utils"
) )
@ -29,7 +28,8 @@ type WebRTCPeerCtx struct {
estimator cc.BandwidthEstimator estimator cc.BandwidthEstimator
estimateTrend *utils.TrendDetector estimateTrend *utils.TrendDetector
// stream selectors // stream selectors
videoSelector types.StreamSelectorManager video types.StreamSelectorManager
audio types.StreamSinkManager
// tracks & channels // tracks & channels
audioTrack *Track audioTrack *Track
videoTrack *Track videoTrack *Track
@ -38,7 +38,10 @@ type WebRTCPeerCtx struct {
// config // config
iceTrickle bool iceTrickle bool
estimatorConfig config.WebRTCEstimator estimatorConfig config.WebRTCEstimator
paused bool
videoAuto bool videoAuto bool
videoDisabled bool
audioDisabled bool
} }
// //
@ -158,8 +161,8 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
break break
} }
// if estimation is disabled, do nothing // if estimation or video is disabled, do nothing
if !peer.videoAuto || conf.Passive { if !peer.videoAuto || peer.videoDisabled || peer.paused || conf.Passive {
continue continue
} }
@ -236,9 +239,11 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
continue continue
} }
err := peer.SetVideo(types.StreamSelector{ err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId, ID: streamId,
Type: types.StreamSelectorTypeLower, Type: types.StreamSelectorTypeLower,
},
}) })
if err != nil && err != types.ErrWebRTCStreamNotFound { if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream") peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
@ -287,9 +292,11 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
continue continue
} }
err := peer.SetVideo(types.StreamSelector{ err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId, ID: streamId,
Type: types.StreamSelectorTypeHigher, Type: types.StreamSelectorTypeHigher,
},
}) })
if err != nil && err != types.ErrWebRTCStreamNotFound { if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream") peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
@ -304,16 +311,56 @@ func (peer *WebRTCPeerCtx) estimatorReader() {
} }
} }
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
peer.mu.Lock()
defer peer.mu.Unlock()
peer.videoTrack.SetPaused(isPaused || peer.videoDisabled)
peer.audioTrack.SetPaused(isPaused || peer.audioDisabled)
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
peer.paused = isPaused
return nil
}
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.paused
}
// //
// video // video
// //
func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error { func (peer *WebRTCPeerCtx) SetVideo(r types.PeerVideoRequest) error {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
modified := false
// video disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.videoDisabled != disabled {
peer.videoDisabled = disabled
peer.videoTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set video disabled")
modified = true
}
}
// video selector
if r.Selector != nil {
selector := *r.Selector
// get requested video stream from selector // get requested video stream from selector
stream, ok := peer.videoSelector.GetStream(selector) stream, ok := peer.video.GetStream(selector)
if !ok { if !ok {
return types.ErrWebRTCStreamNotFound return types.ErrWebRTCStreamNotFound
} }
@ -324,74 +371,106 @@ func (peer *WebRTCPeerCtx) SetVideo(selector types.StreamSelector) error {
return err return err
} }
// if video stream was already set, do nothing // update only if stream changed
if !changed { if changed {
return nil
}
videoID := stream.ID() videoID := stream.ID()
peer.metrics.SetVideoID(videoID) peer.metrics.SetVideoID(videoID)
peer.logger.Info().Str("video_id", videoID).Msg("set video") peer.logger.Info().Str("video_id", videoID).Msg("set video")
modified = true
go peer.session.Send( }
event.SIGNAL_VIDEO,
message.SignalVideo{
Video: videoID,
Auto: peer.videoAuto,
})
return nil
} }
func (peer *WebRTCPeerCtx) VideoID() (string, bool) { // video auto
peer.mu.Lock() if r.Auto != nil {
defer peer.mu.Unlock() videoAuto := *r.Auto
stream, ok := peer.videoTrack.Stream() if peer.estimator == nil || peer.estimatorConfig.Passive {
if !ok {
return "", false
}
return stream.ID(), true
}
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
peer.mu.Lock()
defer peer.mu.Unlock()
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
peer.videoTrack.SetPaused(isPaused)
peer.audioTrack.SetPaused(isPaused)
return nil
}
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.videoTrack.Paused() || peer.audioTrack.Paused()
}
func (peer *WebRTCPeerCtx) SetVideoAuto(videoAuto bool) {
peer.mu.Lock()
defer peer.mu.Unlock()
// if estimator is enabled and is not passive, enable video auto bitrate
if peer.estimator != nil && !peer.estimatorConfig.Passive {
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
peer.videoAuto = videoAuto
} else {
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto") peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
peer.videoAuto = false // ensure video auto is disabled videoAuto = false // ensure video auto is disabled
}
// update only if video auto changed
if peer.videoAuto != videoAuto {
peer.videoAuto = videoAuto
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
modified = true
} }
} }
func (peer *WebRTCPeerCtx) VideoAuto() bool { // send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_VIDEO, peer.Video())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Video() types.PeerVideo {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
return peer.videoAuto // get current video stream ID
ID := ""
stream, ok := peer.videoTrack.Stream()
if ok {
ID = stream.ID()
}
return types.PeerVideo{
Disabled: peer.videoDisabled,
ID: ID,
Video: ID, // TODO: Remove, used for backward compatibility
Auto: peer.videoAuto,
}
}
//
// audio
//
func (peer *WebRTCPeerCtx) SetAudio(r types.PeerAudioRequest) error {
peer.mu.Lock()
defer peer.mu.Unlock()
modified := false
// audio disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.audioDisabled != disabled {
peer.audioDisabled = disabled
peer.audioTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set audio disabled")
modified = true
}
}
// send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_AUDIO, peer.Audio())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Audio() types.PeerAudio {
peer.mu.Lock()
defer peer.mu.Unlock()
return types.PeerAudio{
Disabled: peer.audioDisabled,
}
} }
// //

View File

@ -177,6 +177,7 @@ func (t *Track) SetPaused(paused bool) {
// if there is no state change or no stream, do nothing // if there is no state change or no stream, do nothing
if t.paused == paused || t.stream == nil { if t.paused == paused || t.stream == nil {
t.paused = paused
return return
} }

View File

@ -45,7 +45,7 @@ func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketM
// Signal Events // Signal Events
case event.SIGNAL_REQUEST: case event.SIGNAL_REQUEST:
payload := &message.SignalVideo{} payload := &message.SignalRequest{}
err = utils.Unmarshal(payload, data.Payload, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalRequest(session, payload) return h.signalRequest(session, payload)
}) })
@ -71,6 +71,11 @@ func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketM
err = utils.Unmarshal(payload, data.Payload, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalVideo(session, payload) return h.signalVideo(session, payload)
}) })
case event.SIGNAL_AUDIO:
payload := &message.SignalAudio{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalAudio(session, payload)
})
// Control Events // Control Events
case event.CONTROL_RELEASE: case event.CONTROL_RELEASE:

View File

@ -9,17 +9,11 @@ import (
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
) )
func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *message.SignalVideo) error { func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *message.SignalRequest) error {
if !session.Profile().CanWatch { if !session.Profile().CanWatch {
return errors.New("not allowed to watch") return errors.New("not allowed to watch")
} }
// use default first video, if not provided
if payload.Video == "" {
videos := h.capture.Video().IDs()
payload.Video = videos[0]
}
offer, peer, err := h.webrtc.CreatePeer(session) offer, peer, err := h.webrtc.CreatePeer(session)
if err != nil { if err != nil {
return err return err
@ -30,14 +24,38 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
peer.SetPaused(true) peer.SetPaused(true)
} }
// set video auto state video := payload.Video
peer.SetVideoAuto(payload.Auto)
// use default first video, if not provided
if video.Selector == nil {
videos := h.capture.Video().IDs()
video.Selector = &types.StreamSelector{
ID: videos[0],
Type: types.StreamSelectorTypeExact,
}
}
// TODO: Remove, used for compatibility with old clients.
if video.Auto == nil {
video.Auto = &payload.Auto
}
// set video stream // set video stream
err = peer.SetVideo(types.StreamSelector{ err = peer.SetVideo(video)
ID: payload.Video, if err != nil {
Type: types.StreamSelectorTypeNearest, return err
}) }
audio := payload.Audio
// enable by default if not requested otherwise
if audio.Disabled == nil {
disabled := false
audio.Disabled = &disabled
}
// set audio stream
err = peer.SetAudio(audio)
if err != nil { if err != nil {
return err return err
} }
@ -47,6 +65,9 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag
message.SignalProvide{ message.SignalProvide{
SDP: offer.SDP, SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(), ICEServers: h.webrtc.ICEServers(),
Video: peer.Video(),
Audio: peer.Audio(),
}) })
return nil return nil
@ -128,14 +149,14 @@ func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.
return errors.New("webRTC peer does not exist") return errors.New("webRTC peer does not exist")
} }
peer.SetVideoAuto(payload.Auto) return peer.SetVideo(payload.PeerVideoRequest)
if payload.Video != "" {
return peer.SetVideo(types.StreamSelector{
ID: payload.Video,
Type: types.StreamSelectorTypeNearest,
})
} }
return nil func (h *MessageHandlerCtx) signalAudio(session types.Session, payload *message.SignalAudio) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
return peer.SetAudio(payload.PeerAudioRequest)
} }

View File

@ -94,11 +94,11 @@ func (s StreamSelectorType) MarshalText() ([]byte, error) {
type StreamSelector struct { type StreamSelector struct {
// type of stream selector // type of stream selector
Type StreamSelectorType Type StreamSelectorType `json:"type"`
// select stream by its ID // select stream by its ID
ID string ID string `json:"id"`
// select stream by its bitrate // select stream by its bitrate
Bitrate uint64 Bitrate uint64 `json:"bitrate"`
} }
type StreamSelectorManager interface { type StreamSelectorManager interface {

View File

@ -17,6 +17,7 @@ const (
SIGNAL_PROVIDE = "signal/provide" SIGNAL_PROVIDE = "signal/provide"
SIGNAL_CANDIDATE = "signal/candidate" SIGNAL_CANDIDATE = "signal/candidate"
SIGNAL_VIDEO = "signal/video" SIGNAL_VIDEO = "signal/video"
SIGNAL_AUDIO = "signal/audio"
SIGNAL_CLOSE = "signal/close" SIGNAL_CLOSE = "signal/close"
) )

View File

@ -45,9 +45,19 @@ type SystemDisconnect struct {
// Signal // Signal
///////////////////////////// /////////////////////////////
type SignalRequest struct {
Video types.PeerVideoRequest `json:"video"`
Audio types.PeerAudioRequest `json:"audio"`
Auto bool `json:"auto"` // TODO: Remove this
}
type SignalProvide struct { type SignalProvide struct {
SDP string `json:"sdp"` SDP string `json:"sdp"`
ICEServers []types.ICEServer `json:"iceservers"` ICEServers []types.ICEServer `json:"iceservers"`
Video types.PeerVideo `json:"video"`
Audio types.PeerAudio `json:"audio"`
} }
type SignalCandidate struct { type SignalCandidate struct {
@ -59,8 +69,11 @@ type SignalDescription struct {
} }
type SignalVideo struct { type SignalVideo struct {
Video string `json:"video"` types.PeerVideoRequest
Auto bool `json:"auto"` }
type SignalAudio struct {
types.PeerAudioRequest
} }
///////////////////////////// /////////////////////////////

View File

@ -18,18 +18,40 @@ type ICEServer struct {
Credential string `mapstructure:"credential" json:"credential,omitempty"` Credential string `mapstructure:"credential" json:"credential,omitempty"`
} }
type PeerVideo struct {
Disabled bool `json:"disabled"`
ID string `json:"id"`
Video string `json:"video"` // TODO: Remove this, used for compatibility with old clients.
Auto bool `json:"auto"`
}
type PeerVideoRequest struct {
Disabled *bool `json:"disabled,omitempty"`
Selector *StreamSelector `json:"selector,omitempty"`
Auto *bool `json:"auto,omitempty"`
}
type PeerAudio struct {
Disabled bool `json:"disabled"`
}
type PeerAudioRequest struct {
Disabled *bool `json:"disabled,omitempty"`
}
type WebRTCPeer interface { type WebRTCPeer interface {
CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error)
CreateAnswer() (*webrtc.SessionDescription, error) CreateAnswer() (*webrtc.SessionDescription, error)
SetRemoteDescription(webrtc.SessionDescription) error SetRemoteDescription(webrtc.SessionDescription) error
SetCandidate(webrtc.ICECandidateInit) error SetCandidate(webrtc.ICECandidateInit) error
SetVideo(StreamSelector) error
VideoID() (string, bool)
SetPaused(isPaused bool) error SetPaused(isPaused bool) error
Paused() bool Paused() bool
SetVideoAuto(auto bool)
VideoAuto() bool SetVideo(PeerVideoRequest) error
Video() PeerVideo
SetAudio(PeerAudioRequest) error
Audio() PeerAudio
SendCursorPosition(x, y int) error SendCursorPosition(x, y int) error
SendCursorImage(cur *CursorImage, img []byte) error SendCursorImage(cur *CursorImage, img []byte) error