webrtc refactor peer track.

This commit is contained in:
Miroslav Šedivý 2021-09-27 00:50:49 +02:00
parent beac1cb088
commit 9d4d5766ef
4 changed files with 219 additions and 209 deletions

View File

@ -1,6 +1,7 @@
package capture package capture
import ( import (
"errors"
"reflect" "reflect"
"sync" "sync"
@ -13,18 +14,22 @@ import (
) )
type StreamManagerCtx struct { type StreamManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
mu sync.Mutex mu sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
codec codec.RTPCodec codec codec.RTPCodec
pipelineStr func() string
pipeline *gst.Pipeline pipeline *gst.Pipeline
sample chan types.Sample pipelineMu sync.Mutex
listeners map[uintptr]*func(sample types.Sample) pipelineStr func() string
emitMu sync.Mutex
emitUpdate chan bool sample chan types.Sample
emitStop chan bool sampleStop chan interface{}
started bool sampleUpdate chan interface{}
listeners map[uintptr]*func(sample types.Sample)
listenersMu sync.Mutex
listenersCount uint32
} }
func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) *StreamManagerCtx { func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) *StreamManagerCtx {
@ -34,13 +39,12 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string)
Str("video_id", video_id).Logger() Str("video_id", video_id).Logger()
manager := &StreamManagerCtx{ manager := &StreamManagerCtx{
logger: logger, logger: logger,
codec: codec, codec: codec,
pipelineStr: pipelineStr, pipelineStr: pipelineStr,
listeners: map[uintptr]*func(sample types.Sample){}, sampleStop: make(chan interface{}),
emitUpdate: make(chan bool), sampleUpdate: make(chan interface{}),
emitStop: make(chan bool), listeners: map[uintptr]*func(sample types.Sample){},
started: false,
} }
manager.wg.Add(1) manager.wg.Add(1)
@ -51,17 +55,17 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string)
for { for {
select { select {
case <-manager.emitStop: case <-manager.sampleStop:
manager.logger.Debug().Msg("stopped emitting samples") manager.logger.Debug().Msg("stopped emitting samples")
return return
case <-manager.emitUpdate: case <-manager.sampleUpdate:
manager.logger.Debug().Msg("update emitting samples") manager.logger.Debug().Msg("update emitting samples")
case sample := <-manager.sample: case sample := <-manager.sample:
manager.emitMu.Lock() manager.listenersMu.Lock()
for _, emit := range manager.listeners { for _, emit := range manager.listeners {
(*emit)(sample) (*emit)(sample)
} }
manager.emitMu.Unlock() manager.listenersMu.Unlock()
} }
} }
}() }()
@ -72,15 +76,15 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string)
func (manager *StreamManagerCtx) shutdown() { func (manager *StreamManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown") manager.logger.Info().Msgf("shutdown")
manager.emitMu.Lock() manager.listenersMu.Lock()
for key := range manager.listeners { for key := range manager.listeners {
delete(manager.listeners, key) delete(manager.listeners, key)
} }
manager.emitMu.Unlock() manager.listenersMu.Unlock()
manager.destroyPipeline() manager.destroyPipeline()
manager.emitStop <- true close(manager.sampleStop)
manager.wg.Wait() manager.wg.Wait()
} }
@ -88,63 +92,78 @@ func (manager *StreamManagerCtx) Codec() codec.RTPCodec {
return manager.codec return manager.codec
} }
func (manager *StreamManagerCtx) AddListener(listener *func(sample types.Sample)) { func (manager *StreamManagerCtx) NewListener(listener *func(sample types.Sample)) (addListener func(), err error) {
manager.emitMu.Lock() if listener == nil {
defer manager.emitMu.Unlock() return addListener, errors.New("listener cannot be nil")
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
manager.listeners[ptr] = listener
manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener")
} }
manager.mu.Lock()
defer manager.mu.Unlock()
if manager.listenersCount == 0 {
err := manager.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return addListener, err
}
manager.listenersCount++
manager.logger.Info().Msgf("first listener, starting")
}
return func() {
ptr := reflect.ValueOf(listener).Pointer()
manager.listenersMu.Lock()
manager.listeners[ptr] = listener
manager.listenersMu.Unlock()
manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener")
}, nil
} }
func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Sample)) { func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Sample)) {
manager.emitMu.Lock() if listener == nil {
defer manager.emitMu.Unlock() return
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
delete(manager.listeners, ptr)
manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener")
} }
ptr := reflect.ValueOf(listener).Pointer()
manager.listenersMu.Lock()
delete(manager.listeners, ptr)
manager.listenersMu.Unlock()
manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener")
go func() {
manager.mu.Lock()
defer manager.mu.Unlock()
if manager.listenersCount == 1 {
manager.destroyPipeline()
manager.listenersCount = 0
manager.logger.Info().Msgf("last listener, stopping")
}
}()
} }
func (manager *StreamManagerCtx) ListenersCount() int { func (manager *StreamManagerCtx) ListenersCount() int {
manager.emitMu.Lock() manager.listenersMu.Lock()
defer manager.emitMu.Unlock() defer manager.listenersMu.Unlock()
return len(manager.listeners) return len(manager.listeners)
} }
func (manager *StreamManagerCtx) Start() error {
manager.mu.Lock()
defer manager.mu.Unlock()
err := manager.createPipeline()
if err != nil {
return err
}
manager.logger.Info().Msgf("start")
manager.started = true
return nil
}
func (manager *StreamManagerCtx) Stop() {
manager.mu.Lock()
defer manager.mu.Unlock()
manager.logger.Info().Msgf("stop")
manager.started = false
manager.destroyPipeline()
}
func (manager *StreamManagerCtx) Started() bool { func (manager *StreamManagerCtx) Started() bool {
return manager.started manager.mu.Lock()
defer manager.mu.Unlock()
return manager.listenersCount > 0
} }
func (manager *StreamManagerCtx) createPipeline() error { func (manager *StreamManagerCtx) createPipeline() error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline != nil { if manager.pipeline != nil {
return types.ErrCapturePipelineAlreadyExists return types.ErrCapturePipelineAlreadyExists
} }
@ -166,11 +185,14 @@ func (manager *StreamManagerCtx) createPipeline() error {
manager.pipeline.Start() manager.pipeline.Start()
manager.sample = manager.pipeline.Sample manager.sample = manager.pipeline.Sample
manager.emitUpdate <- true manager.sampleUpdate <- struct{}{}
return nil return nil
} }
func (manager *StreamManagerCtx) destroyPipeline() { func (manager *StreamManagerCtx) destroyPipeline() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil { if manager.pipeline == nil {
return return
} }

View File

@ -35,12 +35,12 @@ type ScreencastManager interface {
type StreamManager interface { type StreamManager interface {
Codec() codec.RTPCodec Codec() codec.RTPCodec
AddListener(listener *func(sample Sample)) // starts pipeline if was not running before and returns register function
NewListener(listener *func(sample Sample)) (addListener func(), err error)
// stops pipeline if it was last listener
RemoveListener(listener *func(sample Sample)) RemoveListener(listener *func(sample Sample))
ListenersCount() int
Start() error ListenersCount() int
Stop()
Started() bool Started() bool
} }

View File

@ -1,15 +1,11 @@
package webrtc package webrtc
import ( import (
"errors"
"fmt" "fmt"
"io"
"strings" "strings"
"sync"
"time" "time"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -39,13 +35,10 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con
capture: capture, capture: capture,
curImage: cursor.NewImage(desktop), curImage: cursor.NewImage(desktop),
curPosition: cursor.NewPosition(desktop), curPosition: cursor.NewPosition(desktop),
participants: 0,
} }
} }
type WebRTCManagerCtx struct { type WebRTCManagerCtx struct {
mu sync.Mutex
logger zerolog.Logger logger zerolog.Logger
config *config.WebRTC config *config.WebRTC
@ -53,33 +46,9 @@ type WebRTCManagerCtx struct {
capture types.CaptureManager capture types.CaptureManager
curImage *cursor.ImageCtx curImage *cursor.ImageCtx
curPosition *cursor.PositionCtx curPosition *cursor.PositionCtx
audioTrack *webrtc.TrackLocalStaticSample
audioListener func(sample types.Sample)
participants uint32
} }
func (manager *WebRTCManagerCtx) Start() { func (manager *WebRTCManagerCtx) Start() {
var err error
// create audio track
audio := manager.capture.Audio()
manager.audioTrack, err = webrtc.NewTrackLocalStaticSample(audio.Codec().Capability, "audio", "stream")
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create audio track")
}
manager.audioListener = func(sample types.Sample) {
if err := manager.audioTrack.WriteSample(media.Sample(sample)); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
// The peerConnection has been closed.
return
}
manager.logger.Warn().Err(err).Msg("audio pipeline failed to write")
}
}
audio.AddListener(&manager.audioListener)
manager.curImage.Start() manager.curImage.Start()
manager.logger.Info(). manager.logger.Info().
@ -97,9 +66,6 @@ func (manager *WebRTCManagerCtx) Shutdown() error {
manager.curImage.Shutdown() manager.curImage.Shutdown()
manager.curPosition.Shutdown() manager.curPosition.Shutdown()
audio := manager.capture.Audio()
audio.RemoveListener(&manager.audioListener)
return nil return nil
} }
@ -112,6 +78,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
logger := manager.logger.With().Str("session_id", session.ID()).Logger() logger := manager.logger.With().Str("session_id", session.ID()).Logger()
logger.Info().Msg("creating webrtc peer") logger.Info().Msg("creating webrtc peer")
// all audios must have the same codec
audioStream := manager.capture.Audio()
// all videos must have the same codec // all videos must have the same codec
videoStream, ok := manager.capture.Video(videoID) videoStream, ok := manager.capture.Video(videoID)
if !ok { if !ok {
@ -119,8 +88,8 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
} }
connection, err := manager.newPeerConnection([]codec.RTPCodec{ connection, err := manager.newPeerConnection([]codec.RTPCodec{
audioStream.Codec(),
videoStream.Codec(), videoStream.Codec(),
manager.capture.Audio().Codec(),
}, logger) }, logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -142,79 +111,32 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
}) })
} }
// create video track // audio track
videoTrack, err := webrtc.NewTrackLocalStaticSample(videoStream.Codec().Capability, "video", "stream")
audioTrack, err := manager.newPeerTrack(audioStream, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
videoListener := func(sample types.Sample) { audioTrack.AddToConnection(connection)
if err := videoTrack.WriteSample(media.Sample(sample)); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
// The peerConnection has been closed.
return
}
logger.Warn().Err(err).Msg("video pipeline failed to write")
}
}
manager.mu.Lock()
// should be stream started
if videoStream.ListenersCount() == 0 {
if err := videoStream.Start(); err != nil {
return nil, err
}
}
videoStream.AddListener(&videoListener)
// start audio, when first participant connects
if !manager.capture.Audio().Started() {
if err := manager.capture.Audio().Start(); err != nil {
manager.logger.Panic().Err(err).Msg("unable to start audio stream")
}
}
manager.participants = manager.participants + 1
manager.mu.Unlock()
changeVideo := func(videoID string) error {
newVideoStream, ok := manager.capture.Video(videoID)
if !ok {
return types.ErrWebRTCVideoNotFound
}
// should be new stream started
if newVideoStream.ListenersCount() == 0 {
if err := newVideoStream.Start(); err != nil {
return err
}
}
// switch videoListeners
videoStream.RemoveListener(&videoListener)
newVideoStream.AddListener(&videoListener)
// should be old stream stopped
if videoStream.ListenersCount() == 0 {
videoStream.Stop()
}
videoStream = newVideoStream
return nil
}
rtpAudio, err := connection.AddTrack(manager.audioTrack)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rtpVideo, err := connection.AddTrack(videoTrack) // video track
videoTrack, err := manager.newPeerTrack(videoStream, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
videoTrack.AddToConnection(connection)
if err != nil {
return nil, err
}
// data channel
dataChannel, err := connection.CreateDataChannel("data", nil) dataChannel, err := connection.CreateDataChannel("data", nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -224,8 +146,15 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
logger: logger, logger: logger,
connection: connection, connection: connection,
dataChannel: dataChannel, dataChannel: dataChannel,
changeVideo: changeVideo, changeVideo: func(videoID string) error {
iceTrickle: manager.config.ICETrickle, videoStream, ok := manager.capture.Video(videoID)
if !ok {
return types.ErrWebRTCVideoNotFound
}
return videoTrack.SetStream(videoStream)
},
iceTrickle: manager.config.ICETrickle,
} }
cursorImage := func(entry *cursor.ImageEntry) { cursorImage := func(entry *cursor.ImageEntry) {
@ -252,29 +181,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
webrtc.PeerConnectionStateFailed: webrtc.PeerConnectionStateFailed:
connection.Close() connection.Close()
case webrtc.PeerConnectionStateClosed: case webrtc.PeerConnectionStateClosed:
manager.mu.Lock()
session.SetWebRTCConnected(peer, false) session.SetWebRTCConnected(peer, false)
videoStream.RemoveListener(&videoListener) videoTrack.RemoveStream()
audioTrack.RemoveStream()
// should be stream stopped
if videoStream.ListenersCount() == 0 {
videoStream.Stop()
}
// decrease participants
manager.participants = manager.participants - 1
// stop audio, if last participant disonnects
if manager.participants <= 0 {
manager.participants = 0
if manager.capture.Audio().Started() {
manager.capture.Audio().Stop()
}
}
manager.mu.Unlock()
} }
}) })
@ -310,24 +219,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
} }
}) })
go func() {
rtcpBuf := make([]byte, 1500)
for {
if _, _, err := rtpAudio.Read(rtcpBuf); err != nil {
return
}
}
}()
go func() {
rtcpBuf := make([]byte, 1500)
for {
if _, _, err := rtpVideo.Read(rtcpBuf); err != nil {
return
}
}
}()
session.SetWebRTCPeer(peer) session.SetWebRTCPeer(peer)
return peer.CreateOffer(false) return peer.CreateOffer(false)
} }

