move server to server directory.

This commit is contained in:
Miroslav Šedivý
2024-06-23 17:48:14 +02:00
parent da45f62ca8
commit 5b98344205
211 changed files with 18 additions and 10 deletions

View File

@ -0,0 +1,168 @@
package cursor
import (
"reflect"
"sync"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type ImageListener interface {
SendCursorImage(cur *types.CursorImage, img []byte) error
}
type Image interface {
Start()
Shutdown()
GetCurrent() (cur *types.CursorImage, img []byte, err error)
AddListener(listener ImageListener)
RemoveListener(listener ImageListener)
}
type imageEntry struct {
*types.CursorImage
ImagePNG []byte
}
type image struct {
logger zerolog.Logger
desktop types.DesktopManager
listeners map[uintptr]ImageListener
listenersMu sync.RWMutex
cache map[uint64]*imageEntry
cacheMu sync.RWMutex
current *imageEntry
maxSerial uint64
}
func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image {
return &image{
logger: logger.With().Str("submodule", "cursor-image").Logger(),
desktop: desktop,
listeners: map[uintptr]ImageListener{},
cache: map[uint64]*imageEntry{},
maxSerial: 300, // TODO: Cleanup?
}
}
func (manager *image) Start() {
manager.desktop.OnCursorChanged(func(serial uint64) {
entry, err := manager.getCached(serial)
if err != nil {
manager.logger.Err(err).Msg("failed to get cursor image")
return
}
manager.current = entry
manager.listenersMu.RLock()
for _, l := range manager.listeners {
if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil {
manager.logger.Err(err).Msg("failed to set cursor image")
}
}
manager.listenersMu.RUnlock()
})
manager.logger.Info().Msg("starting")
}
func (manager *image) Shutdown() {
manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock()
for key := range manager.listeners {
delete(manager.listeners, key)
}
manager.listenersMu.Unlock()
}
func (manager *image) getCached(serial uint64) (*imageEntry, error) {
// zero means no serial available
if serial == 0 || serial > manager.maxSerial {
manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass")
return manager.fetchEntry()
}
manager.cacheMu.RLock()
entry, ok := manager.cache[serial]
manager.cacheMu.RUnlock()
if ok {
return entry, nil
}
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss")
entry, err := manager.fetchEntry()
if err != nil {
return nil, err
}
manager.cacheMu.Lock()
manager.cache[entry.Serial] = entry
manager.cacheMu.Unlock()
if entry.Serial != serial {
manager.logger.Warn().
Uint64("expected-serial", serial).
Uint64("received-serial", entry.Serial).
Msg("serial mismatch")
}
return entry, nil
}
func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) {
if manager.current != nil {
return manager.current.CursorImage, manager.current.ImagePNG, nil
}
entry, err := manager.fetchEntry()
if err != nil {
return nil, nil, err
}
manager.current = entry
return entry.CursorImage, entry.ImagePNG, nil
}
func (manager *image) AddListener(listener ImageListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
manager.listeners[ptr] = listener
}
}
func (manager *image) RemoveListener(listener ImageListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
delete(manager.listeners, ptr)
}
}
func (manager *image) fetchEntry() (*imageEntry, error) {
cur := manager.desktop.GetCursorImage()
img, err := utils.CreatePNGImage(cur.Image)
if err != nil {
return nil, err
}
cur.Image = nil // free memory
return &imageEntry{
CursorImage: cur,
ImagePNG: img,
}, nil
}

View File

@ -0,0 +1,74 @@
package cursor
import (
"reflect"
"sync"
"github.com/rs/zerolog"
)
type PositionListener interface {
SendCursorPosition(x, y int) error
}
type Position interface {
Shutdown()
Set(x, y int)
AddListener(listener PositionListener)
RemoveListener(listener PositionListener)
}
type position struct {
logger zerolog.Logger
listeners map[uintptr]PositionListener
listenersMu sync.RWMutex
}
func NewPosition(logger zerolog.Logger) *position {
return &position{
logger: logger.With().Str("submodule", "cursor-position").Logger(),
listeners: map[uintptr]PositionListener{},
}
}
func (manager *position) Shutdown() {
manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock()
for key := range manager.listeners {
delete(manager.listeners, key)
}
manager.listenersMu.Unlock()
}
func (manager *position) Set(x, y int) {
manager.listenersMu.RLock()
defer manager.listenersMu.RUnlock()
for _, l := range manager.listeners {
if err := l.SendCursorPosition(x, y); err != nil {
manager.logger.Err(err).Msg("failed to set cursor position")
}
}
}
func (manager *position) AddListener(listener PositionListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
manager.listeners[ptr] = listener
}
}
func (manager *position) RemoveListener(listener PositionListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
delete(manager.listeners, ptr)
}
}

View File

