diff --git a/internal/webrtc/handler.go b/internal/webrtc/handler.go index 34df576e..28a64a1e 100644 --- a/internal/webrtc/handler.go +++ b/internal/webrtc/handler.go @@ -3,12 +3,15 @@ package webrtc import ( "bytes" "encoding/binary" + "math" + "time" "github.com/demodesk/neko/internal/webrtc/payload" "github.com/demodesk/neko/pkg/types" + "github.com/pion/webrtc/v3" ) -func (manager *WebRTCManagerCtx) handle(data []byte, session types.Session) error { +func (manager *WebRTCManagerCtx) handle(data []byte, dataChannel *webrtc.DataChannel, session types.Session) error { // add session id to logger context logger := manager.logger.With().Str("session_id", session.ID()).Logger() @@ -55,6 +58,34 @@ func (manager *WebRTCManagerCtx) handle(data []byte, session types.Session) erro } return nil + } else if header.Event == payload.OP_PING { + ping := &payload.Ping{} + if err := binary.Read(buffer, binary.BigEndian, ping); err != nil { + return err + } + + // change header event to pong + ping.Header = payload.Header{ + Event: payload.OP_PONG, + Length: 19, + } + + // generate server timestamp + serverTs := uint64(time.Now().UnixMilli()) + + // generate pong payload + pong := payload.Pong{ + Ping: *ping, + ServerTs1: uint32(serverTs / math.MaxUint32), + ServerTs2: uint32(serverTs % math.MaxUint32), + } + + buffer := &bytes.Buffer{} + if err := binary.Write(buffer, binary.BigEndian, pong); err != nil { + return err + } + + return dataChannel.Send(buffer.Bytes()) } // continue only if session is host diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 794fda3a..c4f771b6 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -475,7 +475,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, bitrate int) }) dataChannel.OnMessage(func(message webrtc.DataChannelMessage) { - if err := manager.handle(message.Data, session); err != nil { + if err := manager.handle(message.Data, dataChannel, session); err != nil { logger.Err(err).Msg("data handle failed") } }) diff --git a/internal/webrtc/payload/receive.go b/internal/webrtc/payload/receive.go index 8cbff5c8..856c067c 100644 --- a/internal/webrtc/payload/receive.go +++ b/internal/webrtc/payload/receive.go @@ -1,5 +1,7 @@ package payload +import "math" + const ( OP_MOVE = 0x01 OP_SCROLL = 0x02 @@ -7,6 +9,7 @@ const ( OP_KEY_UP = 0x04 OP_BTN_DOWN = 0x05 OP_BTN_UP = 0x06 + OP_PING = 0x07 ) type Move struct { @@ -28,3 +31,15 @@ type Key struct { Key uint32 } + +type Ping struct { + Header + + // client's timestamp split into two uint32 + ClientTs1 uint32 + ClientTs2 uint32 +} + +func (p Ping) ClientTs() uint64 { + return (uint64(p.ClientTs1) * uint64(math.MaxUint32)) + uint64(p.ClientTs2) +} diff --git a/internal/webrtc/payload/send.go b/internal/webrtc/payload/send.go index b3436f56..b539084e 100644 --- a/internal/webrtc/payload/send.go +++ b/internal/webrtc/payload/send.go @@ -1,8 +1,11 @@ package payload +import "math" + const ( OP_CURSOR_POSITION = 0x01 OP_CURSOR_IMAGE = 0x02 + OP_PONG = 0x03 ) type CursorPosition struct { @@ -20,3 +23,15 @@ type CursorImage struct { Xhot uint16 Yhot uint16 } + +type Pong struct { + Ping + + // server's timestamp split into two uint32 + ServerTs1 uint32 + ServerTs2 uint32 +} + +func (p Pong) ServerTs() uint64 { + return (uint64(p.ServerTs1) * uint64(math.MaxUint32)) + uint64(p.ServerTs2) +}