View File

@ -0,0 +1,97 @@
package webrtc
import (
"demodesk/neko/internal/types"
"errors"
"io"
"sync"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/rs/zerolog"
)
func (manager *WebRTCManagerCtx) newPeerTrack(stream types.StreamManager, logger zerolog.Logger) (*PeerTrack, error) {
codec := stream.Codec()
id := codec.Type.String()
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
if err != nil {
return nil, err
}
logger = logger.With().Str("id", id).Logger()
peer := &PeerTrack{
logger: logger,
track: track,
listener: func(sample types.Sample) {
err := track.WriteSample(media.Sample(sample))
if err != nil && errors.Is(err, io.ErrClosedPipe) {
logger.Warn().Err(err).Msg("pipeline failed to write")
}
},
}
peer.SetStream(stream)
return peer, nil
}
type PeerTrack struct {
logger zerolog.Logger
track *webrtc.TrackLocalStaticSample
listener func(sample types.Sample)
streamMu sync.Mutex
stream types.StreamManager
}
func (peer *PeerTrack) SetStream(stream types.StreamManager) error {
peer.streamMu.Lock()
defer peer.streamMu.Unlock()
// prepare new listener
addListener, err := stream.NewListener(&peer.listener)
if err != nil {
return err
}
// remove previous listener (in case it existed)
if peer.stream != nil {
peer.stream.RemoveListener(&peer.listener)
}
// add new listener
addListener()
peer.stream = stream
return nil
}
func (peer *PeerTrack) RemoveStream() {
peer.streamMu.Lock()
defer peer.streamMu.Unlock()
if peer.stream != nil {
peer.stream.RemoveListener(&peer.listener)
}
}
func (peer *PeerTrack) AddToConnection(connection *webrtc.PeerConnection) error {
sender, err := connection.AddTrack(peer.track)
if err != nil {
return err
}
go func() {
rtcpBuf := make([]byte, 1500)
for {
if _, _, err := sender.Read(rtcpBuf); err != nil {
return
}
}
}()
return nil
}