@ -0,0 +1,205 @@
package webrtc
import (
"bytes"
"encoding/binary"
"math"
"time"
"github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
)
func (manager *WebRTCManagerCtx) handle(
logger zerolog.Logger, data []byte,
dataChannel *webrtc.DataChannel,
session types.Session,
) error {
isHost := session.IsHost()
//
// parse header
//
buffer := bytes.NewBuffer(data)
header := &payload.Header{}
if err := binary.Read(buffer, binary.BigEndian, header); err != nil {
return err
}
//
// parse body
//
// handle cursor move event
if header.Event == payload.OP_MOVE {
payload := &payload.Move{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
x, y := int(payload.X), int(payload.Y)
if isHost {
// handle active cursor movement
manager.desktop.Move(x, y)
manager.curPosition.Set(x, y)
} else {
// handle inactive cursor movement
session.SetCursor(types.Cursor{
X: x,
Y: y,
})
}
return nil
} else if header.Event == payload.OP_PING {
ping := &payload.Ping{}
if err := binary.Read(buffer, binary.BigEndian, ping); err != nil {
return err
}
// create pong header
header := payload.Header{
Event: payload.OP_PONG,
Length: 19,
}
// generate server timestamp
serverTs := uint64(time.Now().UnixMilli())
// generate pong payload
pong := payload.Pong{
Ping: *ping,
ServerTs1: uint32(serverTs / math.MaxUint32),
ServerTs2: uint32(serverTs % math.MaxUint32),
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, pong); err != nil {
return err
}
return dataChannel.Send(buffer.Bytes())
}
// continue only if session is host
if !isHost {
return nil
}
switch header.Event {
case payload.OP_SCROLL:
// TODO: remove this once the client is fixed
if header.Length == 4 {
payload := &payload.Scroll_Old{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
manager.desktop.Scroll(int(payload.X), int(payload.Y), false)
logger.Trace().
Int16("x", payload.X).
Int16("y", payload.Y).
Msg("scroll")
} else {
payload := &payload.Scroll{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
manager.desktop.Scroll(int(payload.DeltaX), int(payload.DeltaY), payload.ControlKey)
logger.Trace().
Int16("deltaX", payload.DeltaX).
Int16("deltaY", payload.DeltaY).
Bool("controlKey", payload.ControlKey).
Msg("scroll")
}
case payload.OP_KEY_DOWN:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.KeyDown(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key down failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("key down")
}
case payload.OP_KEY_UP:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.KeyUp(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key up failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("key up")
}
case payload.OP_BTN_DOWN:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.ButtonDown(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button down failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("button down")
}
case payload.OP_BTN_UP:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.ButtonUp(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button up failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("button up")
}
case payload.OP_TOUCH_BEGIN:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchBegin(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch begin failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch begin")
}
case payload.OP_TOUCH_UPDATE:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchUpdate(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch update failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch update")
}
case payload.OP_TOUCH_END:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchEnd(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch end failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch end")
}
}
return nil
}

View File

@ -0,0 +1,576 @@
package webrtc
import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v2"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/interceptor/pkg/gcc"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/cursor"
"github.com/demodesk/neko/internal/webrtc/pionlog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
const (
// size of receiving channel used to buffer incoming TCP packets
tcpReadChanBufferSize = 50
// size of buffer used to buffer outgoing TCP packets. Default is 4MB
tcpWriteBufferSizeInBytes = 4 * 1024 * 1024
// the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds
disconnectedTimeout = 4 * time.Second
// the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds
failedTimeout = 6 * time.Second
// how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds
keepAliveInterval = 2 * time.Second
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
rtcpPLIInterval = 3 * time.Second
)
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
logger := log.With().Str("module", "webrtc").Logger()
configuration := webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
}
if !config.ICELite {
ICEServers := []webrtc.ICEServer{}
for _, server := range config.ICEServersBackend {
var credential any
if server.Credential != "" {
credential = server.Credential
} else {
credential = false
}
ICEServers = append(ICEServers, webrtc.ICEServer{
URLs: server.URLs,
Username: server.Username,
Credential: credential,
})
}
configuration.ICEServers = ICEServers
}
return &WebRTCManagerCtx{
logger: logger,
config: config,
metrics: newMetricsManager(),
webrtcConfiguration: configuration,
desktop: desktop,
capture: capture,
curImage: cursor.NewImage(logger, desktop),
curPosition: cursor.NewPosition(logger),
}
}
type WebRTCManagerCtx struct {
logger zerolog.Logger
config *config.WebRTC
metrics *metricsManager
peerId int32
desktop types.DesktopManager
capture types.CaptureManager
curImage cursor.Image
curPosition cursor.Position
webrtcConfiguration webrtc.Configuration
tcpMux ice.TCPMux
udpMux ice.UDPMux
camStop, micStop *func()
}
func (manager *WebRTCManagerCtx) Start() {
manager.curImage.Start()
logger := pionlog.New(manager.logger)
// add TCP Mux listener
if manager.config.TCPMux > 0 {
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{0, 0, 0, 0},
Port: manager.config.TCPMux,
})
if err != nil {
manager.logger.Fatal().Err(err).Msg("unable to setup ice TCP mux")
}
manager.tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{
Listener: tcpListener,
Logger: logger.NewLogger("ice-tcp"),
ReadBufferSize: tcpReadChanBufferSize,
WriteBufferSize: tcpWriteBufferSizeInBytes,
})
}
// add UDP Mux listener
if manager.config.UDPMux > 0 {
var err error
manager.udpMux, err = ice.NewMultiUDPMuxFromPort(manager.config.UDPMux,
ice.UDPMuxFromPortWithLogger(logger.NewLogger("ice-udp")),
)
if err != nil {
manager.logger.Fatal().Err(err).Msg("unable to setup ice UDP mux")
}
}
manager.logger.Info().
Bool("icelite", manager.config.ICELite).
Bool("icetrickle", manager.config.ICETrickle).
Interface("iceservers-frontend", manager.config.ICEServersFrontend).
Interface("iceservers-backend", manager.config.ICEServersBackend).
Str("nat1to1", strings.Join(manager.config.NAT1To1IPs, ",")).
Str("epr", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)).
Int("tcpmux", manager.config.TCPMux).
Int("udpmux", manager.config.UDPMux).
Msg("webrtc starting")
}
func (manager *WebRTCManagerCtx) Shutdown() error {
manager.logger.Info().Msg("shutdown")
manager.curImage.Shutdown()
manager.curPosition.Shutdown()
return nil
}
func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServersFrontend
}
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine
engine := &webrtc.MediaEngine{}
for _, codec := range codecs {
if err := codec.Register(engine); err != nil {
return nil, nil, err
}
}
// create setting engine
settings := webrtc.SettingEngine{
LoggerFactory: pionlog.New(logger),
}
settings.DisableMediaEngineCopy(true)
settings.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
settings.SetNAT1To1IPs(manager.config.NAT1To1IPs, webrtc.ICECandidateTypeHost)
settings.SetLite(manager.config.ICELite)
// make sure server answer sdp setup as passive, to not force DTLS renegotiation
// otherwise iOS renegotiation fails with: Failed to set SSL role for the transport.
settings.SetAnsweringDTLSRole(webrtc.DTLSRoleServer)
var networkType []webrtc.NetworkType
// udp candidates
if manager.udpMux != nil {
settings.SetICEUDPMux(manager.udpMux)
networkType = append(networkType,
webrtc.NetworkTypeUDP4,
webrtc.NetworkTypeUDP6,
)
} else if manager.config.EphemeralMax != 0 {
_ = settings.SetEphemeralUDPPortRange(manager.config.EphemeralMin, manager.config.EphemeralMax)
networkType = append(networkType,
webrtc.NetworkTypeUDP4,
webrtc.NetworkTypeUDP6,
)
}
// tcp candidates
if manager.tcpMux != nil {
settings.SetICETCPMux(manager.tcpMux)
networkType = append(networkType,
webrtc.NetworkTypeTCP4,
webrtc.NetworkTypeTCP6,
)
}
// enable support for TCP and UDP ICE candidates
settings.SetNetworkTypes(networkType)
// create interceptor registry
registry := &interceptor.Registry{}
// create bandwidth estimator
estimatorChan := make(chan cc.BandwidthEstimator, 1)
if manager.config.Estimator.Enabled {
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
return gcc.NewSendSideBWE(
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
)
})
if err != nil {
return nil, nil, err
}
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
estimatorChan <- estimator
})
registry.Add(congestionController)
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
return nil, nil, err
}
} else {
// no estimator, send nil
estimatorChan <- nil
}
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
return nil, nil, err
}
// create new API
api := webrtc.NewAPI(
webrtc.WithMediaEngine(engine),
webrtc.WithSettingEngine(settings),
webrtc.WithInterceptorRegistry(registry),
)
// create new peer connection
configuration := manager.webrtcConfiguration
connection, err := api.NewPeerConnection(configuration)
return connection, <-estimatorChan, err
}
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
id := atomic.AddInt32(&manager.peerId, 1)
// get metrics for session
metrics := manager.metrics.getBySession(session)
metrics.NewConnection()
// add session id to logger context
logger := manager.logger.With().Str("session_id", session.ID()).Int32("peer_id", id).Logger()
logger.Info().Msg("creating webrtc peer")
// all audios must have the same codec
audio := manager.capture.Audio()
audioCodec := audio.Codec()
// all videos must have the same codec
video := manager.capture.Video()
videoCodec := video.Codec()
connection, estimator, err := manager.newPeerConnection(
logger, []codec.RTPCodec{audioCodec, videoCodec})
if err != nil {
return nil, nil, err
}
// asynchronously send local ICE Candidates
if manager.config.ICETrickle {
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
logger.Debug().Msg("all local ice candidates sent")
return
}
session.Send(
event.SIGNAL_CANDIDATE,
message.SignalCandidate{
ICECandidateInit: candidate.ToJSON(),
})
})
}
// audio track
audioTrack, err := NewTrack(logger, audioCodec, connection)
if err != nil {
return nil, nil, err
}
// we disable audio by default manually
audioTrack.SetPaused(true)
// set stream for audio track
_, err = audioTrack.SetStream(audio)
if err != nil {
return nil, nil, err
}
// video track
videoRtcp := make(chan []rtcp.Packet, 1)
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
if err != nil {
return nil, nil, err
}
//
// stream for video track will be set later
//
// data channel
dataChannel, err := connection.CreateDataChannel("data", nil)
if err != nil {
return nil, nil, err
}
peer := &WebRTCPeerCtx{
logger: logger,
session: session,
metrics: metrics,
connection: connection,
// bandwidth estimator
estimator: estimator,
estimateTrend: utils.NewTrendDetector(
utils.TrendDetectorParams{
// Probing
//RequiredSamples: 3,
//DownwardTrendThreshold: 0.0,
//CollapseValues: false,
// Non-Probing
RequiredSamples: 8,
DownwardTrendThreshold: -0.5,
CollapseValues: true,
}),
// stream selectors
video: video,
audio: audio,
// tracks & channels
audioTrack: audioTrack,
videoTrack: videoTrack,
dataChannel: dataChannel,
rtcpChannel: videoRtcp,
// config
iceTrickle: manager.config.ICETrickle,
estimatorConfig: manager.config.Estimator,
audioDisabled: true, // we disable audio by default manually
}
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
logger := logger.With().
Str("kind", track.Kind().String()).
Str("mime", track.Codec().RTPCodecCapability.MimeType).
Logger()
logger.Info().Msgf("received new remote track")
if !session.Profile().CanShareMedia {
err := receiver.Stop()
logger.Warn().Err(err).Msg("media sharing is disabled for this session")
return
}
// parse codec from remote track
codec, ok := codec.ParseRTC(track.Codec())
if !ok {
err := receiver.Stop()
logger.Warn().Err(err).Msg("remote track with unknown codec")
return
}
var srcManager types.StreamSrcManager
stopped := false
stopFn := func() {
if stopped {
return
}
stopped = true
err := receiver.Stop()
srcManager.Stop()
logger.Err(err).Msg("remote track stopped")
}
if track.Kind() == webrtc.RTPCodecTypeAudio {
// audio -> microphone
srcManager = manager.capture.Microphone()
defer stopFn()
if manager.micStop != nil {
(*manager.micStop)()
}
manager.micStop = &stopFn
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
// video -> webcam
srcManager = manager.capture.Webcam()
defer stopFn()
if manager.camStop != nil {
(*manager.camStop)()
}
manager.camStop = &stopFn
} else {
err := receiver.Stop()
logger.Warn().Err(err).Msg("remote track with unsupported codec type")
return
}
err := srcManager.Start(codec)
if err != nil {
logger.Err(err).Msg("failed to start pipeline")
return
}
ticker := time.NewTicker(rtcpPLIInterval)
defer ticker.Stop()
go func() {
for range ticker.C {
err := connection.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{
MediaSSRC: uint32(track.SSRC()),
},
})
if err != nil {
logger.Err(err).Msg("remote track rtcp send err")
}
}
}()
buf := make([]byte, 1400)
for {
i, _, err := track.Read(buf)
if err != nil {
logger.Warn().Err(err).Msg("failed read from remote track")
break
}
srcManager.Push(buf[:i])
}
logger.Info().Msg("remote track data finished")
})
connection.OnDataChannel(func(dc *webrtc.DataChannel) {
logger.Info().Interface("data_channel", dc).Msg("got remote data channel")
})
var once sync.Once
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateConnected:
session.SetWebRTCConnected(peer, true)
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed:
peer.Destroy()
case webrtc.PeerConnectionStateClosed:
// ensure we only run this once
once.Do(func() {
session.SetWebRTCConnected(peer, false)
//
// TODO: Shutdown peer?
//
audioTrack.Shutdown()
videoTrack.Shutdown()
close(videoRtcp)
})
}
metrics.SetState(state)
})
dataChannel.OnOpen(func() {
manager.curImage.AddListener(peer)
manager.curPosition.AddListener(peer)
// send initial cursor image
cur, img, err := manager.curImage.GetCurrent()
if err == nil {
err := peer.SendCursorImage(cur, img)
if err != nil {
logger.Err(err).Msg("failed to set cursor image")
}
} else {
logger.Err(err).Msg("failed to get cursor image")
}
// send initial cursor position
x, y := manager.desktop.GetCursorPosition()
err = peer.SendCursorPosition(x, y)
if err != nil {
logger.Err(err).Msg("failed to set cursor position")
}
})
dataChannel.OnClose(func() {
manager.curImage.RemoveListener(peer)
manager.curPosition.RemoveListener(peer)
})
dataChannel.OnMessage(func(message webrtc.DataChannelMessage) {
if err := manager.handle(logger, message.Data, dataChannel, session); err != nil {
logger.Err(err).Msg("data handle failed")
}
})
session.SetWebRTCPeer(peer)
offer, err := peer.CreateOffer(false)
if err != nil {
return nil, nil, err
}
// on negotiation needed handler must be registered after creating initial
// offer, otherwise it can fire and intercept sucessful negotiation
connection.OnNegotiationNeeded(func() {
logger.Warn().Msg("negotiation is needed")
if connection.SignalingState() != webrtc.SignalingStateStable {
logger.Warn().Msg("connection isn't stable yet; postponing...")
return
}
offer, err := peer.CreateOffer(false)
if err != nil {
logger.Err(err).Msg("sdp offer failed")
return
}
session.Send(
event.SIGNAL_OFFER,
message.SignalDescription{
SDP: offer.SDP,
})
})
// start metrics collectors
go metrics.rtcpReceiver(videoRtcp)
go metrics.connectionStats(connection)
// start estimator reader
go peer.estimatorReader()
return offer, peer, nil
}
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
manager.curPosition.Set(x, y)
}

