use websocket message.

This commit is contained in:
Miroslav Šedivý 2021-08-29 23:00:51 +02:00
parent 47d0359106
commit c82a083fb6
4 changed files with 42 additions and 33 deletions

View File

@ -6,11 +6,6 @@ import (
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
) )
type Message struct {
Event string `json:"event"`
Payload interface{} `json:"payload"` // TODO: New.
}
///////////////////////////// /////////////////////////////
// System // System
///////////////////////////// /////////////////////////////

View File

@ -1,8 +1,16 @@
package types package types
import "net/http" import (
"encoding/json"
"net/http"
)
type HandlerFunction func(Session, []byte) bool type WebSocketMessage struct {
Event string `json:"event"`
Payload json.RawMessage `json:"payload"`
}
type HandlerFunction func(Session, WebSocketMessage) bool
type CheckOrigin func(r *http.Request) bool type CheckOrigin func(r *http.Request) bool

View File

@ -1,8 +1,6 @@
package handler package handler
import ( import (
"encoding/json"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -35,38 +33,35 @@ type MessageHandlerCtx struct {
capture types.CaptureManager capture types.CaptureManager
} }
func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) bool { func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketMessage) bool {
logger := h.logger.With().Str("session_id", session.ID()).Logger() logger := h.logger.With().
Str("event", data.Event).
header := message.Message{} Str("session_id", session.ID()).
if err := json.Unmarshal(raw, &header); err != nil { Logger()
logger.Error().Err(err).Msg("message parsing has failed")
return false
}
var err error var err error
switch header.Event { switch data.Event {
// Signal Events // Signal Events
case event.SIGNAL_REQUEST: case event.SIGNAL_REQUEST:
payload := &message.SignalVideo{} payload := &message.SignalVideo{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalRequest(session, payload) return h.signalRequest(session, payload)
}) })
case event.SIGNAL_RESTART: case event.SIGNAL_RESTART:
err = h.signalRestart(session) err = h.signalRestart(session)
case event.SIGNAL_ANSWER: case event.SIGNAL_ANSWER:
payload := &message.SignalAnswer{} payload := &message.SignalAnswer{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalAnswer(session, payload) return h.signalAnswer(session, payload)
}) })
case event.SIGNAL_CANDIDATE: case event.SIGNAL_CANDIDATE:
payload := &message.SignalCandidate{} payload := &message.SignalCandidate{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalCandidate(session, payload) return h.signalCandidate(session, payload)
}) })
case event.SIGNAL_VIDEO: case event.SIGNAL_VIDEO:
payload := &message.SignalVideo{} payload := &message.SignalVideo{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalVideo(session, payload) return h.signalVideo(session, payload)
}) })
@ -79,38 +74,38 @@ func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) bool {
// Screen Events // Screen Events
case event.SCREEN_SET: case event.SCREEN_SET:
payload := &message.ScreenSize{} payload := &message.ScreenSize{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.screenSet(session, payload) return h.screenSet(session, payload)
}) })
// Clipboard Events // Clipboard Events
case event.CLIPBOARD_SET: case event.CLIPBOARD_SET:
payload := &message.ClipboardData{} payload := &message.ClipboardData{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.clipboardSet(session, payload) return h.clipboardSet(session, payload)
}) })
// Keyboard Events // Keyboard Events
case event.KEYBOARD_MAP: case event.KEYBOARD_MAP:
payload := &message.KeyboardMap{} payload := &message.KeyboardMap{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.keyboardMap(session, payload) return h.keyboardMap(session, payload)
}) })
case event.KEYBOARD_MODIFIERS: case event.KEYBOARD_MODIFIERS:
payload := &message.KeyboardModifiers{} payload := &message.KeyboardModifiers{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.keyboardModifiers(session, payload) return h.keyboardModifiers(session, payload)
}) })
// Send Events // Send Events
case event.SEND_UNICAST: case event.SEND_UNICAST:
payload := &message.SendUnicast{} payload := &message.SendUnicast{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.sendUnicast(session, payload) return h.sendUnicast(session, payload)
}) })
case event.SEND_BROADCAST: case event.SEND_BROADCAST:
payload := &message.SendBroadcast{} payload := &message.SendBroadcast{}
err = utils.Unmarshal(payload, raw, func() error { err = utils.Unmarshal(payload, data.Payload, func() error {
return h.sendBroadcast(session, payload) return h.sendBroadcast(session, payload)
}) })
default: default:
@ -118,7 +113,7 @@ func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) bool {
} }
if err != nil { if err != nil {
logger.Error().Err(err).Str("event", header.Event).Msg("message handler has failed") logger.Error().Err(err).Msg("message handler has failed")
} }
return true return true

View File

@ -1,6 +1,7 @@
package websocket package websocket
import ( import (
"encoding/json"
"net/http" "net/http"
"time" "time"
@ -296,22 +297,32 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session t
for { for {
select { select {
case raw := <-bytes: case raw := <-bytes:
data := types.WebSocketMessage{}
if err := json.Unmarshal(raw, &data); err != nil {
logger.Error().Err(err).Msg("message parsing has failed")
break
}
// TODO: Switch to payload based messages.
data.Payload = raw
logger.Debug(). logger.Debug().
Str("address", connection.RemoteAddr().String()). Str("address", connection.RemoteAddr().String()).
Str("raw", string(raw)). Str("event", data.Event).
Str("payload", string(data.Payload)).
Msg("received message from client") Msg("received message from client")
handled := manager.handler.Message(session, raw) handled := manager.handler.Message(session, data)
for _, handler := range manager.handlers { for _, handler := range manager.handlers {
if handled { if handled {
break break
} }
handled = handler(session, raw) handled = handler(session, data)
} }
if !handled { if !handled {
logger.Warn().Msg("unhandled message") logger.Warn().Str("event", data.Event).Msg("unhandled message")
} }
case <-cancel: case <-cancel:
return return