mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
webrtc refactor peer track.
This commit is contained in:
@ -1,15 +1,11 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pion/webrtc/v3/pkg/media"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
@ -39,13 +35,10 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con
|
||||
capture: capture,
|
||||
curImage: cursor.NewImage(desktop),
|
||||
curPosition: cursor.NewPosition(desktop),
|
||||
|
||||
participants: 0,
|
||||
}
|
||||
}
|
||||
|
||||
type WebRTCManagerCtx struct {
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
config *config.WebRTC
|
||||
|
||||
@ -53,33 +46,9 @@ type WebRTCManagerCtx struct {
|
||||
capture types.CaptureManager
|
||||
curImage *cursor.ImageCtx
|
||||
curPosition *cursor.PositionCtx
|
||||
|
||||
audioTrack *webrtc.TrackLocalStaticSample
|
||||
audioListener func(sample types.Sample)
|
||||
participants uint32
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) Start() {
|
||||
var err error
|
||||
|
||||
// create audio track
|
||||
audio := manager.capture.Audio()
|
||||
manager.audioTrack, err = webrtc.NewTrackLocalStaticSample(audio.Codec().Capability, "audio", "stream")
|
||||
if err != nil {
|
||||
manager.logger.Panic().Err(err).Msg("unable to create audio track")
|
||||
}
|
||||
|
||||
manager.audioListener = func(sample types.Sample) {
|
||||
if err := manager.audioTrack.WriteSample(media.Sample(sample)); err != nil {
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
// The peerConnection has been closed.
|
||||
return
|
||||
}
|
||||
manager.logger.Warn().Err(err).Msg("audio pipeline failed to write")
|
||||
}
|
||||
}
|
||||
audio.AddListener(&manager.audioListener)
|
||||
|
||||
manager.curImage.Start()
|
||||
|
||||
manager.logger.Info().
|
||||
@ -97,9 +66,6 @@ func (manager *WebRTCManagerCtx) Shutdown() error {
|
||||
manager.curImage.Shutdown()
|
||||
manager.curPosition.Shutdown()
|
||||
|
||||
audio := manager.capture.Audio()
|
||||
audio.RemoveListener(&manager.audioListener)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -112,6 +78,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
logger := manager.logger.With().Str("session_id", session.ID()).Logger()
|
||||
logger.Info().Msg("creating webrtc peer")
|
||||
|
||||
// all audios must have the same codec
|
||||
audioStream := manager.capture.Audio()
|
||||
|
||||
// all videos must have the same codec
|
||||
videoStream, ok := manager.capture.Video(videoID)
|
||||
if !ok {
|
||||
@ -119,8 +88,8 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
}
|
||||
|
||||
connection, err := manager.newPeerConnection([]codec.RTPCodec{
|
||||
audioStream.Codec(),
|
||||
videoStream.Codec(),
|
||||
manager.capture.Audio().Codec(),
|
||||
}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -142,79 +111,32 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
})
|
||||
}
|
||||
|
||||
// create video track
|
||||
videoTrack, err := webrtc.NewTrackLocalStaticSample(videoStream.Codec().Capability, "video", "stream")
|
||||
// audio track
|
||||
|
||||
audioTrack, err := manager.newPeerTrack(audioStream, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
videoListener := func(sample types.Sample) {
|
||||
if err := videoTrack.WriteSample(media.Sample(sample)); err != nil {
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
// The peerConnection has been closed.
|
||||
return
|
||||
}
|
||||
logger.Warn().Err(err).Msg("video pipeline failed to write")
|
||||
}
|
||||
}
|
||||
|
||||
manager.mu.Lock()
|
||||
|
||||
// should be stream started
|
||||
if videoStream.ListenersCount() == 0 {
|
||||
if err := videoStream.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
videoStream.AddListener(&videoListener)
|
||||
|
||||
// start audio, when first participant connects
|
||||
if !manager.capture.Audio().Started() {
|
||||
if err := manager.capture.Audio().Start(); err != nil {
|
||||
manager.logger.Panic().Err(err).Msg("unable to start audio stream")
|
||||
}
|
||||
}
|
||||
|
||||
manager.participants = manager.participants + 1
|
||||
manager.mu.Unlock()
|
||||
|
||||
changeVideo := func(videoID string) error {
|
||||
newVideoStream, ok := manager.capture.Video(videoID)
|
||||
if !ok {
|
||||
return types.ErrWebRTCVideoNotFound
|
||||
}
|
||||
|
||||
// should be new stream started
|
||||
if newVideoStream.ListenersCount() == 0 {
|
||||
if err := newVideoStream.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// switch videoListeners
|
||||
videoStream.RemoveListener(&videoListener)
|
||||
newVideoStream.AddListener(&videoListener)
|
||||
|
||||
// should be old stream stopped
|
||||
if videoStream.ListenersCount() == 0 {
|
||||
videoStream.Stop()
|
||||
}
|
||||
|
||||
videoStream = newVideoStream
|
||||
return nil
|
||||
}
|
||||
|
||||
rtpAudio, err := connection.AddTrack(manager.audioTrack)
|
||||
audioTrack.AddToConnection(connection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtpVideo, err := connection.AddTrack(videoTrack)
|
||||
// video track
|
||||
|
||||
videoTrack, err := manager.newPeerTrack(videoStream, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
videoTrack.AddToConnection(connection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// data channel
|
||||
|
||||
dataChannel, err := connection.CreateDataChannel("data", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -224,8 +146,15 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
logger: logger,
|
||||
connection: connection,
|
||||
dataChannel: dataChannel,
|
||||
changeVideo: changeVideo,
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
changeVideo: func(videoID string) error {
|
||||
videoStream, ok := manager.capture.Video(videoID)
|
||||
if !ok {
|
||||
return types.ErrWebRTCVideoNotFound
|
||||
}
|
||||
|
||||
return videoTrack.SetStream(videoStream)
|
||||
},
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
}
|
||||
|
||||
cursorImage := func(entry *cursor.ImageEntry) {
|
||||
@ -252,29 +181,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
webrtc.PeerConnectionStateFailed:
|
||||
connection.Close()
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
manager.mu.Lock()
|
||||
|
||||
session.SetWebRTCConnected(peer, false)
|
||||
videoStream.RemoveListener(&videoListener)
|
||||
|
||||
// should be stream stopped
|
||||
if videoStream.ListenersCount() == 0 {
|
||||
videoStream.Stop()
|
||||
}
|
||||
|
||||
// decrease participants
|
||||
manager.participants = manager.participants - 1
|
||||
|
||||
// stop audio, if last participant disonnects
|
||||
if manager.participants <= 0 {
|
||||
manager.participants = 0
|
||||
|
||||
if manager.capture.Audio().Started() {
|
||||
manager.capture.Audio().Stop()
|
||||
}
|
||||
}
|
||||
|
||||
manager.mu.Unlock()
|
||||
videoTrack.RemoveStream()
|
||||
audioTrack.RemoveStream()
|
||||
}
|
||||
})
|
||||
|
||||
@ -310,24 +219,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin
|
||||
}
|
||||
})
|
||||
|
||||
go func() {
|
||||
rtcpBuf := make([]byte, 1500)
|
||||
for {
|
||||
if _, _, err := rtpAudio.Read(rtcpBuf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rtcpBuf := make([]byte, 1500)
|
||||
for {
|
||||
if _, _, err := rtpVideo.Read(rtcpBuf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
session.SetWebRTCPeer(peer)
|
||||
return peer.CreateOffer(false)
|
||||
}
|
||||
|
97
internal/webrtc/peertrack.go
Normal file
97
internal/webrtc/peertrack.go
Normal file
@ -0,0 +1,97 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"demodesk/neko/internal/types"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pion/webrtc/v3/pkg/media"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func (manager *WebRTCManagerCtx) newPeerTrack(stream types.StreamManager, logger zerolog.Logger) (*PeerTrack, error) {
|
||||
codec := stream.Codec()
|
||||
|
||||
id := codec.Type.String()
|
||||
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger = logger.With().Str("id", id).Logger()
|
||||
|
||||
peer := &PeerTrack{
|
||||
logger: logger,
|
||||
track: track,
|
||||
listener: func(sample types.Sample) {
|
||||
err := track.WriteSample(media.Sample(sample))
|
||||
if err != nil && errors.Is(err, io.ErrClosedPipe) {
|
||||
logger.Warn().Err(err).Msg("pipeline failed to write")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
peer.SetStream(stream)
|
||||
return peer, nil
|
||||
|
||||
}
|
||||
|
||||
type PeerTrack struct {
|
||||
logger zerolog.Logger
|
||||
track *webrtc.TrackLocalStaticSample
|
||||
listener func(sample types.Sample)
|
||||
|
||||
streamMu sync.Mutex
|
||||
stream types.StreamManager
|
||||
}
|
||||
|
||||
func (peer *PeerTrack) SetStream(stream types.StreamManager) error {
|
||||
peer.streamMu.Lock()
|
||||
defer peer.streamMu.Unlock()
|
||||
|
||||
// prepare new listener
|
||||
addListener, err := stream.NewListener(&peer.listener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove previous listener (in case it existed)
|
||||
if peer.stream != nil {
|
||||
peer.stream.RemoveListener(&peer.listener)
|
||||
}
|
||||
|
||||
// add new listener
|
||||
addListener()
|
||||
|
||||
peer.stream = stream
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *PeerTrack) RemoveStream() {
|
||||
peer.streamMu.Lock()
|
||||
defer peer.streamMu.Unlock()
|
||||
|
||||
if peer.stream != nil {
|
||||
peer.stream.RemoveListener(&peer.listener)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *PeerTrack) AddToConnection(connection *webrtc.PeerConnection) error {
|
||||
sender, err := connection.AddTrack(peer.track)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
rtcpBuf := make([]byte, 1500)
|
||||
for {
|
||||
if _, _, err := sender.Read(rtcpBuf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user