View File

@ -0,0 +1,458 @@
package webrtc
import (
"sync"
"time"
"github.com/demodesk/neko/pkg/types"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const (
// how often to read and process webrtc connection stats
connectionStatsInterval = 5 * time.Second
)
type metricsManager struct {
mu sync.Mutex
sessions map[string]*metrics
}
func newMetricsManager() *metricsManager {
return &metricsManager{
sessions: map[string]*metrics{},
}
}
func (m *metricsManager) getBySession(session types.Session) *metrics {
m.mu.Lock()
defer m.mu.Unlock()
sessionId := session.ID()
met, ok := m.sessions[sessionId]
if ok {
return met
}
met = &metrics{
sessionId: sessionId,
connectionState: promauto.NewGauge(prometheus.GaugeOpts{
Name: "connection_state",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Connection state of session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
connectionStateCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "connection_state_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of connection state changes for a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
connectionCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "connection_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Connection count of a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
iceCandidates: map[string]struct{}{},
iceCandidatesMu: &sync.Mutex{},
iceCandidatesUdpCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "ice_candidates_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of ICE candidates sent by a remote client.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "udp",
},
}),
iceCandidatesTcpCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "ice_candidates_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of ICE candidates sent by a remote client.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "tcp",
},
}),
iceCandidatesUsedUdp: promauto.NewGauge(prometheus.GaugeOpts{
Name: "ice_candidates_used",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Used ICE candidates that are currently in use.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "udp",
},
}),
iceCandidatesUsedTcp: promauto.NewGauge(prometheus.GaugeOpts{
Name: "ice_candidates_used",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Used ICE candidates that are currently in use.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "tcp",
},
}),
videoIds: map[string]prometheus.Gauge{},
videoIdsMu: &sync.Mutex{},
receiverEstimatedMaximumBitrate: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_estimated_maximum_bitrate",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverEstimatedTargetBitrate: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_estimated_target_bitrate",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Estimated Target Bitrate using Google's congestion control.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportDelay: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_delay",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportJitter: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_jitter",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Jitter from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportTotalLost: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_total_lost",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Total Lost from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
Name: "transport_layer_nacks",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Transport Layer NACKs from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
iceBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_sent",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Sent bytes to a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "ice",
},
}),
iceBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_received",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Received bytes from a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "ice",
},
}),
sctpBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_sent",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Sent bytes to a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "sctp",
},
}),
sctpBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_received",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Received bytes from a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "sctp",
},
}),
}
m.sessions[sessionId] = met
return met
}
type metrics struct {
sessionId string
connectionState prometheus.Gauge
connectionStateCount prometheus.Counter
connectionCount prometheus.Counter
iceCandidates map[string]struct{}
iceCandidatesMu *sync.Mutex
iceCandidatesUdpCount prometheus.Counter
iceCandidatesTcpCount prometheus.Counter
iceCandidatesUsedUdp prometheus.Gauge
iceCandidatesUsedTcp prometheus.Gauge
videoIds map[string]prometheus.Gauge
videoIdsMu *sync.Mutex
receiverEstimatedMaximumBitrate prometheus.Gauge
receiverEstimatedTargetBitrate prometheus.Gauge
receiverReportDelay prometheus.Gauge
receiverReportJitter prometheus.Gauge
receiverReportTotalLost prometheus.Gauge
transportLayerNacks prometheus.Counter
iceBytesSent prometheus.Gauge
iceBytesReceived prometheus.Gauge
sctpBytesSent prometheus.Gauge
sctpBytesReceived prometheus.Gauge
}
func (met *metrics) reset() {
met.videoIdsMu.Lock()
for _, entry := range met.videoIds {
entry.Set(0)
}
met.videoIdsMu.Unlock()
met.iceCandidatesUsedUdp.Set(float64(0))
met.iceCandidatesUsedTcp.Set(float64(0))
met.receiverEstimatedMaximumBitrate.Set(0)
met.receiverReportDelay.Set(0)
met.receiverReportJitter.Set(0)
}
func (met *metrics) NewConnection() {
met.connectionCount.Add(1)
}
func (met *metrics) NewICECandidate(candidate webrtc.ICECandidateStats) {
met.iceCandidatesMu.Lock()
defer met.iceCandidatesMu.Unlock()
if _, found := met.iceCandidates[candidate.ID]; found {
return
}
met.iceCandidates[candidate.ID] = struct{}{}
if candidate.Protocol == "udp" {
met.iceCandidatesUdpCount.Add(1)
} else if candidate.Protocol == "tcp" {
met.iceCandidatesTcpCount.Add(1)
}
}
func (met *metrics) SetICECandidatesUsed(candidates []webrtc.ICECandidateStats) {
udp, tcp := 0, 0
for _, candidate := range candidates {
if candidate.Protocol == "udp" {
udp++
} else if candidate.Protocol == "tcp" {
tcp++
}
}
met.iceCandidatesUsedUdp.Set(float64(udp))
met.iceCandidatesUsedTcp.Set(float64(tcp))
}
func (met *metrics) SetState(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateNew:
met.connectionState.Set(0)
case webrtc.PeerConnectionStateConnecting:
met.connectionState.Set(4)
case webrtc.PeerConnectionStateConnected:
met.connectionState.Set(5)
case webrtc.PeerConnectionStateDisconnected:
met.connectionState.Set(3)
case webrtc.PeerConnectionStateFailed:
met.connectionState.Set(2)
case webrtc.PeerConnectionStateClosed:
met.connectionState.Set(1)
met.reset()
default:
met.connectionState.Set(-1)
}
met.connectionStateCount.Add(1)
}
func (met *metrics) SetVideoID(videoId string) {
met.videoIdsMu.Lock()
defer met.videoIdsMu.Unlock()
if _, found := met.videoIds[videoId]; !found {
met.videoIds[videoId] = promauto.NewGauge(prometheus.GaugeOpts{
Name: "video_listeners",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Listeners for Video pipelines by a session.",
ConstLabels: map[string]string{
"session_id": met.sessionId,
"video_id": videoId,
},
})
}
for id, entry := range met.videoIds {
if id == videoId {
entry.Set(1)
} else {
entry.Set(0)
}
}
}
func (met *metrics) SetReceiverEstimatedMaximumBitrate(bitrate float32) {
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
}
func (met *metrics) SetReceiverEstimatedTargetBitrate(bitrate float64) {
met.receiverEstimatedTargetBitrate.Set(bitrate)
}
func (met *metrics) SetReceiverReport(report rtcp.ReceptionReport) {
met.receiverReportDelay.Set(float64(report.Delay))
met.receiverReportJitter.Set(float64(report.Jitter))
met.receiverReportTotalLost.Set(float64(report.TotalLost))
}
func (met *metrics) SetIceTransportStats(data webrtc.TransportStats) {
met.iceBytesSent.Set(float64(data.BytesSent))
met.iceBytesReceived.Set(float64(data.BytesReceived))
}
func (met *metrics) SetSctpTransportStats(data webrtc.TransportStats) {
met.sctpBytesSent.Set(float64(data.BytesSent))
met.sctpBytesReceived.Set(float64(data.BytesReceived))
}
//
// collectors
//
func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
for {
packets, ok := <-rtcpCh
if !ok {
break
}
for _, p := range packets {
switch rtcpPacket := p.(type) {
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
met.SetReceiverEstimatedMaximumBitrate(rtcpPacket.Bitrate)
case *rtcp.ReceiverReport:
l := len(rtcpPacket.Reports)
if l > 0 {
// use only last report
met.SetReceiverReport(rtcpPacket.Reports[l-1])
}
case *rtcp.TransportLayerNack:
for _, pair := range rtcpPacket.Nacks {
packetList := pair.PacketList()
met.transportLayerNacks.Add(float64(len(packetList)))
}
}
}
}
}
func (met *metrics) connectionStats(connection *webrtc.PeerConnection) {
ticker := time.NewTicker(connectionStatsInterval)
defer ticker.Stop()
for range ticker.C {
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
stats := connection.GetStats()
data, ok := stats["iceTransport"].(webrtc.TransportStats)
if ok {
met.SetIceTransportStats(data)
}
data, ok = stats["sctpTransport"].(webrtc.TransportStats)
if ok {
met.SetSctpTransportStats(data)
}
remoteCandidates := map[string]webrtc.ICECandidateStats{}
nominatedRemoteCandidates := map[string]struct{}{}
for _, entry := range stats {
// only remote ice candidate stats
candidate, ok := entry.(webrtc.ICECandidateStats)
if ok && candidate.Type == webrtc.StatsTypeRemoteCandidate {
met.NewICECandidate(candidate)
remoteCandidates[candidate.ID] = candidate
}
// only nominated ice candidate pair stats
pair, ok := entry.(webrtc.ICECandidatePairStats)
if ok && pair.Nominated {
nominatedRemoteCandidates[pair.RemoteCandidateID] = struct{}{}
}
}
iceCandidatesUsed := []webrtc.ICECandidateStats{}
for id := range nominatedRemoteCandidates {
if candidate, ok := remoteCandidates[id]; ok {
iceCandidatesUsed = append(iceCandidatesUsed, candidate)
}
}
met.SetICECandidatesUsed(iceCandidatesUsed)
}
}

