diff --git a/internal/webrtc/cursor/image.go b/internal/webrtc/cursor/image.go index 36e04b5c..542bf9be 100644 --- a/internal/webrtc/cursor/image.go +++ b/internal/webrtc/cursor/image.go @@ -5,43 +5,54 @@ import ( "sync" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/utils" ) -func NewImage(desktop types.DesktopManager) *ImageCtx { - return &ImageCtx{ - logger: log.With().Str("module", "webrtc").Str("submodule", "cursor-image").Logger(), +type ImageListener interface { + SendCursorImage(cur *types.CursorImage, img []byte) error +} + +type Image interface { + Start() + Shutdown() + GetCurrent() (cur *types.CursorImage, img []byte, err error) + AddListener(listener ImageListener) + RemoveListener(listener ImageListener) +} + +type imageEntry struct { + *types.CursorImage + ImagePNG []byte +} + +type image struct { + logger zerolog.Logger + desktop types.DesktopManager + + listeners map[uintptr]ImageListener + listenersMu sync.RWMutex + + cache map[uint64]*imageEntry + cacheMu sync.RWMutex + current *imageEntry + maxSerial uint64 +} + +func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image { + return &image{ + logger: logger.With().Str("submodule", "cursor-image").Logger(), desktop: desktop, - listeners: map[uintptr]*func(entry *ImageEntry){}, - cache: map[uint64]*ImageEntry{}, + listeners: map[uintptr]ImageListener{}, + cache: map[uint64]*imageEntry{}, maxSerial: 300, // TODO: Cleanup? } } -type ImageCtx struct { - logger zerolog.Logger - desktop types.DesktopManager - - listeners map[uintptr]*func(entry *ImageEntry) - listenersMu sync.Mutex - - cache map[uint64]*ImageEntry - cacheMu sync.Mutex - current *ImageEntry - maxSerial uint64 -} - -type ImageEntry struct { - Cursor *types.CursorImage - Image []byte -} - -func (manager *ImageCtx) Start() { +func (manager *image) Start() { manager.desktop.OnCursorChanged(func(serial uint64) { - entry, err := manager.GetCached(serial) + entry, err := manager.getCached(serial) if err != nil { manager.logger.Err(err).Msg("failed to get cursor image") return @@ -49,17 +60,19 @@ func (manager *ImageCtx) Start() { manager.current = entry - manager.listenersMu.Lock() - for _, emit := range manager.listeners { - (*emit)(entry) + manager.listenersMu.RLock() + for _, l := range manager.listeners { + if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil { + manager.logger.Err(err).Msg("failed to set cursor image") + } } - manager.listenersMu.Unlock() + manager.listenersMu.RUnlock() }) manager.logger.Info().Msg("starting") } -func (manager *ImageCtx) Shutdown() { +func (manager *image) Shutdown() { manager.logger.Info().Msg("shutdown") manager.listenersMu.Lock() @@ -69,43 +82,57 @@ func (manager *ImageCtx) Shutdown() { manager.listenersMu.Unlock() } -func (manager *ImageCtx) GetCached(serial uint64) (*ImageEntry, error) { +func (manager *image) getCached(serial uint64) (*imageEntry, error) { // zero means no serial available if serial == 0 || serial > manager.maxSerial { manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass") return manager.fetchEntry() } - manager.cacheMu.Lock() + manager.cacheMu.RLock() entry, ok := manager.cache[serial] - manager.cacheMu.Unlock() + manager.cacheMu.RUnlock() if ok { return entry, nil } + manager.logger.Debug().Uint64("serial", serial).Msg("cache miss") + entry, err := manager.fetchEntry() if err != nil { return nil, err } manager.cacheMu.Lock() - manager.cache[serial] = entry + manager.cache[entry.Serial] = entry manager.cacheMu.Unlock() - manager.logger.Debug().Uint64("serial", serial).Msg("cache miss") + if entry.Serial != serial { + manager.logger.Warn(). + Uint64("requested_serial", serial). + Uint64("received_serial", entry.Serial). + Msg("serial mismatch") + } + return entry, nil } -func (manager *ImageCtx) Get() (*ImageEntry, error) { +func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) { if manager.current != nil { - return manager.current, nil + return manager.current.CursorImage, manager.current.ImagePNG, nil } - return manager.fetchEntry() + entry, err := manager.fetchEntry() + if err != nil { + return nil, nil, err + } + + manager.current = entry + return entry.CursorImage, entry.ImagePNG, nil } -func (manager *ImageCtx) AddListener(listener *func(entry *ImageEntry)) { +func (manager *image) AddListener(listener ImageListener) { manager.listenersMu.Lock() defer manager.listenersMu.Unlock() @@ -115,7 +142,7 @@ func (manager *ImageCtx) AddListener(listener *func(entry *ImageEntry)) { } } -func (manager *ImageCtx) RemoveListener(listener *func(entry *ImageEntry)) { +func (manager *image) RemoveListener(listener ImageListener) { manager.listenersMu.Lock() defer manager.listenersMu.Unlock() @@ -125,18 +152,17 @@ func (manager *ImageCtx) RemoveListener(listener *func(entry *ImageEntry)) { } } -func (manager *ImageCtx) fetchEntry() (*ImageEntry, error) { +func (manager *image) fetchEntry() (*imageEntry, error) { cur := manager.desktop.GetCursorImage() img, err := utils.CreatePNGImage(cur.Image) if err != nil { return nil, err } + cur.Image = nil // free memory - entry := &ImageEntry{ - Cursor: cur, - Image: img, - } - - return entry, nil + return &imageEntry{ + CursorImage: cur, + ImagePNG: img, + }, nil } diff --git a/internal/webrtc/cursor/position.go b/internal/webrtc/cursor/position.go index 6c4ea7e2..ac1147bc 100644 --- a/internal/webrtc/cursor/position.go +++ b/internal/webrtc/cursor/position.go @@ -5,24 +5,34 @@ import ( "sync" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) -func NewPosition() *PositionCtx { - return &PositionCtx{ - logger: log.With().Str("module", "webrtc").Str("submodule", "cursor-position").Logger(), - listeners: map[uintptr]*func(x, y int){}, +type PositionListener interface { + SendCursorPosition(x, y int) error +} + +type Position interface { + Shutdown() + Set(x, y int) + AddListener(listener PositionListener) + RemoveListener(listener PositionListener) +} + +type position struct { + logger zerolog.Logger + + listeners map[uintptr]PositionListener + listenersMu sync.RWMutex +} + +func NewPosition(logger zerolog.Logger) *position { + return &position{ + logger: logger.With().Str("submodule", "cursor-position").Logger(), + listeners: map[uintptr]PositionListener{}, } } -type PositionCtx struct { - logger zerolog.Logger - - listeners map[uintptr]*func(x, y int) - listenersMu sync.Mutex -} - -func (manager *PositionCtx) Shutdown() { +func (manager *position) Shutdown() { manager.logger.Info().Msg("shutdown") manager.listenersMu.Lock() @@ -32,16 +42,18 @@ func (manager *PositionCtx) Shutdown() { manager.listenersMu.Unlock() } -func (manager *PositionCtx) Set(x, y int) { - manager.listenersMu.Lock() - defer manager.listenersMu.Unlock() +func (manager *position) Set(x, y int) { + manager.listenersMu.RLock() + defer manager.listenersMu.RUnlock() - for _, emit := range manager.listeners { - (*emit)(x, y) + for _, l := range manager.listeners { + if err := l.SendCursorPosition(x, y); err != nil { + manager.logger.Err(err).Msg("failed to set cursor position") + } } } -func (manager *PositionCtx) AddListener(listener *func(x, y int)) { +func (manager *position) AddListener(listener PositionListener) { manager.listenersMu.Lock() defer manager.listenersMu.Unlock() @@ -51,7 +63,7 @@ func (manager *PositionCtx) AddListener(listener *func(x, y int)) { } } -func (manager *PositionCtx) RemoveListener(listener *func(x, y int)) { +func (manager *position) RemoveListener(listener PositionListener) { manager.listenersMu.Lock() defer manager.listenersMu.Unlock() diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 536d4fd3..cd948db6 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -50,6 +50,8 @@ const ( ) func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { + logger := log.With().Str("module", "webrtc").Logger() + configuration := webrtc.Configuration{ SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, } @@ -75,7 +77,7 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con } return &WebRTCManagerCtx{ - logger: log.With().Str("module", "webrtc").Logger(), + logger: logger, config: config, metrics: newMetricsManager(), @@ -83,8 +85,8 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con desktop: desktop, capture: capture, - curImage: cursor.NewImage(desktop), - curPosition: cursor.NewPosition(), + curImage: cursor.NewImage(logger, desktop), + curPosition: cursor.NewPosition(logger), } } @@ -96,8 +98,8 @@ type WebRTCManagerCtx struct { desktop types.DesktopManager capture types.CaptureManager - curImage *cursor.ImageCtx - curPosition *cursor.PositionCtx + curImage cursor.Image + curPosition cursor.Position webrtcConfiguration webrtc.Configuration @@ -168,7 +170,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer { return manager.config.ICEServersFrontend } -func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { +func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { // create media engine engine := &webrtc.MediaEngine{} for _, codec := range codecs { @@ -288,19 +290,14 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, video := manager.capture.Video() videoCodec := video.Codec() - connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{ + connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{ audioCodec, videoCodec, - }, logger) + }, bitrate) if err != nil { return nil, err } - // if bitrate is 0, and estimator is enabled, use estimator bitrate - if bitrate == 0 && estimator != nil { - bitrate = estimator.GetTargetBitrate() - } - // asynchronously send local ICE Candidates if manager.config.ICETrickle { connection.OnICECandidate(func(candidate *webrtc.ICECandidate) { @@ -317,6 +314,11 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, }) } + // if bitrate is 0, and estimator is enabled, use estimator bitrate + if bitrate == 0 && estimator != nil { + bitrate = estimator.GetTargetBitrate() + } + // audio track audioTrack, err := NewTrack(logger, audioCodec, connection) if err != nil { @@ -531,42 +533,32 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int, metrics.SetState(state) }) - cursorImage := func(entry *cursor.ImageEntry) { - if err := peer.SendCursorImage(entry.Cursor, entry.Image); err != nil { - logger.Err(err).Msg("could not send cursor image") - } - } - - cursorPosition := func(x, y int) { - if session.IsHost() { - return - } - - if err := peer.SendCursorPosition(x, y); err != nil { - logger.Err(err).Msg("could not send cursor position") - } - } - dataChannel.OnOpen(func() { - manager.curImage.AddListener(&cursorImage) - manager.curPosition.AddListener(&cursorPosition) + manager.curImage.AddListener(peer) + manager.curPosition.AddListener(peer) // send initial cursor image - entry, err := manager.curImage.Get() + cur, img, err := manager.curImage.GetCurrent() if err == nil { - cursorImage(entry) + err := peer.SendCursorImage(cur, img) + if err != nil { + logger.Err(err).Msg("failed to set cursor image") + } } else { logger.Err(err).Msg("failed to get cursor image") } // send initial cursor position x, y := manager.desktop.GetCursorPosition() - cursorPosition(x, y) + err = peer.SendCursorPosition(x, y) + if err != nil { + logger.Err(err).Msg("failed to set cursor position") + } }) dataChannel.OnClose(func() { - manager.curImage.RemoveListener(&cursorImage) - manager.curPosition.RemoveListener(&cursorPosition) + manager.curImage.RemoveListener(peer) + manager.curPosition.RemoveListener(peer) }) dataChannel.OnMessage(func(message webrtc.DataChannelMessage) { diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index 311276c4..16381e12 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -226,6 +226,11 @@ func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error { peer.mu.Lock() defer peer.mu.Unlock() + // do not send cursor position to host + if peer.session.IsHost() { + return nil + } + header := payload.Header{ Event: payload.OP_CURSOR_POSITION, Length: 7,