mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
move server to server directory.
This commit is contained in:
168
server/internal/webrtc/cursor/image.go
Normal file
168
server/internal/webrtc/cursor/image.go
Normal file
@ -0,0 +1,168 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type ImageListener interface {
|
||||
SendCursorImage(cur *types.CursorImage, img []byte) error
|
||||
}
|
||||
|
||||
type Image interface {
|
||||
Start()
|
||||
Shutdown()
|
||||
GetCurrent() (cur *types.CursorImage, img []byte, err error)
|
||||
AddListener(listener ImageListener)
|
||||
RemoveListener(listener ImageListener)
|
||||
}
|
||||
|
||||
type imageEntry struct {
|
||||
*types.CursorImage
|
||||
ImagePNG []byte
|
||||
}
|
||||
|
||||
type image struct {
|
||||
logger zerolog.Logger
|
||||
desktop types.DesktopManager
|
||||
|
||||
listeners map[uintptr]ImageListener
|
||||
listenersMu sync.RWMutex
|
||||
|
||||
cache map[uint64]*imageEntry
|
||||
cacheMu sync.RWMutex
|
||||
current *imageEntry
|
||||
maxSerial uint64
|
||||
}
|
||||
|
||||
func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image {
|
||||
return &image{
|
||||
logger: logger.With().Str("submodule", "cursor-image").Logger(),
|
||||
desktop: desktop,
|
||||
listeners: map[uintptr]ImageListener{},
|
||||
cache: map[uint64]*imageEntry{},
|
||||
maxSerial: 300, // TODO: Cleanup?
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) Start() {
|
||||
manager.desktop.OnCursorChanged(func(serial uint64) {
|
||||
entry, err := manager.getCached(serial)
|
||||
if err != nil {
|
||||
manager.logger.Err(err).Msg("failed to get cursor image")
|
||||
return
|
||||
}
|
||||
|
||||
manager.current = entry
|
||||
|
||||
manager.listenersMu.RLock()
|
||||
for _, l := range manager.listeners {
|
||||
if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil {
|
||||
manager.logger.Err(err).Msg("failed to set cursor image")
|
||||
}
|
||||
}
|
||||
manager.listenersMu.RUnlock()
|
||||
})
|
||||
|
||||
manager.logger.Info().Msg("starting")
|
||||
}
|
||||
|
||||
func (manager *image) Shutdown() {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
for key := range manager.listeners {
|
||||
delete(manager.listeners, key)
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
}
|
||||
|
||||
func (manager *image) getCached(serial uint64) (*imageEntry, error) {
|
||||
// zero means no serial available
|
||||
if serial == 0 || serial > manager.maxSerial {
|
||||
manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass")
|
||||
return manager.fetchEntry()
|
||||
}
|
||||
|
||||
manager.cacheMu.RLock()
|
||||
entry, ok := manager.cache[serial]
|
||||
manager.cacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss")
|
||||
|
||||
entry, err := manager.fetchEntry()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager.cacheMu.Lock()
|
||||
manager.cache[entry.Serial] = entry
|
||||
manager.cacheMu.Unlock()
|
||||
|
||||
if entry.Serial != serial {
|
||||
manager.logger.Warn().
|
||||
Uint64("expected-serial", serial).
|
||||
Uint64("received-serial", entry.Serial).
|
||||
Msg("serial mismatch")
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) {
|
||||
if manager.current != nil {
|
||||
return manager.current.CursorImage, manager.current.ImagePNG, nil
|
||||
}
|
||||
|
||||
entry, err := manager.fetchEntry()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
manager.current = entry
|
||||
return entry.CursorImage, entry.ImagePNG, nil
|
||||
}
|
||||
|
||||
func (manager *image) AddListener(listener ImageListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
manager.listeners[ptr] = listener
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) RemoveListener(listener ImageListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
delete(manager.listeners, ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) fetchEntry() (*imageEntry, error) {
|
||||
cur := manager.desktop.GetCursorImage()
|
||||
|
||||
img, err := utils.CreatePNGImage(cur.Image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cur.Image = nil // free memory
|
||||
|
||||
return &imageEntry{
|
||||
CursorImage: cur,
|
||||
ImagePNG: img,
|
||||
}, nil
|
||||
}
|
74
server/internal/webrtc/cursor/position.go
Normal file
74
server/internal/webrtc/cursor/position.go
Normal file
@ -0,0 +1,74 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type PositionListener interface {
|
||||
SendCursorPosition(x, y int) error
|
||||
}
|
||||
|
||||
type Position interface {
|
||||
Shutdown()
|
||||
Set(x, y int)
|
||||
AddListener(listener PositionListener)
|
||||
RemoveListener(listener PositionListener)
|
||||
}
|
||||
|
||||
type position struct {
|
||||
logger zerolog.Logger
|
||||
|
||||
listeners map[uintptr]PositionListener
|
||||
listenersMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPosition(logger zerolog.Logger) *position {
|
||||
return &position{
|
||||
logger: logger.With().Str("submodule", "cursor-position").Logger(),
|
||||
listeners: map[uintptr]PositionListener{},
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) Shutdown() {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
for key := range manager.listeners {
|
||||
delete(manager.listeners, key)
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
}
|
||||
|
||||
func (manager *position) Set(x, y int) {
|
||||
manager.listenersMu.RLock()
|
||||
defer manager.listenersMu.RUnlock()
|
||||
|
||||
for _, l := range manager.listeners {
|
||||
if err := l.SendCursorPosition(x, y); err != nil {
|
||||
manager.logger.Err(err).Msg("failed to set cursor position")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) AddListener(listener PositionListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
manager.listeners[ptr] = listener
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) RemoveListener(listener PositionListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
delete(manager.listeners, ptr)
|
||||
}
|
||||
}
|
205
server/internal/webrtc/handler.go
Normal file
205
server/internal/webrtc/handler.go
Normal file
@ -0,0 +1,205 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func (manager *WebRTCManagerCtx) handle(
|
||||
logger zerolog.Logger, data []byte,
|
||||
dataChannel *webrtc.DataChannel,
|
||||
session types.Session,
|
||||
) error {
|
||||
isHost := session.IsHost()
|
||||
|
||||
//
|
||||
// parse header
|
||||
//
|
||||
|
||||
buffer := bytes.NewBuffer(data)
|
||||
|
||||
header := &payload.Header{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//
|
||||
// parse body
|
||||
//
|
||||
|
||||
// handle cursor move event
|
||||
if header.Event == payload.OP_MOVE {
|
||||
payload := &payload.Move{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
x, y := int(payload.X), int(payload.Y)
|
||||
if isHost {
|
||||
// handle active cursor movement
|
||||
manager.desktop.Move(x, y)
|
||||
manager.curPosition.Set(x, y)
|
||||
} else {
|
||||
// handle inactive cursor movement
|
||||
session.SetCursor(types.Cursor{
|
||||
X: x,
|
||||
Y: y,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if header.Event == payload.OP_PING {
|
||||
ping := &payload.Ping{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, ping); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create pong header
|
||||
header := payload.Header{
|
||||
Event: payload.OP_PONG,
|
||||
Length: 19,
|
||||
}
|
||||
|
||||
// generate server timestamp
|
||||
serverTs := uint64(time.Now().UnixMilli())
|
||||
|
||||
// generate pong payload
|
||||
pong := payload.Pong{
|
||||
Ping: *ping,
|
||||
ServerTs1: uint32(serverTs / math.MaxUint32),
|
||||
ServerTs2: uint32(serverTs % math.MaxUint32),
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, pong); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dataChannel.Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
// continue only if session is host
|
||||
if !isHost {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch header.Event {
|
||||
case payload.OP_SCROLL:
|
||||
// TODO: remove this once the client is fixed
|
||||
if header.Length == 4 {
|
||||
payload := &payload.Scroll_Old{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.desktop.Scroll(int(payload.X), int(payload.Y), false)
|
||||
logger.Trace().
|
||||
Int16("x", payload.X).
|
||||
Int16("y", payload.Y).
|
||||
Msg("scroll")
|
||||
} else {
|
||||
payload := &payload.Scroll{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.desktop.Scroll(int(payload.DeltaX), int(payload.DeltaY), payload.ControlKey)
|
||||
logger.Trace().
|
||||
Int16("deltaX", payload.DeltaX).
|
||||
Int16("deltaY", payload.DeltaY).
|
||||
Bool("controlKey", payload.ControlKey).
|
||||
Msg("scroll")
|
||||
}
|
||||
case payload.OP_KEY_DOWN:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.KeyDown(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key down failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("key down")
|
||||
}
|
||||
case payload.OP_KEY_UP:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.KeyUp(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key up failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("key up")
|
||||
}
|
||||
case payload.OP_BTN_DOWN:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.ButtonDown(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button down failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("button down")
|
||||
}
|
||||
case payload.OP_BTN_UP:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.ButtonUp(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button up failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("button up")
|
||||
}
|
||||
case payload.OP_TOUCH_BEGIN:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchBegin(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch begin failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch begin")
|
||||
}
|
||||
case payload.OP_TOUCH_UPDATE:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchUpdate(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch update failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch update")
|
||||
}
|
||||
case payload.OP_TOUCH_END:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchEnd(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch end failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch end")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
576
server/internal/webrtc/manager.go
Normal file
576
server/internal/webrtc/manager.go
Normal file
@ -0,0 +1,576 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/interceptor/pkg/cc"
|
||||
"github.com/pion/interceptor/pkg/gcc"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/webrtc/cursor"
|
||||
"github.com/demodesk/neko/internal/webrtc/pionlog"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// size of receiving channel used to buffer incoming TCP packets
|
||||
tcpReadChanBufferSize = 50
|
||||
|
||||
// size of buffer used to buffer outgoing TCP packets. Default is 4MB
|
||||
tcpWriteBufferSizeInBytes = 4 * 1024 * 1024
|
||||
|
||||
// the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds
|
||||
disconnectedTimeout = 4 * time.Second
|
||||
|
||||
// the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds
|
||||
failedTimeout = 6 * time.Second
|
||||
|
||||
// how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds
|
||||
keepAliveInterval = 2 * time.Second
|
||||
|
||||
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
|
||||
rtcpPLIInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
|
||||
logger := log.With().Str("module", "webrtc").Logger()
|
||||
|
||||
configuration := webrtc.Configuration{
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
|
||||
}
|
||||
|
||||
if !config.ICELite {
|
||||
ICEServers := []webrtc.ICEServer{}
|
||||
for _, server := range config.ICEServersBackend {
|
||||
var credential any
|
||||
if server.Credential != "" {
|
||||
credential = server.Credential
|
||||
} else {
|
||||
credential = false
|
||||
}
|
||||
|
||||
ICEServers = append(ICEServers, webrtc.ICEServer{
|
||||
URLs: server.URLs,
|
||||
Username: server.Username,
|
||||
Credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
configuration.ICEServers = ICEServers
|
||||
}
|
||||
|
||||
return &WebRTCManagerCtx{
|
||||
logger: logger,
|
||||
config: config,
|
||||
metrics: newMetricsManager(),
|
||||
|
||||
webrtcConfiguration: configuration,
|
||||
|
||||
desktop: desktop,
|
||||
capture: capture,
|
||||
curImage: cursor.NewImage(logger, desktop),
|
||||
curPosition: cursor.NewPosition(logger),
|
||||
}
|
||||
}
|
||||
|
||||
type WebRTCManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
config *config.WebRTC
|
||||
metrics *metricsManager
|
||||
peerId int32
|
||||
|
||||
desktop types.DesktopManager
|
||||
capture types.CaptureManager
|
||||
curImage cursor.Image
|
||||
curPosition cursor.Position
|
||||
|
||||
webrtcConfiguration webrtc.Configuration
|
||||
|
||||
tcpMux ice.TCPMux
|
||||
udpMux ice.UDPMux
|
||||
|
||||
camStop, micStop *func()
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) Start() {
|
||||
manager.curImage.Start()
|
||||
|
||||
logger := pionlog.New(manager.logger)
|
||||
|
||||
// add TCP Mux listener
|
||||
if manager.config.TCPMux > 0 {
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.IP{0, 0, 0, 0},
|
||||
Port: manager.config.TCPMux,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
manager.logger.Fatal().Err(err).Msg("unable to setup ice TCP mux")
|
||||
}
|
||||
|
||||
manager.tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{
|
||||
Listener: tcpListener,
|
||||
Logger: logger.NewLogger("ice-tcp"),
|
||||
ReadBufferSize: tcpReadChanBufferSize,
|
||||
WriteBufferSize: tcpWriteBufferSizeInBytes,
|
||||
})
|
||||
}
|
||||
|
||||
// add UDP Mux listener
|
||||
if manager.config.UDPMux > 0 {
|
||||
var err error
|
||||
manager.udpMux, err = ice.NewMultiUDPMuxFromPort(manager.config.UDPMux,
|
||||
ice.UDPMuxFromPortWithLogger(logger.NewLogger("ice-udp")),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
manager.logger.Fatal().Err(err).Msg("unable to setup ice UDP mux")
|
||||
}
|
||||
}
|
||||
|
||||
manager.logger.Info().
|
||||
Bool("icelite", manager.config.ICELite).
|
||||
Bool("icetrickle", manager.config.ICETrickle).
|
||||
Interface("iceservers-frontend", manager.config.ICEServersFrontend).
|
||||
Interface("iceservers-backend", manager.config.ICEServersBackend).
|
||||
Str("nat1to1", strings.Join(manager.config.NAT1To1IPs, ",")).
|
||||
Str("epr", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)).
|
||||
Int("tcpmux", manager.config.TCPMux).
|
||||
Int("udpmux", manager.config.UDPMux).
|
||||
Msg("webrtc starting")
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.curImage.Shutdown()
|
||||
manager.curPosition.Shutdown()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
|
||||
return manager.config.ICEServersFrontend
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||
// create media engine
|
||||
engine := &webrtc.MediaEngine{}
|
||||
for _, codec := range codecs {
|
||||
if err := codec.Register(engine); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// create setting engine
|
||||
settings := webrtc.SettingEngine{
|
||||
LoggerFactory: pionlog.New(logger),
|
||||
}
|
||||
|
||||
settings.DisableMediaEngineCopy(true)
|
||||
settings.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
settings.SetNAT1To1IPs(manager.config.NAT1To1IPs, webrtc.ICECandidateTypeHost)
|
||||
settings.SetLite(manager.config.ICELite)
|
||||
// make sure server answer sdp setup as passive, to not force DTLS renegotiation
|
||||
// otherwise iOS renegotiation fails with: Failed to set SSL role for the transport.
|
||||
settings.SetAnsweringDTLSRole(webrtc.DTLSRoleServer)
|
||||
|
||||
var networkType []webrtc.NetworkType
|
||||
|
||||
// udp candidates
|
||||
if manager.udpMux != nil {
|
||||
settings.SetICEUDPMux(manager.udpMux)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeUDP4,
|
||||
webrtc.NetworkTypeUDP6,
|
||||
)
|
||||
} else if manager.config.EphemeralMax != 0 {
|
||||
_ = settings.SetEphemeralUDPPortRange(manager.config.EphemeralMin, manager.config.EphemeralMax)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeUDP4,
|
||||
webrtc.NetworkTypeUDP6,
|
||||
)
|
||||
}
|
||||
|
||||
// tcp candidates
|
||||
if manager.tcpMux != nil {
|
||||
settings.SetICETCPMux(manager.tcpMux)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeTCP4,
|
||||
webrtc.NetworkTypeTCP6,
|
||||
)
|
||||
}
|
||||
|
||||
// enable support for TCP and UDP ICE candidates
|
||||
settings.SetNetworkTypes(networkType)
|
||||
|
||||
// create interceptor registry
|
||||
registry := &interceptor.Registry{}
|
||||
|
||||
// create bandwidth estimator
|
||||
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
||||
if manager.config.Estimator.Enabled {
|
||||
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
||||
return gcc.NewSendSideBWE(
|
||||
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
|
||||
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
|
||||
estimatorChan <- estimator
|
||||
})
|
||||
|
||||
registry.Add(congestionController)
|
||||
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
// no estimator, send nil
|
||||
estimatorChan <- nil
|
||||
}
|
||||
|
||||
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// create new API
|
||||
api := webrtc.NewAPI(
|
||||
webrtc.WithMediaEngine(engine),
|
||||
webrtc.WithSettingEngine(settings),
|
||||
webrtc.WithInterceptorRegistry(registry),
|
||||
)
|
||||
|
||||
// create new peer connection
|
||||
configuration := manager.webrtcConfiguration
|
||||
connection, err := api.NewPeerConnection(configuration)
|
||||
return connection, <-estimatorChan, err
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
|
||||
id := atomic.AddInt32(&manager.peerId, 1)
|
||||
|
||||
// get metrics for session
|
||||
metrics := manager.metrics.getBySession(session)
|
||||
metrics.NewConnection()
|
||||
|
||||
// add session id to logger context
|
||||
logger := manager.logger.With().Str("session_id", session.ID()).Int32("peer_id", id).Logger()
|
||||
logger.Info().Msg("creating webrtc peer")
|
||||
|
||||
// all audios must have the same codec
|
||||
audio := manager.capture.Audio()
|
||||
audioCodec := audio.Codec()
|
||||
|
||||
// all videos must have the same codec
|
||||
video := manager.capture.Video()
|
||||
videoCodec := video.Codec()
|
||||
|
||||
connection, estimator, err := manager.newPeerConnection(
|
||||
logger, []codec.RTPCodec{audioCodec, videoCodec})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// asynchronously send local ICE Candidates
|
||||
if manager.config.ICETrickle {
|
||||
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
logger.Debug().Msg("all local ice candidates sent")
|
||||
return
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_CANDIDATE,
|
||||
message.SignalCandidate{
|
||||
ICECandidateInit: candidate.ToJSON(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// audio track
|
||||
audioTrack, err := NewTrack(logger, audioCodec, connection)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we disable audio by default manually
|
||||
audioTrack.SetPaused(true)
|
||||
|
||||
// set stream for audio track
|
||||
_, err = audioTrack.SetStream(audio)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// video track
|
||||
videoRtcp := make(chan []rtcp.Packet, 1)
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
//
|
||||
// stream for video track will be set later
|
||||
//
|
||||
|
||||
// data channel
|
||||
|
||||
dataChannel, err := connection.CreateDataChannel("data", nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer := &WebRTCPeerCtx{
|
||||
logger: logger,
|
||||
session: session,
|
||||
metrics: metrics,
|
||||
connection: connection,
|
||||
// bandwidth estimator
|
||||
estimator: estimator,
|
||||
estimateTrend: utils.NewTrendDetector(
|
||||
utils.TrendDetectorParams{
|
||||
// Probing
|
||||
//RequiredSamples: 3,
|
||||
//DownwardTrendThreshold: 0.0,
|
||||
//CollapseValues: false,
|
||||
// Non-Probing
|
||||
RequiredSamples: 8,
|
||||
DownwardTrendThreshold: -0.5,
|
||||
CollapseValues: true,
|
||||
}),
|
||||
// stream selectors
|
||||
video: video,
|
||||
audio: audio,
|
||||
// tracks & channels
|
||||
audioTrack: audioTrack,
|
||||
videoTrack: videoTrack,
|
||||
dataChannel: dataChannel,
|
||||
rtcpChannel: videoRtcp,
|
||||
// config
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
estimatorConfig: manager.config.Estimator,
|
||||
audioDisabled: true, // we disable audio by default manually
|
||||
}
|
||||
|
||||
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
logger := logger.With().
|
||||
Str("kind", track.Kind().String()).
|
||||
Str("mime", track.Codec().RTPCodecCapability.MimeType).
|
||||
Logger()
|
||||
|
||||
logger.Info().Msgf("received new remote track")
|
||||
|
||||
if !session.Profile().CanShareMedia {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("media sharing is disabled for this session")
|
||||
return
|
||||
}
|
||||
|
||||
// parse codec from remote track
|
||||
codec, ok := codec.ParseRTC(track.Codec())
|
||||
if !ok {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("remote track with unknown codec")
|
||||
return
|
||||
}
|
||||
|
||||
var srcManager types.StreamSrcManager
|
||||
|
||||
stopped := false
|
||||
stopFn := func() {
|
||||
if stopped {
|
||||
return
|
||||
}
|
||||
|
||||
stopped = true
|
||||
err := receiver.Stop()
|
||||
srcManager.Stop()
|
||||
logger.Err(err).Msg("remote track stopped")
|
||||
}
|
||||
|
||||
if track.Kind() == webrtc.RTPCodecTypeAudio {
|
||||
// audio -> microphone
|
||||
srcManager = manager.capture.Microphone()
|
||||
defer stopFn()
|
||||
|
||||
if manager.micStop != nil {
|
||||
(*manager.micStop)()
|
||||
}
|
||||
manager.micStop = &stopFn
|
||||
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
|
||||
// video -> webcam
|
||||
srcManager = manager.capture.Webcam()
|
||||
defer stopFn()
|
||||
|
||||
if manager.camStop != nil {
|
||||
(*manager.camStop)()
|
||||
}
|
||||
manager.camStop = &stopFn
|
||||
} else {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("remote track with unsupported codec type")
|
||||
return
|
||||
}
|
||||
|
||||
err := srcManager.Start(codec)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to start pipeline")
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(rtcpPLIInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
err := connection.WriteRTCP([]rtcp.Packet{
|
||||
&rtcp.PictureLossIndication{
|
||||
MediaSSRC: uint32(track.SSRC()),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("remote track rtcp send err")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1400)
|
||||
for {
|
||||
i, _, err := track.Read(buf)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed read from remote track")
|
||||
break
|
||||
}
|
||||
|
||||
srcManager.Push(buf[:i])
|
||||
}
|
||||
|
||||
logger.Info().Msg("remote track data finished")
|
||||
})
|
||||
|
||||
connection.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
logger.Info().Interface("data_channel", dc).Msg("got remote data channel")
|
||||
})
|
||||
|
||||
var once sync.Once
|
||||
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
session.SetWebRTCConnected(peer, true)
|
||||
case webrtc.PeerConnectionStateDisconnected,
|
||||
webrtc.PeerConnectionStateFailed:
|
||||
peer.Destroy()
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
// ensure we only run this once
|
||||
once.Do(func() {
|
||||
session.SetWebRTCConnected(peer, false)
|
||||
//
|
||||
// TODO: Shutdown peer?
|
||||
//
|
||||
audioTrack.Shutdown()
|
||||
videoTrack.Shutdown()
|
||||
close(videoRtcp)
|
||||
})
|
||||
}
|
||||
|
||||
metrics.SetState(state)
|
||||
})
|
||||
|
||||
dataChannel.OnOpen(func() {
|
||||
manager.curImage.AddListener(peer)
|
||||
manager.curPosition.AddListener(peer)
|
||||
|
||||
// send initial cursor image
|
||||
cur, img, err := manager.curImage.GetCurrent()
|
||||
if err == nil {
|
||||
err := peer.SendCursorImage(cur, img)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to set cursor image")
|
||||
}
|
||||
} else {
|
||||
logger.Err(err).Msg("failed to get cursor image")
|
||||
}
|
||||
|
||||
// send initial cursor position
|
||||
x, y := manager.desktop.GetCursorPosition()
|
||||
err = peer.SendCursorPosition(x, y)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to set cursor position")
|
||||
}
|
||||
})
|
||||
|
||||
dataChannel.OnClose(func() {
|
||||
manager.curImage.RemoveListener(peer)
|
||||
manager.curPosition.RemoveListener(peer)
|
||||
})
|
||||
|
||||
dataChannel.OnMessage(func(message webrtc.DataChannelMessage) {
|
||||
if err := manager.handle(logger, message.Data, dataChannel, session); err != nil {
|
||||
logger.Err(err).Msg("data handle failed")
|
||||
}
|
||||
})
|
||||
|
||||
session.SetWebRTCPeer(peer)
|
||||
|
||||
offer, err := peer.CreateOffer(false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// on negotiation needed handler must be registered after creating initial
|
||||
// offer, otherwise it can fire and intercept sucessful negotiation
|
||||
|
||||
connection.OnNegotiationNeeded(func() {
|
||||
logger.Warn().Msg("negotiation is needed")
|
||||
|
||||
if connection.SignalingState() != webrtc.SignalingStateStable {
|
||||
logger.Warn().Msg("connection isn't stable yet; postponing...")
|
||||
return
|
||||
}
|
||||
|
||||
offer, err := peer.CreateOffer(false)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("sdp offer failed")
|
||||
return
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_OFFER,
|
||||
message.SignalDescription{
|
||||
SDP: offer.SDP,
|
||||
})
|
||||
})
|
||||
|
||||
// start metrics collectors
|
||||
go metrics.rtcpReceiver(videoRtcp)
|
||||
go metrics.connectionStats(connection)
|
||||
|
||||
// start estimator reader
|
||||
go peer.estimatorReader()
|
||||
|
||||
return offer, peer, nil
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
|
||||
manager.curPosition.Set(x, y)
|
||||
}
|
458
server/internal/webrtc/metrics.go
Normal file
458
server/internal/webrtc/metrics.go
Normal file
@ -0,0 +1,458 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
const (
|
||||
// how often to read and process webrtc connection stats
|
||||
connectionStatsInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
type metricsManager struct {
|
||||
mu sync.Mutex
|
||||
|
||||
sessions map[string]*metrics
|
||||
}
|
||||
|
||||
func newMetricsManager() *metricsManager {
|
||||
return &metricsManager{
|
||||
sessions: map[string]*metrics{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessionId := session.ID()
|
||||
|
||||
met, ok := m.sessions[sessionId]
|
||||
if ok {
|
||||
return met
|
||||
}
|
||||
|
||||
met = &metrics{
|
||||
sessionId: sessionId,
|
||||
|
||||
connectionState: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "connection_state",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Connection state of session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
connectionStateCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "connection_state_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of connection state changes for a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
connectionCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "connection_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Connection count of a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
iceCandidates: map[string]struct{}{},
|
||||
iceCandidatesMu: &sync.Mutex{},
|
||||
iceCandidatesUdpCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ice_candidates_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of ICE candidates sent by a remote client.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "udp",
|
||||
},
|
||||
}),
|
||||
iceCandidatesTcpCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ice_candidates_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of ICE candidates sent by a remote client.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "tcp",
|
||||
},
|
||||
}),
|
||||
|
||||
iceCandidatesUsedUdp: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ice_candidates_used",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Used ICE candidates that are currently in use.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "udp",
|
||||
},
|
||||
}),
|
||||
iceCandidatesUsedTcp: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ice_candidates_used",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Used ICE candidates that are currently in use.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "tcp",
|
||||
},
|
||||
}),
|
||||
|
||||
videoIds: map[string]prometheus.Gauge{},
|
||||
videoIdsMu: &sync.Mutex{},
|
||||
|
||||
receiverEstimatedMaximumBitrate: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_estimated_maximum_bitrate",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverEstimatedTargetBitrate: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_estimated_target_bitrate",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Estimated Target Bitrate using Google's congestion control.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
receiverReportDelay: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_delay",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverReportJitter: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_jitter",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Jitter from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverReportTotalLost: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_total_lost",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Total Lost from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "transport_layer_nacks",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Transport Layer NACKs from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
iceBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_sent",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Sent bytes to a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "ice",
|
||||
},
|
||||
}),
|
||||
iceBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_received",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Received bytes from a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "ice",
|
||||
},
|
||||
}),
|
||||
|
||||
sctpBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_sent",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Sent bytes to a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "sctp",
|
||||
},
|
||||
}),
|
||||
sctpBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_received",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Received bytes from a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "sctp",
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
m.sessions[sessionId] = met
|
||||
return met
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
sessionId string
|
||||
|
||||
connectionState prometheus.Gauge
|
||||
connectionStateCount prometheus.Counter
|
||||
connectionCount prometheus.Counter
|
||||
|
||||
iceCandidates map[string]struct{}
|
||||
iceCandidatesMu *sync.Mutex
|
||||
iceCandidatesUdpCount prometheus.Counter
|
||||
iceCandidatesTcpCount prometheus.Counter
|
||||
|
||||
iceCandidatesUsedUdp prometheus.Gauge
|
||||
iceCandidatesUsedTcp prometheus.Gauge
|
||||
|
||||
videoIds map[string]prometheus.Gauge
|
||||
videoIdsMu *sync.Mutex
|
||||
|
||||
receiverEstimatedMaximumBitrate prometheus.Gauge
|
||||
receiverEstimatedTargetBitrate prometheus.Gauge
|
||||
|
||||
receiverReportDelay prometheus.Gauge
|
||||
receiverReportJitter prometheus.Gauge
|
||||
receiverReportTotalLost prometheus.Gauge
|
||||
|
||||
transportLayerNacks prometheus.Counter
|
||||
|
||||
iceBytesSent prometheus.Gauge
|
||||
iceBytesReceived prometheus.Gauge
|
||||
sctpBytesSent prometheus.Gauge
|
||||
sctpBytesReceived prometheus.Gauge
|
||||
}
|
||||
|
||||
func (met *metrics) reset() {
|
||||
met.videoIdsMu.Lock()
|
||||
for _, entry := range met.videoIds {
|
||||
entry.Set(0)
|
||||
}
|
||||
met.videoIdsMu.Unlock()
|
||||
|
||||
met.iceCandidatesUsedUdp.Set(float64(0))
|
||||
met.iceCandidatesUsedTcp.Set(float64(0))
|
||||
|
||||
met.receiverEstimatedMaximumBitrate.Set(0)
|
||||
|
||||
met.receiverReportDelay.Set(0)
|
||||
met.receiverReportJitter.Set(0)
|
||||
}
|
||||
|
||||
func (met *metrics) NewConnection() {
|
||||
met.connectionCount.Add(1)
|
||||
}
|
||||
|
||||
func (met *metrics) NewICECandidate(candidate webrtc.ICECandidateStats) {
|
||||
met.iceCandidatesMu.Lock()
|
||||
defer met.iceCandidatesMu.Unlock()
|
||||
|
||||
if _, found := met.iceCandidates[candidate.ID]; found {
|
||||
return
|
||||
}
|
||||
|
||||
met.iceCandidates[candidate.ID] = struct{}{}
|
||||
if candidate.Protocol == "udp" {
|
||||
met.iceCandidatesUdpCount.Add(1)
|
||||
} else if candidate.Protocol == "tcp" {
|
||||
met.iceCandidatesTcpCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) SetICECandidatesUsed(candidates []webrtc.ICECandidateStats) {
|
||||
udp, tcp := 0, 0
|
||||
for _, candidate := range candidates {
|
||||
if candidate.Protocol == "udp" {
|
||||
udp++
|
||||
} else if candidate.Protocol == "tcp" {
|
||||
tcp++
|
||||
}
|
||||
}
|
||||
|
||||
met.iceCandidatesUsedUdp.Set(float64(udp))
|
||||
met.iceCandidatesUsedTcp.Set(float64(tcp))
|
||||
}
|
||||
|
||||
func (met *metrics) SetState(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateNew:
|
||||
met.connectionState.Set(0)
|
||||
case webrtc.PeerConnectionStateConnecting:
|
||||
met.connectionState.Set(4)
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
met.connectionState.Set(5)
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
met.connectionState.Set(3)
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
met.connectionState.Set(2)
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
met.connectionState.Set(1)
|
||||
met.reset()
|
||||
default:
|
||||
met.connectionState.Set(-1)
|
||||
}
|
||||
|
||||
met.connectionStateCount.Add(1)
|
||||
}
|
||||
|
||||
func (met *metrics) SetVideoID(videoId string) {
|
||||
met.videoIdsMu.Lock()
|
||||
defer met.videoIdsMu.Unlock()
|
||||
|
||||
if _, found := met.videoIds[videoId]; !found {
|
||||
met.videoIds[videoId] = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "video_listeners",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Listeners for Video pipelines by a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": met.sessionId,
|
||||
"video_id": videoId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for id, entry := range met.videoIds {
|
||||
if id == videoId {
|
||||
entry.Set(1)
|
||||
} else {
|
||||
entry.Set(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverEstimatedMaximumBitrate(bitrate float32) {
|
||||
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverEstimatedTargetBitrate(bitrate float64) {
|
||||
met.receiverEstimatedTargetBitrate.Set(bitrate)
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverReport(report rtcp.ReceptionReport) {
|
||||
met.receiverReportDelay.Set(float64(report.Delay))
|
||||
met.receiverReportJitter.Set(float64(report.Jitter))
|
||||
met.receiverReportTotalLost.Set(float64(report.TotalLost))
|
||||
}
|
||||
|
||||
func (met *metrics) SetIceTransportStats(data webrtc.TransportStats) {
|
||||
met.iceBytesSent.Set(float64(data.BytesSent))
|
||||
met.iceBytesReceived.Set(float64(data.BytesReceived))
|
||||
}
|
||||
|
||||
func (met *metrics) SetSctpTransportStats(data webrtc.TransportStats) {
|
||||
met.sctpBytesSent.Set(float64(data.BytesSent))
|
||||
met.sctpBytesReceived.Set(float64(data.BytesReceived))
|
||||
}
|
||||
|
||||
//
|
||||
// collectors
|
||||
//
|
||||
|
||||
func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
|
||||
for {
|
||||
packets, ok := <-rtcpCh
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
for _, p := range packets {
|
||||
switch rtcpPacket := p.(type) {
|
||||
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
|
||||
met.SetReceiverEstimatedMaximumBitrate(rtcpPacket.Bitrate)
|
||||
|
||||
case *rtcp.ReceiverReport:
|
||||
l := len(rtcpPacket.Reports)
|
||||
if l > 0 {
|
||||
// use only last report
|
||||
met.SetReceiverReport(rtcpPacket.Reports[l-1])
|
||||
}
|
||||
case *rtcp.TransportLayerNack:
|
||||
for _, pair := range rtcpPacket.Nacks {
|
||||
packetList := pair.PacketList()
|
||||
met.transportLayerNacks.Add(float64(len(packetList)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) connectionStats(connection *webrtc.PeerConnection) {
|
||||
ticker := time.NewTicker(connectionStatsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
|
||||
stats := connection.GetStats()
|
||||
|
||||
data, ok := stats["iceTransport"].(webrtc.TransportStats)
|
||||
if ok {
|
||||
met.SetIceTransportStats(data)
|
||||
}
|
||||
|
||||
data, ok = stats["sctpTransport"].(webrtc.TransportStats)
|
||||
if ok {
|
||||
met.SetSctpTransportStats(data)
|
||||
}
|
||||
|
||||
remoteCandidates := map[string]webrtc.ICECandidateStats{}
|
||||
nominatedRemoteCandidates := map[string]struct{}{}
|
||||
for _, entry := range stats {
|
||||
// only remote ice candidate stats
|
||||
candidate, ok := entry.(webrtc.ICECandidateStats)
|
||||
if ok && candidate.Type == webrtc.StatsTypeRemoteCandidate {
|
||||
met.NewICECandidate(candidate)
|
||||
remoteCandidates[candidate.ID] = candidate
|
||||
}
|
||||
|
||||
// only nominated ice candidate pair stats
|
||||
pair, ok := entry.(webrtc.ICECandidatePairStats)
|
||||
if ok && pair.Nominated {
|
||||
nominatedRemoteCandidates[pair.RemoteCandidateID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
iceCandidatesUsed := []webrtc.ICECandidateStats{}
|
||||
for id := range nominatedRemoteCandidates {
|
||||
if candidate, ok := remoteCandidates[id]; ok {
|
||||
iceCandidatesUsed = append(iceCandidatesUsed, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
met.SetICECandidatesUsed(iceCandidatesUsed)
|
||||
}
|
||||
}
|
55
server/internal/webrtc/payload/receive.go
Normal file
55
server/internal/webrtc/payload/receive.go
Normal file
@ -0,0 +1,55 @@
|
||||
package payload
|
||||
|
||||
import "math"
|
||||
|
||||
const (
|
||||
OP_MOVE = 0x01
|
||||
OP_SCROLL = 0x02
|
||||
OP_KEY_DOWN = 0x03
|
||||
OP_KEY_UP = 0x04
|
||||
OP_BTN_DOWN = 0x05
|
||||
OP_BTN_UP = 0x06
|
||||
OP_PING = 0x07
|
||||
// touch events
|
||||
OP_TOUCH_BEGIN = 0x08
|
||||
OP_TOUCH_UPDATE = 0x09
|
||||
OP_TOUCH_END = 0x0a
|
||||
)
|
||||
|
||||
type Move struct {
|
||||
X uint16
|
||||
Y uint16
|
||||
}
|
||||
|
||||
// TODO: remove this once the client is fixed
|
||||
type Scroll_Old struct {
|
||||
X int16
|
||||
Y int16
|
||||
}
|
||||
|
||||
type Scroll struct {
|
||||
DeltaX int16
|
||||
DeltaY int16
|
||||
ControlKey bool
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
Key uint32
|
||||
}
|
||||
|
||||
type Ping struct {
|
||||
// client's timestamp split into two uint32
|
||||
ClientTs1 uint32
|
||||
ClientTs2 uint32
|
||||
}
|
||||
|
||||
func (p Ping) ClientTs() uint64 {
|
||||
return (uint64(p.ClientTs1) * uint64(math.MaxUint32)) + uint64(p.ClientTs2)
|
||||
}
|
||||
|
||||
type Touch struct {
|
||||
TouchId uint32
|
||||
X int32
|
||||
Y int32
|
||||
Pressure uint8
|
||||
}
|
33
server/internal/webrtc/payload/send.go
Normal file
33
server/internal/webrtc/payload/send.go
Normal file
@ -0,0 +1,33 @@
|
||||
package payload
|
||||
|
||||
import "math"
|
||||
|
||||
const (
|
||||
OP_CURSOR_POSITION = 0x01
|
||||
OP_CURSOR_IMAGE = 0x02
|
||||
OP_PONG = 0x03
|
||||
)
|
||||
|
||||
type CursorPosition struct {
|
||||
X uint16
|
||||
Y uint16
|
||||
}
|
||||
|
||||
type CursorImage struct {
|
||||
Width uint16
|
||||
Height uint16
|
||||
Xhot uint16
|
||||
Yhot uint16
|
||||
}
|
||||
|
||||
type Pong struct {
|
||||
Ping
|
||||
|
||||
// server's timestamp split into two uint32
|
||||
ServerTs1 uint32
|
||||
ServerTs2 uint32
|
||||
}
|
||||
|
||||
func (p Pong) ServerTs() uint64 {
|
||||
return (uint64(p.ServerTs1) * uint64(math.MaxUint32)) + uint64(p.ServerTs2)
|
||||
}
|
6
server/internal/webrtc/payload/types.go
Normal file
6
server/internal/webrtc/payload/types.go
Normal file
@ -0,0 +1,6 @@
|
||||
package payload
|
||||
|
||||
type Header struct {
|
||||
Event uint8
|
||||
Length uint16
|
||||
}
|
543
server/internal/webrtc/peer.go
Normal file
543
server/internal/webrtc/peer.go
Normal file
@ -0,0 +1,543 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/interceptor/pkg/cc"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type WebRTCPeerCtx struct {
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
session types.Session
|
||||
metrics *metrics
|
||||
connection *webrtc.PeerConnection
|
||||
// bandwidth estimator
|
||||
estimator cc.BandwidthEstimator
|
||||
estimateTrend *utils.TrendDetector
|
||||
// stream selectors
|
||||
video types.StreamSelectorManager
|
||||
audio types.StreamSinkManager
|
||||
// tracks & channels
|
||||
audioTrack *Track
|
||||
videoTrack *Track
|
||||
dataChannel *webrtc.DataChannel
|
||||
rtcpChannel chan []rtcp.Packet
|
||||
// config
|
||||
iceTrickle bool
|
||||
estimatorConfig config.WebRTCEstimator
|
||||
paused bool
|
||||
videoAuto bool
|
||||
videoDisabled bool
|
||||
audioDisabled bool
|
||||
}
|
||||
|
||||
//
|
||||
// connection
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
offer, err := peer.connection.CreateOffer(&webrtc.OfferOptions{
|
||||
ICERestart: ICERestart,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peer.setLocalDescription(offer)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) CreateAnswer() (*webrtc.SessionDescription, error) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
answer, err := peer.connection.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peer.setLocalDescription(answer)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) setLocalDescription(description webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
|
||||
if !peer.iceTrickle {
|
||||
// Create channel that is blocked until ICE Gathering is complete
|
||||
gatherComplete := webrtc.GatheringCompletePromise(peer.connection)
|
||||
|
||||
if err := peer.connection.SetLocalDescription(description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
<-gatherComplete
|
||||
} else {
|
||||
if err := peer.connection.SetLocalDescription(description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return peer.connection.LocalDescription(), nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetRemoteDescription(desc webrtc.SessionDescription) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.connection.SetRemoteDescription(desc)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.connection.AddICECandidate(candidate)
|
||||
}
|
||||
|
||||
// TODO: Add shutdown function?
|
||||
func (peer *WebRTCPeerCtx) Destroy() {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
var err error
|
||||
|
||||
// if peer connection is not closed, close it
|
||||
if peer.connection.ConnectionState() != webrtc.PeerConnectionStateClosed {
|
||||
err = peer.connection.Close()
|
||||
}
|
||||
|
||||
peer.logger.Err(err).Msg("peer connection destroyed")
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) estimatorReader() {
|
||||
conf := peer.estimatorConfig
|
||||
|
||||
// if estimator is not in debug mode, use a nop logger
|
||||
var debugLogger zerolog.Logger
|
||||
if conf.Debug {
|
||||
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
|
||||
} else {
|
||||
debugLogger = zerolog.Nop()
|
||||
}
|
||||
|
||||
// if estimator is disabled, do nothing
|
||||
if peer.estimator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// use a ticker to get current client target bitrate
|
||||
ticker := time.NewTicker(conf.ReadInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// since when is the estimate stable/unstable
|
||||
stableSince := time.Now() // we asume stable at start
|
||||
unstableSince := time.Time{}
|
||||
// since when are we neutral but cannot accomodate current bitrate
|
||||
// we migt be stalled or estimator just reached zer (very bad connection)
|
||||
stalledSince := time.Time{}
|
||||
// when was the last upgrade/downgrade
|
||||
lastUpgradeTime := time.Time{}
|
||||
lastDowngradeTime := time.Time{}
|
||||
|
||||
for range ticker.C {
|
||||
targetBitrate := peer.estimator.GetTargetBitrate()
|
||||
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
|
||||
|
||||
// if peer connection is closed, stop reading
|
||||
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
|
||||
// if estimation or video is disabled, do nothing
|
||||
if !peer.videoAuto || peer.videoDisabled || peer.paused || conf.Passive {
|
||||
continue
|
||||
}
|
||||
|
||||
// get trend direction to decide if we should upgrade or downgrade
|
||||
peer.estimateTrend.AddValue(int64(targetBitrate))
|
||||
direction := peer.estimateTrend.GetDirection()
|
||||
|
||||
// get current stream bitrate
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if !ok {
|
||||
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
|
||||
continue
|
||||
}
|
||||
|
||||
// if stream bitrate is 0, we need to wait for some time until we get a valid value
|
||||
streamId, streamBitrate := stream.ID(), stream.Bitrate()
|
||||
if streamBitrate == 0 {
|
||||
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
|
||||
continue
|
||||
}
|
||||
|
||||
// check whats the difference between target and stream bitrate
|
||||
diff := float64(targetBitrate) / float64(streamBitrate)
|
||||
|
||||
debugLogger.Info().
|
||||
Float64("diff", diff).
|
||||
Int("target_bitrate", targetBitrate).
|
||||
Uint64("stream_bitrate", streamBitrate).
|
||||
Str("direction", direction.String()).
|
||||
Msg("got bitrate from estimator")
|
||||
|
||||
// if we can accomodate current stream or we are not netural anymore,
|
||||
// we are not stalled so we reset the stalled time
|
||||
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
|
||||
stalledSince = time.Now()
|
||||
}
|
||||
|
||||
// if we are neutral and stalled for too long, we might be congesting
|
||||
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
|
||||
if stalled {
|
||||
debugLogger.Warn().
|
||||
Time("stalled_since", stalledSince).
|
||||
Msgf("it looks like we are stalled")
|
||||
}
|
||||
|
||||
// if we have an downward trend or are stalled, we might be congesting
|
||||
if direction == utils.TrendDirectionDownward || stalled {
|
||||
// we reset the stable time because we are congesting
|
||||
stableSince = time.Now()
|
||||
|
||||
// if we downgraded recently, we wait for some more time
|
||||
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_downgrade", lastDowngradeTime).
|
||||
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not unstable but we fluctuate we should wait for some more time
|
||||
if time.Since(unstableSince) < conf.UnstableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("unstable_since", unstableSince).
|
||||
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we still have a big difference between target and stream bitrate, we wait for some more time
|
||||
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("we still have a big difference between target and stream bitrate, " +
|
||||
"therefore we still should be able to accomodate current stream")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.PeerVideoRequest{
|
||||
Selector: &types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeLower,
|
||||
},
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
|
||||
}
|
||||
lastDowngradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the lowest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("downgraded video stream")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// we reset the unstable time because we are not congesting
|
||||
unstableSince = time.Now()
|
||||
|
||||
// if we have a neutral or upward trend, that means our estimate is stable
|
||||
// if we are on the highest stream, we don't need to do anything
|
||||
// but if there is a higher stream, we should try to upgrade and see if it works
|
||||
|
||||
// if we upgraded recently, we wait for some more time
|
||||
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_upgrade", lastUpgradeTime).
|
||||
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not stable for long enough, we wait for some more time
|
||||
// because bandwidth estimation might fluctuate
|
||||
if time.Since(stableSince) < conf.StableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("stable_since", stableSince).
|
||||
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// upgrade only if estimated bitrate passed the threshold
|
||||
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
|
||||
"therefore we should wait for some more time")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.PeerVideoRequest{
|
||||
Selector: &types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeHigher,
|
||||
},
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
|
||||
}
|
||||
lastUpgradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the highest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("upgraded video stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
peer.videoTrack.SetPaused(isPaused || peer.videoDisabled)
|
||||
peer.audioTrack.SetPaused(isPaused || peer.audioDisabled)
|
||||
|
||||
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
|
||||
peer.paused = isPaused
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Paused() bool {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.paused
|
||||
}
|
||||
|
||||
//
|
||||
// video
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideo(r types.PeerVideoRequest) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
modified := false
|
||||
|
||||
// video disabled
|
||||
if r.Disabled != nil {
|
||||
disabled := *r.Disabled
|
||||
|
||||
// update only if changed
|
||||
if peer.videoDisabled != disabled {
|
||||
peer.videoDisabled = disabled
|
||||
peer.videoTrack.SetPaused(disabled || peer.paused)
|
||||
|
||||
peer.logger.Info().Bool("disabled", disabled).Msg("set video disabled")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// video selector
|
||||
if r.Selector != nil {
|
||||
selector := *r.Selector
|
||||
|
||||
// get requested video stream from selector
|
||||
stream, ok := peer.video.GetStream(selector)
|
||||
if !ok {
|
||||
return types.ErrWebRTCStreamNotFound
|
||||
}
|
||||
|
||||
// set video stream to track
|
||||
changed, err := peer.videoTrack.SetStream(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update only if stream changed
|
||||
if changed {
|
||||
videoID := stream.ID()
|
||||
peer.metrics.SetVideoID(videoID)
|
||||
|
||||
peer.logger.Info().Str("video_id", videoID).Msg("set video")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// video auto
|
||||
if r.Auto != nil {
|
||||
videoAuto := *r.Auto
|
||||
|
||||
if peer.estimator == nil || peer.estimatorConfig.Passive {
|
||||
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
|
||||
videoAuto = false // ensure video auto is disabled
|
||||
}
|
||||
|
||||
// update only if video auto changed
|
||||
if peer.videoAuto != videoAuto {
|
||||
peer.videoAuto = videoAuto
|
||||
|
||||
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// send video signal if modified
|
||||
if modified {
|
||||
go func() {
|
||||
// in goroutine because of mutex and we don't want to block
|
||||
peer.session.Send(event.SIGNAL_VIDEO, peer.Video())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Video() types.PeerVideo {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// get current video stream ID
|
||||
ID := ""
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if ok {
|
||||
ID = stream.ID()
|
||||
}
|
||||
|
||||
return types.PeerVideo{
|
||||
Disabled: peer.videoDisabled,
|
||||
ID: ID,
|
||||
Video: ID, // TODO: Remove, used for backward compatibility
|
||||
Auto: peer.videoAuto,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// audio
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetAudio(r types.PeerAudioRequest) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
modified := false
|
||||
|
||||
// audio disabled
|
||||
if r.Disabled != nil {
|
||||
disabled := *r.Disabled
|
||||
|
||||
// update only if changed
|
||||
if peer.audioDisabled != disabled {
|
||||
peer.audioDisabled = disabled
|
||||
peer.audioTrack.SetPaused(disabled || peer.paused)
|
||||
|
||||
peer.logger.Info().Bool("disabled", disabled).Msg("set audio disabled")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// send video signal if modified
|
||||
if modified {
|
||||
go func() {
|
||||
// in goroutine because of mutex and we don't want to block
|
||||
peer.session.Send(event.SIGNAL_AUDIO, peer.Audio())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Audio() types.PeerAudio {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return types.PeerAudio{
|
||||
Disabled: peer.audioDisabled,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// data channel
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// do not send cursor position to host
|
||||
if peer.session.IsHost() {
|
||||
return nil
|
||||
}
|
||||
|
||||
header := payload.Header{
|
||||
Event: payload.OP_CURSOR_POSITION,
|
||||
Length: 7,
|
||||
}
|
||||
|
||||
data := payload.CursorPosition{
|
||||
X: uint16(x),
|
||||
Y: uint16(y),
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return peer.dataChannel.Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SendCursorImage(cur *types.CursorImage, img []byte) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
header := payload.Header{
|
||||
Event: payload.OP_CURSOR_IMAGE,
|
||||
Length: uint16(11 + len(img)),
|
||||
}
|
||||
|
||||
data := payload.CursorImage{
|
||||
Width: cur.Width,
|
||||
Height: cur.Height,
|
||||
Xhot: cur.Xhot,
|
||||
Yhot: cur.Yhot,
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, img); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return peer.dataChannel.Send(buffer.Bytes())
|
||||
}
|
27
server/internal/webrtc/pionlog/factory.go
Normal file
27
server/internal/webrtc/pionlog/factory.go
Normal file
@ -0,0 +1,27 @@
|
||||
package pionlog
|
||||
|
||||
import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func New(logger zerolog.Logger) Factory {
|
||||
return Factory{
|
||||
Logger: logger.With().Str("submodule", "pion").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
Logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (l Factory) NewLogger(subsystem string) logging.LeveledLogger {
|
||||
if subsystem == "sctp" {
|
||||
return nulllog{}
|
||||
}
|
||||
|
||||
return logger{
|
||||
subsystem: subsystem,
|
||||
logger: l.Logger.With().Str("subsystem", subsystem).Logger(),
|
||||
}
|
||||
}
|
66
server/internal/webrtc/pionlog/logger.go
Normal file
66
server/internal/webrtc/pionlog/logger.go
Normal file
@ -0,0 +1,66 @@
|
||||
package pionlog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
logger zerolog.Logger
|
||||
subsystem string
|
||||
}
|
||||
|
||||
func (l logger) Trace(msg string) {
|
||||
l.logger.Trace().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Tracef(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Trace().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Debug(msg string) {
|
||||
l.logger.Debug().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Debugf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Debug().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Info(msg string) {
|
||||
if strings.Contains(msg, "duplicated packet") {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Infof(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if strings.Contains(msg, "duplicated packet") {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Warn(msg string) {
|
||||
l.logger.Warn().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Warnf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Warn().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Error(msg string) {
|
||||
l.logger.Error().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Errorf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Error().Msg(strings.TrimSpace(msg))
|
||||
}
|
14
server/internal/webrtc/pionlog/nullog.go
Normal file
14
server/internal/webrtc/pionlog/nullog.go
Normal file
@ -0,0 +1,14 @@
|
||||
package pionlog
|
||||
|
||||
type nulllog struct{}
|
||||
|
||||
func (l nulllog) Trace(msg string) {}
|
||||
func (l nulllog) Tracef(format string, args ...any) {}
|
||||
func (l nulllog) Debug(msg string) {}
|
||||
func (l nulllog) Debugf(format string, args ...any) {}
|
||||
func (l nulllog) Info(msg string) {}
|
||||
func (l nulllog) Infof(format string, args ...any) {}
|
||||
func (l nulllog) Warn(msg string) {}
|
||||
func (l nulllog) Warnf(format string, args ...any) {}
|
||||
func (l nulllog) Error(msg string) {}
|
||||
func (l nulllog) Errorf(format string, args ...any) {}
|
203
server/internal/webrtc/track.go
Normal file
203
server/internal/webrtc/track.go
Normal file
@ -0,0 +1,203 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pion/webrtc/v3/pkg/media"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type Track struct {
|
||||
logger zerolog.Logger
|
||||
track *webrtc.TrackLocalStaticSample
|
||||
|
||||
rtcpCh chan []rtcp.Packet
|
||||
sample chan types.Sample
|
||||
|
||||
paused bool
|
||||
stream types.StreamSinkManager
|
||||
streamMu sync.Mutex
|
||||
}
|
||||
|
||||
type trackOption func(*Track)
|
||||
|
||||
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
|
||||
return func(t *Track) {
|
||||
t.rtcpCh = rtcp
|
||||
}
|
||||
}
|
||||
|
||||
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...trackOption) (*Track, error) {
|
||||
id := codec.Type.String()
|
||||
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Track{
|
||||
logger: logger.With().Str("id", id).Logger(),
|
||||
track: track,
|
||||
rtcpCh: nil,
|
||||
sample: make(chan types.Sample),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
sender, err := connection.AddTrack(t.track)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go t.rtcpReader(sender)
|
||||
go t.sampleReader()
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Track) Shutdown() {
|
||||
t.RemoveStream()
|
||||
close(t.sample)
|
||||
}
|
||||
|
||||
func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
|
||||
for {
|
||||
packets, _, err := sender.ReadRTCP()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
t.logger.Debug().Msg("track rtcp reader closed")
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Warn().Err(err).Msg("failed to read track rtcp")
|
||||
continue
|
||||
}
|
||||
|
||||
if t.rtcpCh != nil {
|
||||
t.rtcpCh <- packets
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- sample ---
|
||||
|
||||
func (t *Track) sampleReader() {
|
||||
for {
|
||||
sample, ok := <-t.sample
|
||||
if !ok {
|
||||
t.logger.Debug().Msg("track sample reader closed")
|
||||
return
|
||||
}
|
||||
|
||||
err := t.track.WriteSample(media.Sample{
|
||||
Data: sample.Data,
|
||||
Duration: sample.Duration,
|
||||
Timestamp: sample.Timestamp,
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.logger.Warn().Err(err).Msg("failed to write sample to track")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Track) WriteSample(sample types.Sample) {
|
||||
t.sample <- sample
|
||||
}
|
||||
|
||||
// --- stream ---
|
||||
|
||||
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if we already listen to the stream, do nothing
|
||||
if t.stream == stream {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// if paused, we switch the stream but don't add the listener
|
||||
if t.paused {
|
||||
t.stream = stream
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if t.stream != nil {
|
||||
err = t.stream.MoveListenerTo(t, stream)
|
||||
} else {
|
||||
err = stream.AddListener(t)
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
t.stream = stream
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (t *Track) RemoveStream() {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if there is no stream, or paused we don't need to remove the listener
|
||||
if t.stream == nil || t.paused {
|
||||
t.stream = nil
|
||||
return
|
||||
}
|
||||
|
||||
err := t.stream.RemoveListener(t)
|
||||
if err != nil {
|
||||
t.logger.Warn().Err(err).Msg("failed to remove listener from stream")
|
||||
}
|
||||
|
||||
t.stream = nil
|
||||
}
|
||||
|
||||
func (t *Track) Stream() (types.StreamSinkManager, bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.stream, t.stream != nil
|
||||
}
|
||||
|
||||
// --- paused ---
|
||||
|
||||
func (t *Track) SetPaused(paused bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if there is no state change or no stream, do nothing
|
||||
if t.paused == paused || t.stream == nil {
|
||||
t.paused = paused
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if paused {
|
||||
err = t.stream.RemoveListener(t)
|
||||
} else {
|
||||
err = t.stream.AddListener(t)
|
||||
}
|
||||
if err != nil {
|
||||
t.logger.Warn().Err(err).Msg("failed to change listener state")
|
||||
return
|
||||
}
|
||||
|
||||
t.paused = paused
|
||||
}
|
||||
|
||||
func (t *Track) Paused() bool {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.paused
|
||||
}
|
Reference in New Issue
Block a user