View File

@ -0,0 +1,55 @@
package payload
import "math"
const (
OP_MOVE = 0x01
OP_SCROLL = 0x02
OP_KEY_DOWN = 0x03
OP_KEY_UP = 0x04
OP_BTN_DOWN = 0x05
OP_BTN_UP = 0x06
OP_PING = 0x07
// touch events
OP_TOUCH_BEGIN = 0x08
OP_TOUCH_UPDATE = 0x09
OP_TOUCH_END = 0x0a
)
type Move struct {
X uint16
Y uint16
}
// TODO: remove this once the client is fixed
type Scroll_Old struct {
X int16
Y int16
}
type Scroll struct {
DeltaX int16
DeltaY int16
ControlKey bool
}
type Key struct {
Key uint32
}
type Ping struct {
// client's timestamp split into two uint32
ClientTs1 uint32
ClientTs2 uint32
}
func (p Ping) ClientTs() uint64 {
return (uint64(p.ClientTs1) * uint64(math.MaxUint32)) + uint64(p.ClientTs2)
}
type Touch struct {
TouchId uint32
X int32
Y int32
Pressure uint8
}

View File

@ -0,0 +1,33 @@
package payload
import "math"
const (
OP_CURSOR_POSITION = 0x01
OP_CURSOR_IMAGE = 0x02
OP_PONG = 0x03
)
type CursorPosition struct {
X uint16
Y uint16
}
type CursorImage struct {
Width uint16
Height uint16
Xhot uint16
Yhot uint16
}
type Pong struct {
Ping
// server's timestamp split into two uint32
ServerTs1 uint32
ServerTs2 uint32
}
func (p Pong) ServerTs() uint64 {
return (uint64(p.ServerTs1) * uint64(math.MaxUint32)) + uint64(p.ServerTs2)
}

