refactro cursor image and pos.

This commit is contained in:
Miroslav Šedivý 2023-04-17 00:42:29 +02:00
parent 728e27da34
commit e8aab98012
4 changed files with 139 additions and 104 deletions

View File

@ -5,43 +5,54 @@ import (
"sync" "sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils" "github.com/demodesk/neko/pkg/utils"
) )
func NewImage(desktop types.DesktopManager) *ImageCtx { type ImageListener interface {
return &ImageCtx{ SendCursorImage(cur *types.CursorImage, img []byte) error
logger: log.With().Str("module", "webrtc").Str("submodule", "cursor-image").Logger(), }
type Image interface {
Start()
Shutdown()
GetCurrent() (cur *types.CursorImage, img []byte, err error)
AddListener(listener ImageListener)
RemoveListener(listener ImageListener)
}
type imageEntry struct {
*types.CursorImage
ImagePNG []byte
}
type image struct {
logger zerolog.Logger
desktop types.DesktopManager
listeners map[uintptr]ImageListener
listenersMu sync.RWMutex
cache map[uint64]*imageEntry
cacheMu sync.RWMutex
current *imageEntry
maxSerial uint64
}
func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image {
return &image{
logger: logger.With().Str("submodule", "cursor-image").Logger(),
desktop: desktop, desktop: desktop,
listeners: map[uintptr]*func(entry *ImageEntry){}, listeners: map[uintptr]ImageListener{},
cache: map[uint64]*ImageEntry{}, cache: map[uint64]*imageEntry{},
maxSerial: 300, // TODO: Cleanup? maxSerial: 300, // TODO: Cleanup?
} }
} }
type ImageCtx struct { func (manager *image) Start() {
logger zerolog.Logger
desktop types.DesktopManager
listeners map[uintptr]*func(entry *ImageEntry)
listenersMu sync.Mutex
cache map[uint64]*ImageEntry
cacheMu sync.Mutex
current *ImageEntry
maxSerial uint64
}
type ImageEntry struct {
Cursor *types.CursorImage
Image []byte
}
func (manager *ImageCtx) Start() {
manager.desktop.OnCursorChanged(func(serial uint64) { manager.desktop.OnCursorChanged(func(serial uint64) {
entry, err := manager.GetCached(serial) entry, err := manager.getCached(serial)
if err != nil { if err != nil {
manager.logger.Err(err).Msg("failed to get cursor image") manager.logger.Err(err).Msg("failed to get cursor image")
return return
@ -49,17 +60,19 @@ func (manager *ImageCtx) Start() {
manager.current = entry manager.current = entry
manager.listenersMu.Lock() manager.listenersMu.RLock()
for _, emit := range manager.listeners { for _, l := range manager.listeners {
(*emit)(entry) if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil {
manager.logger.Err(err).Msg("failed to set cursor image")
}
} }
manager.listenersMu.Unlock() manager.listenersMu.RUnlock()
}) })
manager.logger.Info().Msg("starting") manager.logger.Info().Msg("starting")
} }
func (manager *ImageCtx) Shutdown() { func (manager *image) Shutdown() {
manager.logger.Info().Msg("shutdown") manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock() manager.listenersMu.Lock()
@ -69,43 +82,57 @@ func (manager *ImageCtx) Shutdown() {
manager.listenersMu.Unlock() manager.listenersMu.Unlock()
} }
func (manager *ImageCtx) GetCached(serial uint64) (*ImageEntry, error) { func (manager *image) getCached(serial uint64) (*imageEntry, error) {
// zero means no serial available // zero means no serial available
if serial == 0 || serial > manager.maxSerial { if serial == 0 || serial > manager.maxSerial {
manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass") manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass")
return manager.fetchEntry() return manager.fetchEntry()
} }
manager.cacheMu.Lock() manager.cacheMu.RLock()
entry, ok := manager.cache[serial] entry, ok := manager.cache[serial]
manager.cacheMu.Unlock() manager.cacheMu.RUnlock()
if ok { if ok {
return entry, nil return entry, nil
} }
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss")
entry, err := manager.fetchEntry() entry, err := manager.fetchEntry()
if err != nil { if err != nil {
return nil, err return nil, err
} }
manager.cacheMu.Lock() manager.cacheMu.Lock()
manager.cache[serial] = entry manager.cache[entry.Serial] = entry
manager.cacheMu.Unlock() manager.cacheMu.Unlock()
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss") if entry.Serial != serial {
manager.logger.Warn().
Uint64("requested_serial", serial).
Uint64("received_serial", entry.Serial).
Msg("serial mismatch")
}
return entry, nil return entry, nil
} }
func (manager *ImageCtx) Get() (*ImageEntry, error) { func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) {
if manager.current != nil { if manager.current != nil {
return manager.current, nil return manager.current.CursorImage, manager.current.ImagePNG, nil
} }
return manager.fetchEntry() entry, err := manager.fetchEntry()
if err != nil {
return nil, nil, err
}
manager.current = entry
return entry.CursorImage, entry.ImagePNG, nil
} }
func (manager *ImageCtx) AddListener(listener *func(entry *ImageEntry)) { func (manager *image) AddListener(listener ImageListener) {
manager.listenersMu.Lock() manager.listenersMu.Lock()
defer manager.listenersMu.Unlock() defer manager.listenersMu.Unlock()
@ -115,7 +142,7 @@ func (manager *ImageCtx) AddListener(listener *func(entry *ImageEntry)) {
} }
} }
func (manager *ImageCtx) RemoveListener(listener *func(entry *ImageEntry)) { func (manager *image) RemoveListener(listener ImageListener) {
manager.listenersMu.Lock() manager.listenersMu.Lock()
defer manager.listenersMu.Unlock() defer manager.listenersMu.Unlock()
@ -125,18 +152,17 @@ func (manager *ImageCtx) RemoveListener(listener *func(entry *ImageEntry)) {
} }
} }
func (manager *ImageCtx) fetchEntry() (*ImageEntry, error) { func (manager *image) fetchEntry() (*imageEntry, error) {
cur := manager.desktop.GetCursorImage() cur := manager.desktop.GetCursorImage()
img, err := utils.CreatePNGImage(cur.Image) img, err := utils.CreatePNGImage(cur.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cur.Image = nil // free memory
entry := &ImageEntry{ return &imageEntry{
Cursor: cur, CursorImage: cur,
Image: img, ImagePNG: img,
} }, nil
return entry, nil
} }

View File

@ -5,24 +5,34 @@ import (
"sync" "sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
) )
func NewPosition() *PositionCtx { type PositionListener interface {
return &PositionCtx{ SendCursorPosition(x, y int) error
logger: log.With().Str("module", "webrtc").Str("submodule", "cursor-position").Logger(), }
listeners: map[uintptr]*func(x, y int){},
type Position interface {
Shutdown()
Set(x, y int)
AddListener(listener PositionListener)
RemoveListener(listener PositionListener)
}
type position struct {
logger zerolog.Logger
listeners map[uintptr]PositionListener
listenersMu sync.RWMutex
}
func NewPosition(logger zerolog.Logger) *position {
return &position{
logger: logger.With().Str("submodule", "cursor-position").Logger(),
listeners: map[uintptr]PositionListener{},
} }
} }
type PositionCtx struct { func (manager *position) Shutdown() {
logger zerolog.Logger
listeners map[uintptr]*func(x, y int)
listenersMu sync.Mutex
}
func (manager *PositionCtx) Shutdown() {
manager.logger.Info().Msg("shutdown") manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock() manager.listenersMu.Lock()
@ -32,16 +42,18 @@ func (manager *PositionCtx) Shutdown() {
manager.listenersMu.Unlock() manager.listenersMu.Unlock()
} }
func (manager *PositionCtx) Set(x, y int) { func (manager *position) Set(x, y int) {
manager.listenersMu.Lock() manager.listenersMu.RLock()
defer manager.listenersMu.Unlock() defer manager.listenersMu.RUnlock()
for _, emit := range manager.listeners { for _, l := range manager.listeners {
(*emit)(x, y) if err := l.SendCursorPosition(x, y); err != nil {
manager.logger.Err(err).Msg("failed to set cursor position")
}
} }
} }
func (manager *PositionCtx) AddListener(listener *func(x, y int)) { func (manager *position) AddListener(listener PositionListener) {
manager.listenersMu.Lock() manager.listenersMu.Lock()
defer manager.listenersMu.Unlock() defer manager.listenersMu.Unlock()
@ -51,7 +63,7 @@ func (manager *PositionCtx) AddListener(listener *func(x, y int)) {
} }
} }
func (manager *PositionCtx) RemoveListener(listener *func(x, y int)) { func (manager *position) RemoveListener(listener PositionListener) {
manager.listenersMu.Lock() manager.listenersMu.Lock()
defer manager.listenersMu.Unlock() defer manager.listenersMu.Unlock()

View File

@ -50,6 +50,8 @@ const (
) )
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx { func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
logger := log.With().Str("module", "webrtc").Logger()
configuration := webrtc.Configuration{ configuration := webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
} }
@ -75,7 +77,7 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con
} }
return &WebRTCManagerCtx{ return &WebRTCManagerCtx{
logger: log.With().Str("module", "webrtc").Logger(), logger: logger,
config: config, config: config,
metrics: newMetricsManager(), metrics: newMetricsManager(),
@ -83,8 +85,8 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con
desktop: desktop, desktop: desktop,
capture: capture, capture: capture,
curImage: cursor.NewImage(desktop), curImage: cursor.NewImage(logger, desktop),
curPosition: cursor.NewPosition(), curPosition: cursor.NewPosition(logger),
} }
} }
@ -96,8 +98,8 @@ type WebRTCManagerCtx struct {
desktop types.DesktopManager desktop types.DesktopManager
capture types.CaptureManager capture types.CaptureManager
curImage *cursor.ImageCtx curImage cursor.Image
curPosition *cursor.PositionCtx curPosition cursor.Position
webrtcConfiguration webrtc.Configuration webrtcConfiguration webrtc.Configuration
@ -168,7 +170,7 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServersFrontend return manager.config.ICEServersFrontend
} }
func (manager *WebRTCManagerCtx) newPeerConnection(bitrate int, codecs []codec.RTPCodec, logger zerolog.Logger) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec, bitrate int) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine // create media engine
engine := &webrtc.MediaEngine{} engine := &webrtc.MediaEngine{}
for _, codec := range codecs { for _, codec := range codecs {
@ -288,19 +290,14 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
video := manager.capture.Video() video := manager.capture.Video()
videoCodec := video.Codec() videoCodec := video.Codec()
connection, estimator, err := manager.newPeerConnection(bitrate, []codec.RTPCodec{ connection, estimator, err := manager.newPeerConnection(logger, []codec.RTPCodec{
audioCodec, audioCodec,
videoCodec, videoCodec,
}, logger) }, bitrate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// if bitrate is 0, and estimator is enabled, use estimator bitrate
if bitrate == 0 && estimator != nil {
bitrate = estimator.GetTargetBitrate()
}
// asynchronously send local ICE Candidates // asynchronously send local ICE Candidates
if manager.config.ICETrickle { if manager.config.ICETrickle {
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) { connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
@ -317,6 +314,11 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
}) })
} }
// if bitrate is 0, and estimator is enabled, use estimator bitrate
if bitrate == 0 && estimator != nil {
bitrate = estimator.GetTargetBitrate()
}
// audio track // audio track
audioTrack, err := NewTrack(logger, audioCodec, connection) audioTrack, err := NewTrack(logger, audioCodec, connection)
if err != nil { if err != nil {
@ -531,42 +533,32 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int,
metrics.SetState(state) metrics.SetState(state)
}) })
cursorImage := func(entry *cursor.ImageEntry) {
if err := peer.SendCursorImage(entry.Cursor, entry.Image); err != nil {
logger.Err(err).Msg("could not send cursor image")
}
}
cursorPosition := func(x, y int) {
if session.IsHost() {
return
}
if err := peer.SendCursorPosition(x, y); err != nil {
logger.Err(err).Msg("could not send cursor position")
}
}
dataChannel.OnOpen(func() { dataChannel.OnOpen(func() {
manager.curImage.AddListener(&cursorImage) manager.curImage.AddListener(peer)
manager.curPosition.AddListener(&cursorPosition) manager.curPosition.AddListener(peer)
// send initial cursor image // send initial cursor image
entry, err := manager.curImage.Get() cur, img, err := manager.curImage.GetCurrent()
if err == nil { if err == nil {
cursorImage(entry) err := peer.SendCursorImage(cur, img)
if err != nil {
logger.Err(err).Msg("failed to set cursor image")
}
} else { } else {
logger.Err(err).Msg("failed to get cursor image") logger.Err(err).Msg("failed to get cursor image")
} }
// send initial cursor position // send initial cursor position
x, y := manager.desktop.GetCursorPosition() x, y := manager.desktop.GetCursorPosition()
cursorPosition(x, y) err = peer.SendCursorPosition(x, y)
if err != nil {
logger.Err(err).Msg("failed to set cursor position")
}
}) })
dataChannel.OnClose(func() { dataChannel.OnClose(func() {
manager.curImage.RemoveListener(&cursorImage) manager.curImage.RemoveListener(peer)
manager.curPosition.RemoveListener(&cursorPosition) manager.curPosition.RemoveListener(peer)
}) })
dataChannel.OnMessage(func(message webrtc.DataChannelMessage) { dataChannel.OnMessage(func(message webrtc.DataChannelMessage) {

View File

@ -226,6 +226,11 @@ func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error {
peer.mu.Lock() peer.mu.Lock()
defer peer.mu.Unlock() defer peer.mu.Unlock()
// do not send cursor position to host
if peer.session.IsHost() {
return nil
}
header := payload.Header{ header := payload.Header{
Event: payload.OP_CURSOR_POSITION, Event: payload.OP_CURSOR_POSITION,
Length: 7, Length: 7,