View File

@ -0,0 +1,6 @@
package payload
type Header struct {
Event uint8
Length uint16
}

View File

@ -0,0 +1,543 @@
package webrtc
import (
"bytes"
"encoding/binary"
"sync"
"time"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/utils"
)
type WebRTCPeerCtx struct {
mu sync.Mutex
logger zerolog.Logger
session types.Session
metrics *metrics
connection *webrtc.PeerConnection
// bandwidth estimator
estimator cc.BandwidthEstimator
estimateTrend *utils.TrendDetector
// stream selectors
video types.StreamSelectorManager
audio types.StreamSinkManager
// tracks & channels
audioTrack *Track
videoTrack *Track
dataChannel *webrtc.DataChannel
rtcpChannel chan []rtcp.Packet
// config
iceTrickle bool
estimatorConfig config.WebRTCEstimator
paused bool
videoAuto bool
videoDisabled bool
audioDisabled bool
}
//
// connection
//
func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) {
peer.mu.Lock()
defer peer.mu.Unlock()
offer, err := peer.connection.CreateOffer(&webrtc.OfferOptions{
ICERestart: ICERestart,
})
if err != nil {
return nil, err
}
return peer.setLocalDescription(offer)
}
func (peer *WebRTCPeerCtx) CreateAnswer() (*webrtc.SessionDescription, error) {
peer.mu.Lock()
defer peer.mu.Unlock()
answer, err := peer.connection.CreateAnswer(nil)
if err != nil {
return nil, err
}
return peer.setLocalDescription(answer)
}
func (peer *WebRTCPeerCtx) setLocalDescription(description webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
if !peer.iceTrickle {
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(peer.connection)
if err := peer.connection.SetLocalDescription(description); err != nil {
return nil, err
}
<-gatherComplete
} else {
if err := peer.connection.SetLocalDescription(description); err != nil {
return nil, err
}
}
return peer.connection.LocalDescription(), nil
}
func (peer *WebRTCPeerCtx) SetRemoteDescription(desc webrtc.SessionDescription) error {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.connection.SetRemoteDescription(desc)
}
func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.connection.AddICECandidate(candidate)
}
// TODO: Add shutdown function?
func (peer *WebRTCPeerCtx) Destroy() {
peer.mu.Lock()
defer peer.mu.Unlock()
var err error
// if peer connection is not closed, close it
if peer.connection.ConnectionState() != webrtc.PeerConnectionStateClosed {
err = peer.connection.Close()
}
peer.logger.Err(err).Msg("peer connection destroyed")
}
func (peer *WebRTCPeerCtx) estimatorReader() {
conf := peer.estimatorConfig
// if estimator is not in debug mode, use a nop logger
var debugLogger zerolog.Logger
if conf.Debug {
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
} else {
debugLogger = zerolog.Nop()
}
// if estimator is disabled, do nothing
if peer.estimator == nil {
return
}
// use a ticker to get current client target bitrate
ticker := time.NewTicker(conf.ReadInterval)
defer ticker.Stop()
// since when is the estimate stable/unstable
stableSince := time.Now() // we asume stable at start
unstableSince := time.Time{}
// since when are we neutral but cannot accomodate current bitrate
// we migt be stalled or estimator just reached zer (very bad connection)
stalledSince := time.Time{}
// when was the last upgrade/downgrade
lastUpgradeTime := time.Time{}
lastDowngradeTime := time.Time{}
for range ticker.C {
targetBitrate := peer.estimator.GetTargetBitrate()
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
// if peer connection is closed, stop reading
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
// if estimation or video is disabled, do nothing
if !peer.videoAuto || peer.videoDisabled || peer.paused || conf.Passive {
continue
}
// get trend direction to decide if we should upgrade or downgrade
peer.estimateTrend.AddValue(int64(targetBitrate))
direction := peer.estimateTrend.GetDirection()
// get current stream bitrate
stream, ok := peer.videoTrack.Stream()
if !ok {
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
continue
}
// if stream bitrate is 0, we need to wait for some time until we get a valid value
streamId, streamBitrate := stream.ID(), stream.Bitrate()
if streamBitrate == 0 {
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
continue
}
// check whats the difference between target and stream bitrate
diff := float64(targetBitrate) / float64(streamBitrate)
debugLogger.Info().
Float64("diff", diff).
Int("target_bitrate", targetBitrate).
Uint64("stream_bitrate", streamBitrate).
Str("direction", direction.String()).
Msg("got bitrate from estimator")
// if we can accomodate current stream or we are not netural anymore,
// we are not stalled so we reset the stalled time
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
stalledSince = time.Now()
}
// if we are neutral and stalled for too long, we might be congesting
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
if stalled {
debugLogger.Warn().
Time("stalled_since", stalledSince).
Msgf("it looks like we are stalled")
}
// if we have an downward trend or are stalled, we might be congesting
if direction == utils.TrendDirectionDownward || stalled {
// we reset the stable time because we are congesting
stableSince = time.Now()
// if we downgraded recently, we wait for some more time
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
debugLogger.Debug().
Time("last_downgrade", lastDowngradeTime).
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
continue
}
// if we are not unstable but we fluctuate we should wait for some more time
if time.Since(unstableSince) < conf.UnstableDuration {
debugLogger.Debug().
Time("unstable_since", unstableSince).
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
continue
}
// if we still have a big difference between target and stream bitrate, we wait for some more time
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("we still have a big difference between target and stream bitrate, " +
"therefore we still should be able to accomodate current stream")
continue
}
err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeLower,
},
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
}
lastDowngradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the lowest stream")
} else {
debugLogger.Info().Msg("downgraded video stream")
}
continue
}
// we reset the unstable time because we are not congesting
unstableSince = time.Now()
// if we have a neutral or upward trend, that means our estimate is stable
// if we are on the highest stream, we don't need to do anything
// but if there is a higher stream, we should try to upgrade and see if it works
// if we upgraded recently, we wait for some more time
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
debugLogger.Debug().
Time("last_upgrade", lastUpgradeTime).
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
continue
}
// if we are not stable for long enough, we wait for some more time
// because bandwidth estimation might fluctuate
if time.Since(stableSince) < conf.StableDuration {
debugLogger.Debug().
Time("stable_since", stableSince).
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
continue
}
// upgrade only if estimated bitrate passed the threshold
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
"therefore we should wait for some more time")
continue
}
err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeHigher,
},
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
}
lastUpgradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the highest stream")
} else {
debugLogger.Info().Msg("upgraded video stream")
}
}
}
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
peer.mu.Lock()
defer peer.mu.Unlock()
peer.videoTrack.SetPaused(isPaused || peer.videoDisabled)
peer.audioTrack.SetPaused(isPaused || peer.audioDisabled)
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
peer.paused = isPaused
return nil
}
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.paused
}
//
// video
//
func (peer *WebRTCPeerCtx) SetVideo(r types.PeerVideoRequest) error {
peer.mu.Lock()
defer peer.mu.Unlock()
modified := false
// video disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.videoDisabled != disabled {
peer.videoDisabled = disabled
peer.videoTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set video disabled")
modified = true
}
}
// video selector
if r.Selector != nil {
selector := *r.Selector
// get requested video stream from selector
stream, ok := peer.video.GetStream(selector)
if !ok {
return types.ErrWebRTCStreamNotFound
}
// set video stream to track
changed, err := peer.videoTrack.SetStream(stream)
if err != nil {
return err
}
// update only if stream changed
if changed {
videoID := stream.ID()
peer.metrics.SetVideoID(videoID)
peer.logger.Info().Str("video_id", videoID).Msg("set video")
modified = true
}
}
// video auto
if r.Auto != nil {
videoAuto := *r.Auto
if peer.estimator == nil || peer.estimatorConfig.Passive {
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
videoAuto = false // ensure video auto is disabled
}
// update only if video auto changed
if peer.videoAuto != videoAuto {
peer.videoAuto = videoAuto
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
modified = true
}
}
// send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_VIDEO, peer.Video())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Video() types.PeerVideo {
peer.mu.Lock()
defer peer.mu.Unlock()
// get current video stream ID
ID := ""
stream, ok := peer.videoTrack.Stream()
if ok {
ID = stream.ID()
}
return types.PeerVideo{
Disabled: peer.videoDisabled,
ID: ID,
Video: ID, // TODO: Remove, used for backward compatibility
Auto: peer.videoAuto,
}
}
//
// audio
//
func (peer *WebRTCPeerCtx) SetAudio(r types.PeerAudioRequest) error {
peer.mu.Lock()
defer peer.mu.Unlock()
modified := false
// audio disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.audioDisabled != disabled {
peer.audioDisabled = disabled
peer.audioTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set audio disabled")
modified = true
}
}
// send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_AUDIO, peer.Audio())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Audio() types.PeerAudio {
peer.mu.Lock()
defer peer.mu.Unlock()
return types.PeerAudio{
Disabled: peer.audioDisabled,
}
}
//
// data channel
//
func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error {
peer.mu.Lock()
defer peer.mu.Unlock()
// do not send cursor position to host
if peer.session.IsHost() {
return nil
}
header := payload.Header{
Event: payload.OP_CURSOR_POSITION,
Length: 7,
}
data := payload.CursorPosition{
X: uint16(x),
Y: uint16(y),
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
return err
}
return peer.dataChannel.Send(buffer.Bytes())
}
func (peer *WebRTCPeerCtx) SendCursorImage(cur *types.CursorImage, img []byte) error {
peer.mu.Lock()
defer peer.mu.Unlock()
header := payload.Header{
Event: payload.OP_CURSOR_IMAGE,
Length: uint16(11 + len(img)),
}
data := payload.CursorImage{
Width: cur.Width,
Height: cur.Height,
Xhot: cur.Xhot,
Yhot: cur.Yhot,
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, img); err != nil {
return err
}
return peer.dataChannel.Send(buffer.Bytes())
}

View File

@ -0,0 +1,27 @@
package pionlog
import (
"github.com/pion/logging"
"github.com/rs/zerolog"
)
func New(logger zerolog.Logger) Factory {
return Factory{
Logger: logger.With().Str("submodule", "pion").Logger(),
}
}
type Factory struct {
Logger zerolog.Logger
}
func (l Factory) NewLogger(subsystem string) logging.LeveledLogger {
if subsystem == "sctp" {
return nulllog{}
}
return logger{
subsystem: subsystem,
logger: l.Logger.With().Str("subsystem", subsystem).Logger(),
}
}

View File

@ -0,0 +1,66 @@
package pionlog
import (
"fmt"
"strings"
"github.com/rs/zerolog"
)
type logger struct {
logger zerolog.Logger
subsystem string
}
func (l logger) Trace(msg string) {
l.logger.Trace().Msg(strings.TrimSpace(msg))
}
func (l logger) Tracef(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Trace().Msg(strings.TrimSpace(msg))
}
func (l logger) Debug(msg string) {
l.logger.Debug().Msg(strings.TrimSpace(msg))
}
func (l logger) Debugf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Debug().Msg(strings.TrimSpace(msg))
}
func (l logger) Info(msg string) {
if strings.Contains(msg, "duplicated packet") {
return
}
l.logger.Info().Msg(strings.TrimSpace(msg))
}
func (l logger) Infof(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
if strings.Contains(msg, "duplicated packet") {
return
}
l.logger.Info().Msg(strings.TrimSpace(msg))
}
func (l logger) Warn(msg string) {
l.logger.Warn().Msg(strings.TrimSpace(msg))
}
func (l logger) Warnf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Warn().Msg(strings.TrimSpace(msg))
}
func (l logger) Error(msg string) {
l.logger.Error().Msg(strings.TrimSpace(msg))
}
func (l logger) Errorf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Error().Msg(strings.TrimSpace(msg))
}

View File

@ -0,0 +1,14 @@
package pionlog
type nulllog struct{}
func (l nulllog) Trace(msg string) {}
func (l nulllog) Tracef(format string, args ...any) {}
func (l nulllog) Debug(msg string) {}
func (l nulllog) Debugf(format string, args ...any) {}
func (l nulllog) Info(msg string) {}
func (l nulllog) Infof(format string, args ...any) {}
func (l nulllog) Warn(msg string) {}
func (l nulllog) Warnf(format string, args ...any) {}
func (l nulllog) Error(msg string) {}
func (l nulllog) Errorf(format string, args ...any) {}

View File

@ -0,0 +1,203 @@
package webrtc
import (
"errors"
"io"
"sync"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type Track struct {
logger zerolog.Logger
track *webrtc.TrackLocalStaticSample
rtcpCh chan []rtcp.Packet
sample chan types.Sample
paused bool
stream types.StreamSinkManager
streamMu sync.Mutex
}
type trackOption func(*Track)
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
return func(t *Track) {
t.rtcpCh = rtcp
}
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...trackOption) (*Track, error) {
id := codec.Type.String()
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
if err != nil {
return nil, err
}
t := &Track{
logger: logger.With().Str("id", id).Logger(),
track: track,
rtcpCh: nil,
sample: make(chan types.Sample),
}
for _, opt := range opts {
opt(t)
}
sender, err := connection.AddTrack(t.track)
if err != nil {
return nil, err
}
go t.rtcpReader(sender)
go t.sampleReader()
return t, nil
}
func (t *Track) Shutdown() {
t.RemoveStream()
close(t.sample)
}
func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
for {
packets, _, err := sender.ReadRTCP()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
t.logger.Debug().Msg("track rtcp reader closed")
return
}
t.logger.Warn().Err(err).Msg("failed to read track rtcp")
continue
}
if t.rtcpCh != nil {
t.rtcpCh <- packets
}
}
}
// --- sample ---
func (t *Track) sampleReader() {
for {
sample, ok := <-t.sample
if !ok {
t.logger.Debug().Msg("track sample reader closed")
return
}
err := t.track.WriteSample(media.Sample{
Data: sample.Data,
Duration: sample.Duration,
Timestamp: sample.Timestamp,
})
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.logger.Warn().Err(err).Msg("failed to write sample to track")
}
}
}
func (t *Track) WriteSample(sample types.Sample) {
t.sample <- sample
}
// --- stream ---
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if we already listen to the stream, do nothing
if t.stream == stream {
return false, nil
}
// if paused, we switch the stream but don't add the listener
if t.paused {
t.stream = stream
return true, nil
}
var err error
if t.stream != nil {
err = t.stream.MoveListenerTo(t, stream)
} else {
err = stream.AddListener(t)
}
if err != nil {
return false, err
}
t.stream = stream
return true, nil
}
func (t *Track) RemoveStream() {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if there is no stream, or paused we don't need to remove the listener
if t.stream == nil || t.paused {
t.stream = nil
return
}
err := t.stream.RemoveListener(t)
if err != nil {
t.logger.Warn().Err(err).Msg("failed to remove listener from stream")
}
t.stream = nil
}
func (t *Track) Stream() (types.StreamSinkManager, bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.stream, t.stream != nil
}
// --- paused ---
func (t *Track) SetPaused(paused bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if there is no state change or no stream, do nothing
if t.paused == paused || t.stream == nil {
t.paused = paused
return
}
var err error
if paused {
err = t.stream.RemoveListener(t)
} else {
err = t.stream.AddListener(t)
}
if err != nil {
t.logger.Warn().Err(err).Msg("failed to change listener state")
return
}
t.paused = paused
}
func (t *Track) Paused() bool {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.paused
}