diff --git a/README.md b/README.md index 1c1880e..7616243 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ NEKO_DISPLAY=0 // Display number NEKO_WIDTH=1280 // Display width NEKO_HEIGHT=720 // Display height NEKO_PASSWORD=neko // Password -NEKO_ADMIN=neko // Admin Password +NEKO_ADMIN=neko // Admin Password NEKO_BIND=0.0.0.0:8080 // Bind NEKO_KEY= // (SSL)Key NEKO_CERT= // (SSL)Cert @@ -65,4 +65,5 @@ NEKO_CERT= // (SSL)Cert `cd .docker && ./build` ### Non Goals - * Turning n.eko into a service that serves multiple rooms and browsers/desktops. \ No newline at end of file + * Turning n.eko into a service that serves multiple rooms and browsers/desktops. + * Voice chat, use [Discord](https://discordapp.com/)) \ No newline at end of file diff --git a/client/src/neko/base.ts b/client/src/neko/base.ts index 02dbc95..fec5d1e 100644 --- a/client/src/neko/base.ts +++ b/client/src/neko/base.ts @@ -34,7 +34,7 @@ export abstract class BaseClient extends EventEmitter { } get peerConnected() { - return typeof this._peer !== 'undefined' && this._state === 'connected' + return typeof this._peer !== 'undefined' && ['connected', 'checking', 'completed'].includes(this._state) } get connected() { @@ -60,7 +60,7 @@ export abstract class BaseClient extends EventEmitter { this._ws.onmessage = this.onMessage.bind(this) this._ws.onerror = event => this.onError.bind(this) this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed')) - this._timeout = setTimeout(this.onTimeout.bind(this), 5000) + this._timeout = setTimeout(this.onTimeout.bind(this), 15000) } catch (err) { this.onDisconnected(err) } diff --git a/server/go.sum b/server/go.sum index 6719b1d..ef59363 100644 --- a/server/go.sum +++ b/server/go.sum @@ -127,6 +127,7 @@ github.com/pion/transport v0.8.10 h1:lTiobMEw2PG6BH/mgIVqTV2mBp/mPT+IJLaN8ZxgdHk github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8= github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80= github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU= +github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4= github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk= github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/server/internal/config/webrtc.go b/server/internal/config/webrtc.go index aaa1d19..1118759 100644 --- a/server/internal/config/webrtc.go +++ b/server/internal/config/webrtc.go @@ -1,8 +1,7 @@ package config import ( - "strings" - + "github.com/pion/webrtc/v2" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -22,13 +21,8 @@ func (WebRTC) Init(cmd *cobra.Command) error { return err } - cmd.PersistentFlags().String("ac", "opus", "Audio codec to use for streaming") - if err := viper.BindPFlag("acodec", cmd.PersistentFlags().Lookup("ac")); err != nil { - return err - } - - cmd.PersistentFlags().String("ap", "", "Audio codec parameters to use for streaming") - if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("ap")); err != nil { + cmd.PersistentFlags().String("aduio", "", "Audio codec parameters to use for streaming") + if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("aduio")); err != nil { return err } @@ -37,13 +31,45 @@ func (WebRTC) Init(cmd *cobra.Command) error { return err } - cmd.PersistentFlags().String("vc", "vp8", "Video codec to use for streaming") - if err := viper.BindPFlag("vcodec", cmd.PersistentFlags().Lookup("vc")); err != nil { + cmd.PersistentFlags().String("video", "", "Video codec parameters to use for streaming") + if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("video")); err != nil { return err } - cmd.PersistentFlags().String("vp", "", "Video codec parameters to use for streaming") - if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("vp")); err != nil { + // video codecs + cmd.PersistentFlags().Bool("vp8", false, "Use VP8 codec") + if err := viper.BindPFlag("vp8", cmd.PersistentFlags().Lookup("vp8")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("vp9", false, "Use VP9 codec") + if err := viper.BindPFlag("vp9", cmd.PersistentFlags().Lookup("vp9")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("h264", false, "Use H264 codec") + if err := viper.BindPFlag("h264", cmd.PersistentFlags().Lookup("h264")); err != nil { + return err + } + + // audio codecs + cmd.PersistentFlags().Bool("opus", false, "Use Opus codec") + if err := viper.BindPFlag("opus", cmd.PersistentFlags().Lookup("opus")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("g722", false, "Use G722 codec") + if err := viper.BindPFlag("g722", cmd.PersistentFlags().Lookup("g722")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("pcmu", false, "Use PCMU codec") + if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("pcma", false, "Use PCMA codec") + if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil { return err } @@ -51,10 +77,30 @@ func (WebRTC) Init(cmd *cobra.Command) error { } func (s *WebRTC) Set() { - s.Device = strings.ToLower(viper.GetString("device")) - s.AudioCodec = strings.ToLower(viper.GetString("acodec")) - s.AudioParams = strings.ToLower(viper.GetString("aparams")) - s.Display = strings.ToLower(viper.GetString("display")) - s.VideoCodec = strings.ToLower(viper.GetString("vcodec")) - s.VideoParams = strings.ToLower(viper.GetString("vparams")) + videoCodec := webrtc.VP8 + if viper.GetBool("vp8") { + videoCodec = webrtc.VP8 + } else if viper.GetBool("vp9") { + videoCodec = webrtc.VP9 + } else if viper.GetBool("h264") { + videoCodec = webrtc.H264 + } + + audioCodec := webrtc.VP8 + if viper.GetBool("opus") { + audioCodec = webrtc.Opus + } else if viper.GetBool("g722") { + audioCodec = webrtc.G722 + } else if viper.GetBool("pcmu") { + audioCodec = webrtc.PCMU + } else if viper.GetBool("pcma") { + audioCodec = webrtc.PCMA + } + + s.Device = viper.GetString("device") + s.AudioCodec = audioCodec + s.AudioParams = viper.GetString("aparams") + s.Display = viper.GetString("display") + s.VideoCodec = videoCodec + s.VideoParams = viper.GetString("vparams") } diff --git a/server/internal/gst/gst.go b/server/internal/gst/gst.go index 1b2f86b..6ef7dcb 100644 --- a/server/internal/gst/gst.go +++ b/server/internal/gst/gst.go @@ -9,12 +9,13 @@ package gst import "C" import ( "fmt" - "io" "sync" "unsafe" "github.com/pion/webrtc/v2" - "github.com/pion/webrtc/v2/pkg/media" + "github.com/pkg/errors" + + "n.eko.moe/neko/internal/types" ) /* @@ -35,10 +36,10 @@ import ( // Pipeline is a wrapper for a GStreamer Pipeline type Pipeline struct { Pipeline *C.GstElement - tracks []*webrtc.Track + Sample chan types.Sample + CodecName string + ClockRate float32 id int - codecName string - clockRate float32 } var pipelines = make(map[int]*Pipeline) @@ -57,39 +58,39 @@ func init() { } // CreatePipeline creates a GStreamer Pipeline -func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string) *Pipeline { +func CreatePipeline(codecName string, pipelineSrc string) (*Pipeline, error) { pipelineStr := "appsink name=appsink" var clockRate float32 switch codecName { case webrtc.VP8: // https://gstreamer.freedesktop.org/documentation/vpx/vp8enc.html?gi-language=c - // gstreamer1.0-plugins-good + // gstreamer1.0-plugins-good // vp8enc error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true cpu-used=5 deadline=1 pipelineStr = pipelineSrc + " ! vp8enc error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true cpu-used=5 deadline=1 ! " + pipelineStr clockRate = videoClockRate if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil { - panic(err) + return nil, err } case webrtc.VP9: // https://gstreamer.freedesktop.org/documentation/vpx/vp9enc.html?gi-language=c - // gstreamer1.0-plugins-good + // gstreamer1.0-plugins-good // vp9enc - // Causes panic! + // Causes panic! pipelineStr = pipelineSrc + " ! vp9enc ! " + pipelineStr clockRate = videoClockRate if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil { - panic(err) + return nil, err } case webrtc.H264: // https://gstreamer.freedesktop.org/documentation/x264/index.html?gi-language=c // gstreamer1.0-plugins-ugly - // video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream + // video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream // https://gstreamer.freedesktop.org/documentation/openh264/openh264enc.html?gi-language=c#openh264enc // gstreamer1.0-plugins-bad @@ -98,14 +99,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string clockRate = videoClockRate if err := CheckPlugins([]string{"ximagesrc"}); err != nil { - panic(err) + return nil, err } if err := CheckPlugins([]string{"openh264"}); err != nil { pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr if err := CheckPlugins([]string{"x264"}); err != nil { - panic(err) + return nil, err } } @@ -115,9 +116,9 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string // opusenc pipelineStr = pipelineSrc + " ! opusenc ! " + pipelineStr clockRate = audioClockRate - + if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil { - panic(err) + return nil, err } case webrtc.G722: @@ -128,7 +129,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string clockRate = audioClockRate if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil { - panic(err) + return nil, err } case webrtc.PCMU: @@ -140,7 +141,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string clockRate = pcmClockRate if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil { - panic(err) + return nil, err } case webrtc.PCMA: @@ -151,11 +152,11 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string clockRate = pcmClockRate if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil { - panic(err) + return nil, err } default: - panic("Unhandled codec " + codecName) + return nil, errors.Errorf("unknown video codec %s", codecName) } pipelineStrUnsafe := C.CString(pipelineStr) @@ -166,14 +167,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string pipeline := &Pipeline{ Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe), - tracks: tracks, + Sample: make(chan types.Sample), + CodecName: codecName, + ClockRate: clockRate, id: len(pipelines), - codecName: codecName, - clockRate: clockRate, } pipelines[pipeline.id] = pipeline - return pipeline + return pipeline, nil } // Start starts the GStreamer Pipeline @@ -193,14 +194,13 @@ func CheckPlugins(plugins []string) error { plugin = C.gst_registry_find_plugin(registry, plugincstr) C.free(unsafe.Pointer(plugincstr)) if plugin == nil { - return fmt.Errorf("Required gstreamer plugin %s not found", pluginstr) + return fmt.Errorf("required gstreamer plugin %s not found", pluginstr) } } return nil } - //export goHandlePipelineBuffer func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) { pipelinesLock.Lock() @@ -208,12 +208,8 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i pipelinesLock.Unlock() if ok { - samples := uint32(pipeline.clockRate * (float32(duration) / 1000000000)) - for _, t := range pipeline.tracks { - if err := t.WriteSample(media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}); err != nil && err != io.ErrClosedPipe { - panic(err) - } - } + samples := uint32(pipeline.ClockRate * (float32(duration) / 1000000000)) + pipeline.Sample <- types.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples} } else { fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID)) } diff --git a/server/internal/session/manager.go b/server/internal/session/manager.go index a4daac8..6673882 100644 --- a/server/internal/session/manager.go +++ b/server/internal/session/manager.go @@ -3,15 +3,17 @@ package session import ( "fmt" - "github.com/gorilla/websocket" "github.com/kataras/go-events" - "github.com/pion/webrtc/v2" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "n.eko.moe/neko/internal/types" "n.eko.moe/neko/internal/utils" ) func New() *SessionManager { return &SessionManager{ + logger: log.With().Str("module", "session").Logger(), host: "", members: make(map[string]*Session), emmiter: events.New(), @@ -19,153 +21,100 @@ func New() *SessionManager { } type SessionManager struct { + logger zerolog.Logger host string members map[string]*Session emmiter events.EventEmmiter } -func (m *SessionManager) New(id string, admin bool, socket *websocket.Conn) *Session { +func (manager *SessionManager) New(id string, admin bool, socket types.WebScoket) types.Session { session := &Session{ - ID: id, - Admin: admin, + id: id, + admin: admin, + manager: manager, socket: socket, + logger: manager.logger.With().Str("id", id).Logger(), connected: false, } - m.members[id] = session - m.emmiter.Emit("created", id, session) + manager.members[id] = session + manager.emmiter.Emit("created", id, session) return session } -func (m *SessionManager) IsHost(id string) bool { - return m.host == id +func (manager *SessionManager) HasHost() bool { + return manager.host != "" } -func (m *SessionManager) HasHost() bool { - return m.host != "" +func (manager *SessionManager) IsHost(id string) bool { + return manager.host == id } -func (m *SessionManager) SetHost(id string) error { - _, ok := m.members[id] +func (manager *SessionManager) SetHost(id string) error { + _, ok := manager.members[id] if ok { - m.host = id - m.emmiter.Emit("host", id) + manager.host = id + manager.emmiter.Emit("host", id) return nil } return fmt.Errorf("invalid session id %s", id) } -func (m *SessionManager) GetHost() (*Session, bool) { - host, ok := m.members[m.host] +func (manager *SessionManager) GetHost() (types.Session, bool) { + host, ok := manager.members[manager.host] return host, ok } -func (m *SessionManager) ClearHost() { - id := m.host - m.host = "" - m.emmiter.Emit("host_cleared", id) +func (manager *SessionManager) ClearHost() { + id := manager.host + manager.host = "" + manager.emmiter.Emit("host_cleared", id) } -func (m *SessionManager) Has(id string) bool { - _, ok := m.members[id] +func (manager *SessionManager) Has(id string) bool { + _, ok := manager.members[id] return ok } -func (m *SessionManager) Get(id string) (*Session, bool) { - session, ok := m.members[id] +func (manager *SessionManager) Get(id string) (types.Session, bool) { + session, ok := manager.members[id] return session, ok } -func (m *SessionManager) GetConnected() []*Session { - var sessions []*Session - for _, sess := range m.members { - if sess.connected { - sessions = append(sessions, sess) +func (manager *SessionManager) Members() []*types.Member { + members := []*types.Member{} + for _, session := range manager.members { + if !session.connected { + continue + } + + member := session.Member() + if member != nil { + members = append(members, member) } } - - return sessions + return members } -func (m *SessionManager) Set(id string, session *Session) { - m.members[id] = session -} - -func (m *SessionManager) Destroy(id string) error { - session, ok := m.members[id] +func (manager *SessionManager) Destroy(id string) error { + session, ok := manager.members[id] if ok { err := session.destroy() - delete(m.members, id) - m.emmiter.Emit("destroyed", id) + delete(manager.members, id) + manager.emmiter.Emit("destroyed", id) return err } return nil } -func (m *SessionManager) SetSocket(id string, socket *websocket.Conn) (bool, error) { - session, ok := m.members[id] - if ok { - session.socket = socket - return true, nil - } - - return false, fmt.Errorf("invalid session id %s", id) -} - -func (m *SessionManager) SetPeer(id string, peer *webrtc.PeerConnection) (bool, error) { - session, ok := m.members[id] - if ok { - session.peer = peer - return true, nil - } - - return false, fmt.Errorf("invalid session id %s", id) -} - -func (m *SessionManager) SetName(id string, name string) (bool, error) { - session, ok := m.members[id] - if ok { - session.Name = name - session.connected = true - m.emmiter.Emit("connected", id, session) - return true, nil - } - - return false, fmt.Errorf("invalid session id %s", id) -} - -func (m *SessionManager) Mute(id string) error { - session, ok := m.members[id] - if ok { - session.Muted = true - } +func (manager *SessionManager) Clear() error { return nil } -func (m *SessionManager) Unmute(id string) error { - session, ok := m.members[id] - if ok { - session.Muted = false - } - return nil -} - -func (m *SessionManager) Kick(id string, v interface{}) error { - session, ok := m.members[id] - if ok { - return session.Kick(v) - } - return nil -} - -func (m *SessionManager) Clear() error { - return nil -} - -func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error { - for id, sess := range m.members { - if !sess.connected { +func (manager *SessionManager) Brodcast(v interface{}, exclude interface{}) error { + for id, session := range manager.members { + if !session.connected { continue } @@ -175,39 +124,65 @@ func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error { } } - if err := sess.Send(v); err != nil { + if err := session.Send(v); err != nil { return err } } return nil } -func (m *SessionManager) OnHost(listener func(id string)) { - m.emmiter.On("host", func(payload ...interface{}) { +func (manager *SessionManager) WriteVideoSample(sample types.Sample) error { + for _, session := range manager.members { + if !session.connected { + continue + } + + if err := session.WriteVideoSample(sample); err != nil { + return err + } + } + return nil +} + +func (manager *SessionManager) WriteAudioSample(sample types.Sample) error { + for _, session := range manager.members { + if !session.connected { + continue + } + + if err := session.WriteAudioSample(sample); err != nil { + return err + } + } + return nil +} + +func (manager *SessionManager) OnHost(listener func(id string)) { + manager.emmiter.On("host", func(payload ...interface{}) { listener(payload[0].(string)) }) } -func (m *SessionManager) OnHostCleared(listener func(id string)) { - m.emmiter.On("host_cleared", func(payload ...interface{}) { +func (manager *SessionManager) OnHostCleared(listener func(id string)) { + manager.emmiter.On("host_cleared", func(payload ...interface{}) { listener(payload[0].(string)) }) } -func (m *SessionManager) OnCreated(listener func(id string, session *Session)) { - m.emmiter.On("created", func(payload ...interface{}) { +func (manager *SessionManager) OnDestroy(listener func(id string)) { + manager.emmiter.On("destroyed", func(payload ...interface{}) { + listener(payload[0].(string)) + }) +} + +func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) { + manager.emmiter.On("created", func(payload ...interface{}) { listener(payload[0].(string), payload[1].(*Session)) }) } -func (m *SessionManager) OnConnected(listener func(id string, session *Session)) { - m.emmiter.On("connected", func(payload ...interface{}) { +func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) { + manager.emmiter.On("connected", func(payload ...interface{}) { listener(payload[0].(string), payload[1].(*Session)) }) } - -func (m *SessionManager) OnDestroy(listener func(id string)) { - m.emmiter.On("destroyed", func(payload ...interface{}) { - listener(payload[0].(string)) - }) -} diff --git a/server/internal/session/session.go b/server/internal/session/session.go index e1cd7ab..78280f2 100644 --- a/server/internal/session/session.go +++ b/server/internal/session/session.go @@ -3,38 +3,90 @@ package session import ( "sync" - "github.com/gorilla/websocket" - "github.com/pion/webrtc/v2" + "github.com/rs/zerolog" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) type Session struct { - ID string `json:"id"` - Name string `json:"username"` - Admin bool `json:"admin"` - Muted bool `json:"muted"` + logger zerolog.Logger + id string + name string + admin bool + muted bool connected bool - socket *websocket.Conn - peer *webrtc.PeerConnection + manager *SessionManager + socket types.WebScoket + peer types.Peer mu sync.Mutex } -func (session *Session) RemoteAddr() *string { - if session.socket != nil { - address := session.socket.RemoteAddr().String() - return &address +func (session *Session) ID() string { + return session.id +} + +func (session *Session) Name() string { + return session.name +} + +func (session *Session) Admin() bool { + return session.admin +} + +func (session *Session) Muted() bool { + return session.muted +} + +func (session *Session) Connected() bool { + return session.connected +} + +func (session *Session) Address() *string { + if session.socket == nil { + return nil } + return session.socket.Address() +} + +func (session *Session) Member() *types.Member { + return &types.Member{ + ID: session.id, + Name: session.name, + Admin: session.admin, + Muted: session.muted, + } +} + +func (session *Session) SetMuted(muted bool) { + session.muted = muted +} + +func (session *Session) SetName(name string) error { + session.name = name return nil } -// TODO: write to peer data channel -func (session *Session) Write(v interface{}) error { - session.mu.Lock() - defer session.mu.Unlock() +func (session *Session) SetSocket(socket types.WebScoket) error { + session.socket = socket return nil } -func (session *Session) Kick(v interface{}) error { - if err := session.Send(v); err != nil { +func (session *Session) SetPeer(peer types.Peer) error { + session.peer = peer + session.connected = true + session.manager.emmiter.Emit("connected", session.id, session) + return nil +} + +func (session *Session) Kick(reason string) error { + if session.socket == nil { + return nil + } + if err := session.socket.Send(&message.Disconnect{ + Event: event.SYSTEM_DISCONNECT, + Message: reason, + }); err != nil { return err } @@ -42,27 +94,40 @@ func (session *Session) Kick(v interface{}) error { } func (session *Session) Send(v interface{}) error { - session.mu.Lock() - defer session.mu.Unlock() - - if session.socket != nil { - return session.socket.WriteJSON(v) + if session.socket == nil { + return nil } + return session.socket.Send(v) +} - return nil +func (session *Session) Write(v interface{}) error { + if session.socket == nil { + return nil + } + return session.socket.Send(v) +} + +func (session *Session) WriteVideoSample(sample types.Sample) error { + if session.peer == nil || !session.connected { + return nil + } + return session.peer.WriteVideoSample(sample) +} + +func (session *Session) WriteAudioSample(sample types.Sample) error { + if session.peer == nil || !session.connected { + return nil + } + return session.peer.WriteAudioSample(sample) } func (session *Session) destroy() error { - if session.peer != nil && session.peer.ConnectionState() == webrtc.PeerConnectionStateConnected { - if err := session.peer.Close(); err != nil { - return err - } + if err := session.socket.Destroy(); err != nil { + return err } - if session.socket != nil { - if err := session.socket.Close(); err != nil { - return err - } + if err := session.peer.Destroy(); err != nil { + return err } return nil diff --git a/server/internal/event/events.go b/server/internal/types/event/events.go similarity index 100% rename from server/internal/event/events.go rename to server/internal/types/event/events.go diff --git a/server/internal/message/messages.go b/server/internal/types/message/messages.go similarity index 90% rename from server/internal/message/messages.go rename to server/internal/types/message/messages.go index deb0cc9..11fdcd9 100644 --- a/server/internal/message/messages.go +++ b/server/internal/types/message/messages.go @@ -1,6 +1,8 @@ package message -import "n.eko.moe/neko/internal/session" +import ( + "n.eko.moe/neko/internal/types" +) type Message struct { Event string `json:"event"` @@ -27,13 +29,13 @@ type Signal struct { } type MembersList struct { - Event string `json:"event"` - Memebers []*session.Session `json:"members"` + Event string `json:"event"` + Memebers []*types.Member `json:"members"` } type Member struct { Event string `json:"event"` - *session.Session + *types.Member } type MemberDisconnected struct { Event string `json:"event"` diff --git a/server/internal/types/session.go b/server/internal/types/session.go new file mode 100644 index 0000000..4e57cf3 --- /dev/null +++ b/server/internal/types/session.go @@ -0,0 +1,49 @@ +package types + +type Member struct { + ID string `json:"id"` + Name string `json:"username"` + Admin bool `json:"admin"` + Muted bool `json:"muted"` +} + +type Session interface { + ID() string + Name() string + Admin() bool + Muted() bool + Connected() bool + Member() *Member + SetMuted(muted bool) + SetName(name string) error + SetSocket(socket WebScoket) error + SetPeer(peer Peer) error + Address() *string + Kick(message string) error + Write(v interface{}) error + Send(v interface{}) error + WriteAudioSample(sample Sample) error + WriteVideoSample(sample Sample) error +} + +type SessionManager interface { + New(id string, admin bool, socket WebScoket) Session + HasHost() bool + IsHost(id string) bool + SetHost(id string) error + GetHost() (Session, bool) + ClearHost() + Has(id string) bool + Get(id string) (Session, bool) + Members() []*Member + Destroy(id string) error + Clear() error + Brodcast(v interface{}, exclude interface{}) error + WriteAudioSample(sample Sample) error + WriteVideoSample(sample Sample) error + OnHost(listener func(id string)) + OnHostCleared(listener func(id string)) + OnDestroy(listener func(id string)) + OnCreated(listener func(id string, session Session)) + OnConnected(listener func(id string, session Session)) +} diff --git a/server/internal/types/webrtc.go b/server/internal/types/webrtc.go new file mode 100644 index 0000000..76cc571 --- /dev/null +++ b/server/internal/types/webrtc.go @@ -0,0 +1,19 @@ +package types + +type Sample struct { + Data []byte + Samples uint32 +} + +type WebRTCManager interface { + Start() + Shutdown() error + CreatePeer(id string, sdp string) (string, Peer, error) +} + +type Peer interface { + WriteVideoSample(sample Sample) error + WriteAudioSample(sample Sample) error + WriteData(v interface{}) error + Destroy() error +} diff --git a/server/internal/types/webscoket.go b/server/internal/types/webscoket.go new file mode 100644 index 0000000..ad58721 --- /dev/null +++ b/server/internal/types/webscoket.go @@ -0,0 +1,7 @@ +package types + +type WebScoket interface { + Address() *string + Send(v interface{}) error + Destroy() error +} diff --git a/server/internal/webrtc/logger.go b/server/internal/webrtc/logger.go index 63c7195..6468e34 100644 --- a/server/internal/webrtc/logger.go +++ b/server/internal/webrtc/logger.go @@ -1,6 +1,9 @@ package webrtc import ( + "fmt" + "strings" + "github.com/pion/logging" "github.com/rs/zerolog" ) @@ -13,8 +16,19 @@ func (l logger) Trace(msg string) { l.logger.Trace().Ms func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) } func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) } func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) } -func (l logger) Info(msg string) { l.logger.Info().Msg(msg) } -func (l logger) Infof(format string, args ...interface{}) { l.logger.Info().Msgf(format, args...) } +func (l logger) Info(msg string) { + if strings.Contains(msg, "packetio.Buffer is full") { + l.logger.Panic().Msg(msg) + } + l.logger.Info().Msg(msg) +} +func (l logger) Infof(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + if strings.Contains(msg, "packetio.Buffer is full") { + l.logger.Panic().Msg(msg) + } + l.logger.Info().Msgf(format, args...) +} func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) } func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) } func (l logger) Error(msg string) { l.logger.Error().Msg(msg) } diff --git a/server/internal/webrtc/manager.go b/server/internal/webrtc/manager.go deleted file mode 100644 index f462343..0000000 --- a/server/internal/webrtc/manager.go +++ /dev/null @@ -1,226 +0,0 @@ -package webrtc - -import ( - "fmt" - "time" - - "github.com/pion/webrtc/v2" - "github.com/pkg/errors" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - - "n.eko.moe/neko/internal/config" - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/gst" - "n.eko.moe/neko/internal/hid" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" -) - -func New(sessions *session.SessionManager, conf *config.WebRTC) *WebRTCManager { - logger := log.With().Str("module", "webrtc").Logger() - engine := webrtc.MediaEngine{} - engine.RegisterDefaultCodecs() - - setings := webrtc.SettingEngine{ - LoggerFactory: loggerFactory{ - logger: logger, - }, - } - - return &WebRTCManager{ - logger: logger, - engine: engine, - setings: setings, - api: webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(setings)), - cleanup: time.NewTicker(1 * time.Second), - shutdown: make(chan bool), - sessions: sessions, - conf: conf, - config: webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{ - { - URLs: []string{"stun:stun.l.google.com:19302"}, - }, - }, - SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, - }, - } -} - -type WebRTCManager struct { - logger zerolog.Logger - engine webrtc.MediaEngine - setings webrtc.SettingEngine - config webrtc.Configuration - sessions *session.SessionManager - api *webrtc.API - video *webrtc.Track - audio *webrtc.Track - videoPipeline *gst.Pipeline - audioPipeline *gst.Pipeline - cleanup *time.Ticker - conf *config.WebRTC - shutdown chan bool -} - -func (m *WebRTCManager) Start() { - - hid.Display(m.conf.Display) - - switch m.conf.VideoCodec { - case "vp8": - if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP8); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - case "vp9": - if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP9); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - case "h264": - if err := m.createVideoTrack(webrtc.DefaultPayloadTypeH264); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - default: - m.logger.Panic().Err(errors.Errorf("unknown video codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager") - } - - switch m.conf.AudioCodec { - case "opus": - if err := m.createAudioTrack(webrtc.DefaultPayloadTypeOpus); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - case "g722": - if err := m.createAudioTrack(webrtc.DefaultPayloadTypeG722); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - case "pcmu": - if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMU); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - case "pcma": - if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMA); err != nil { - m.logger.Panic().Err(err).Msg("unable to start webrtc manager") - } - default: - m.logger.Panic().Err(errors.Errorf("unknown audio codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager") - } - - m.videoPipeline.Start() - m.audioPipeline.Start() - - go func() { - defer func() { - m.logger.Info().Msg("shutdown") - }() - - for { - select { - case <-m.shutdown: - return - case <-m.cleanup.C: - hid.Check(time.Second * 10) - } - } - }() - - m.sessions.OnHostCleared(func(id string) { - hid.Reset() - }) - - m.sessions.OnCreated(func(id string, session *session.Session) { - m.logger.Debug().Str("id", id).Msg("session created") - }) - - m.sessions.OnDestroy(func(id string) { - m.logger.Debug().Str("id", id).Msg("session destroyed") - }) - - // TODO: log resolution, bit rate and codec parameters - m.logger.Info(). - Str("video_display", m.conf.Display). - Str("video_codec", m.conf.VideoCodec). - Str("audio_device", m.conf.Device). - Str("audio_codec", m.conf.AudioCodec). - Msgf("webrtc streaming") -} - -func (m *WebRTCManager) Shutdown() error { - m.logger.Info().Msgf("webrtc shutting down") - - m.cleanup.Stop() - m.shutdown <- true - m.videoPipeline.Stop() - m.audioPipeline.Stop() - return nil -} - -func (m *WebRTCManager) CreatePeer(id string, sdp string) error { - session, ok := m.sessions.Get(id) - if !ok { - return fmt.Errorf("invalid session id %s", id) - } - - peer, err := m.api.NewPeerConnection(m.config) - if err != nil { - return err - } - - if _, err := peer.AddTransceiverFromTrack(m.video, webrtc.RtpTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionSendonly, - }); err != nil { - return err - } - - if _, err := peer.AddTransceiverFromTrack(m.audio, webrtc.RtpTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionSendonly, - }); err != nil { - return err - } - - peer.SetRemoteDescription(webrtc.SessionDescription{ - SDP: sdp, - Type: webrtc.SDPTypeOffer, - }) - - answer, err := peer.CreateAnswer(nil) - if err != nil { - return err - } - - if err = peer.SetLocalDescription(answer); err != nil { - return err - } - - if err := session.Send(message.Signal{ - Event: event.SIGNAL_ANSWER, - SDP: answer.SDP, - }); err != nil { - return err - } - - peer.OnDataChannel(func(d *webrtc.DataChannel) { - d.OnMessage(func(msg webrtc.DataChannelMessage) { - if err = m.handle(id, msg); err != nil { - m.logger.Warn().Err(err).Msg("data handle failed") - } - }) - }) - - peer.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) { - switch connectionState { - case webrtc.PeerConnectionStateDisconnected: - case webrtc.PeerConnectionStateFailed: - m.logger.Info().Str("id", id).Msg("peer disconnected") - m.sessions.Destroy(id) - break - case webrtc.PeerConnectionStateConnected: - m.logger.Info().Str("id", id).Msg("peer connected") - break - } - }) - - m.sessions.SetPeer(id, peer) - - return nil -} diff --git a/server/internal/webrtc/peer.go b/server/internal/webrtc/peer.go new file mode 100644 index 0000000..6ebac33 --- /dev/null +++ b/server/internal/webrtc/peer.go @@ -0,0 +1,44 @@ +package webrtc + +import ( + "github.com/pion/webrtc/v2" + "github.com/pion/webrtc/v2/pkg/media" + "n.eko.moe/neko/internal/types" +) + +type Peer struct { + id string + engine webrtc.MediaEngine + api *webrtc.API + video *webrtc.Track + audio *webrtc.Track + connection *webrtc.PeerConnection +} + +func (peer *Peer) WriteAudioSample(sample types.Sample) error { + if err := peer.audio.WriteSample(media.Sample(sample)); err != nil { + return err + } + return nil +} + +func (peer *Peer) WriteVideoSample(sample types.Sample) error { + if err := peer.video.WriteSample(media.Sample(sample)); err != nil { + return err + } + return nil +} + +func (peer *Peer) WriteData(v interface{}) error { + return nil +} + +func (peer *Peer) Destroy() error { + if peer.connection != nil && peer.connection.ConnectionState() == webrtc.PeerConnectionStateConnected { + if err := peer.connection.Close(); err != nil { + return err + } + } + + return nil +} diff --git a/server/internal/webrtc/tracks.go b/server/internal/webrtc/tracks.go index c38af64..e00b59b 100644 --- a/server/internal/webrtc/tracks.go +++ b/server/internal/webrtc/tracks.go @@ -5,95 +5,36 @@ import ( "math/rand" "github.com/pion/webrtc/v2" - "github.com/pkg/errors" - - "n.eko.moe/neko/internal/gst" ) -func (m *WebRTCManager) createVideoTrack(payloadType uint8) error { - - clockrate := uint32(90000) +func (m *WebRTCManager) createVideoTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) { var codec *webrtc.RTPCodec - switch payloadType { - case webrtc.DefaultPayloadTypeVP8: - codec = webrtc.NewRTPVP8Codec(payloadType, clockrate) - break - case webrtc.DefaultPayloadTypeVP9: - codec = webrtc.NewRTPVP9Codec(payloadType, clockrate) - break - case webrtc.DefaultPayloadTypeH264: - codec = webrtc.NewRTPH264Codec(payloadType, clockrate) - break - default: - return errors.Errorf("unknown video codec %s", payloadType) + for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeVideo) { + if videoCodec.Name == m.videoPipeline.CodecName { + codec = videoCodec + break + } } - track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec) - if err != nil { - return err + if codec == nil || codec.PayloadType == 0 { + return nil, fmt.Errorf("remote peer does not support %s", m.videoPipeline.CodecName) } - var pipeline *gst.Pipeline - src := fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.conf.Display) - switch payloadType { - case webrtc.DefaultPayloadTypeVP8: - pipeline = gst.CreatePipeline(webrtc.VP8, []*webrtc.Track{track}, src) - break - case webrtc.DefaultPayloadTypeVP9: - pipeline = gst.CreatePipeline(webrtc.VP9, []*webrtc.Track{track}, src) - break - case webrtc.DefaultPayloadTypeH264: - pipeline = gst.CreatePipeline(webrtc.H264, []*webrtc.Track{track}, src) - break - } - - m.video = track - m.videoPipeline = pipeline - return nil + return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec) } -func (m *WebRTCManager) createAudioTrack(payloadType uint8) error { +func (m *WebRTCManager) createAudioTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) { var codec *webrtc.RTPCodec - switch payloadType { - case webrtc.DefaultPayloadTypeOpus: - codec = webrtc.NewRTPOpusCodec(payloadType, 48000) - break - case webrtc.DefaultPayloadTypeG722: - codec = webrtc.NewRTPG722Codec(payloadType, 48000) - break - case webrtc.DefaultPayloadTypePCMU: - codec = webrtc.NewRTPPCMUCodec(payloadType, 8000) - break - case webrtc.DefaultPayloadTypePCMA: - codec = webrtc.NewRTPPCMACodec(payloadType, 8000) - break - default: - return errors.Errorf("unknown audio codec %s", payloadType) + for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeAudio) { + if videoCodec.Name == m.audioPipeline.CodecName { + codec = videoCodec + break + } } - track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec) - if err != nil { - return err + if codec == nil || codec.PayloadType == 0 { + return nil, fmt.Errorf("remote peer does not support %s", m.audioPipeline.CodecName) } - var pipeline *gst.Pipeline - src := fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.conf.Device) - switch payloadType { - case webrtc.DefaultPayloadTypeOpus: - pipeline = gst.CreatePipeline(webrtc.Opus, []*webrtc.Track{track}, src) - break - case webrtc.DefaultPayloadTypeG722: - pipeline = gst.CreatePipeline(webrtc.G722, []*webrtc.Track{track}, src) - break - case webrtc.DefaultPayloadTypePCMU: - pipeline = gst.CreatePipeline(webrtc.PCMU, []*webrtc.Track{track}, src) - break - case webrtc.DefaultPayloadTypePCMA: - pipeline = gst.CreatePipeline(webrtc.PCMA, []*webrtc.Track{track}, src) - break - } - - m.audio = track - m.audioPipeline = pipeline - return nil + return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec) } diff --git a/server/internal/webrtc/webrtc.go b/server/internal/webrtc/webrtc.go new file mode 100644 index 0000000..b3e9d3d --- /dev/null +++ b/server/internal/webrtc/webrtc.go @@ -0,0 +1,237 @@ +package webrtc + +import ( + "fmt" + "time" + + "github.com/pion/webrtc/v2" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "n.eko.moe/neko/internal/config" + "n.eko.moe/neko/internal/gst" + "n.eko.moe/neko/internal/hid" + "n.eko.moe/neko/internal/types" +) + +func New(sessions types.SessionManager, config *config.WebRTC) *WebRTCManager { + logger := log.With().Str("module", "webrtc").Logger() + setings := webrtc.SettingEngine{ + LoggerFactory: loggerFactory{ + logger: logger, + }, + } + + return &WebRTCManager{ + logger: logger, + setings: setings, + cleanup: time.NewTicker(1 * time.Second), + shutdown: make(chan bool), + sessions: sessions, + config: config, + configuration: &webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, + }, + } +} + +type WebRTCManager struct { + logger zerolog.Logger + setings webrtc.SettingEngine + sessions types.SessionManager + videoPipeline *gst.Pipeline + audioPipeline *gst.Pipeline + cleanup *time.Ticker + config *config.WebRTC + shutdown chan bool + configuration *webrtc.Configuration +} + +func (m *WebRTCManager) Start() { + hid.Display(m.config.Display) + + videoPipeline, err := gst.CreatePipeline( + m.config.VideoCodec, + fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.config.Display), + ) + + if err != nil { + m.logger.Panic().Err(err).Msg("unable to start webrtc manager") + } + + audioPipeline, err := gst.CreatePipeline( + m.config.AudioCodec, + fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.config.Device), + ) + + if err != nil { + m.logger.Panic().Err(err).Msg("unable to start webrtc manager") + } + + m.videoPipeline = videoPipeline + m.audioPipeline = audioPipeline + + videoPipeline.Start() + audioPipeline.Start() + + go func() { + defer func() { + m.logger.Info().Msg("shutdown") + }() + + for { + select { + case <-m.shutdown: + return + case sample := <-videoPipeline.Sample: + if err := m.sessions.WriteVideoSample(sample); err != nil { + m.logger.Warn().Err(err).Msg("video pipeline failed") + } + case sample := <-audioPipeline.Sample: + if err := m.sessions.WriteAudioSample(sample); err != nil { + m.logger.Warn().Err(err).Msg("audio pipeline failed") + } + case <-m.cleanup.C: + hid.Check(time.Second * 10) + } + } + }() + + m.sessions.OnHostCleared(func(id string) { + hid.Reset() + }) + + m.sessions.OnCreated(func(id string, session types.Session) { + m.logger.Debug().Str("id", id).Msg("session created") + }) + + m.sessions.OnDestroy(func(id string) { + m.logger.Debug().Str("id", id).Msg("session destroyed") + }) + + // TODO: log resolution, bit rate and codec parameters + m.logger.Info(). + Str("video_display", m.config.Display). + Str("video_codec", m.config.VideoCodec). + Str("audio_device", m.config.Device). + Str("audio_codec", m.config.AudioCodec). + Msgf("webrtc streaming") +} + +func (m *WebRTCManager) Shutdown() error { + m.logger.Info().Msgf("webrtc shutting down") + m.videoPipeline.Stop() + m.audioPipeline.Stop() + m.cleanup.Stop() + m.shutdown <- true + return nil +} + +func (m *WebRTCManager) CreatePeer(id string, sdp string) (string, types.Peer, error) { + // create SessionDescription + description := webrtc.SessionDescription{ + SDP: sdp, + Type: webrtc.SDPTypeOffer, + } + + // create MediaEngine based off sdp + engine := webrtc.MediaEngine{} + engine.PopulateFromSDP(description) + + // create API with MediaEngine and SettingEngine + api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(m.setings)) + + // create new peer connection + connection, err := api.NewPeerConnection(*m.configuration) + if err != nil { + return "", nil, err + } + + // create video track + video, err := m.createVideoTrack(engine) + if err != nil { + return "", nil, err + } + + videoTransceiver, err := connection.AddTransceiverFromTrack(video, webrtc.RtpTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }) + + if err != nil { + return "", nil, err + } + + // create audio track + audio, err := m.createAudioTrack(engine) + if err != nil { + return "", nil, err + } + + audioTransceiver, err := connection.AddTransceiverFromTrack(audio, webrtc.RtpTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }) + + if err != nil { + return "", nil, err + } + + // clear the Transceiver bufers + go func() { + for { + if _, err := audioTransceiver.Sender.ReadRTCP(); err != nil { + return + } + + if _, err = videoTransceiver.Sender.ReadRTCP(); err != nil { + return + } + } + }() + + // set remote description + connection.SetRemoteDescription(description) + + answer, err := connection.CreateAnswer(nil) + if err != nil { + return "", nil, err + } + + if err = connection.SetLocalDescription(answer); err != nil { + return "", nil, err + } + + connection.OnDataChannel(func(d *webrtc.DataChannel) { + d.OnMessage(func(msg webrtc.DataChannelMessage) { + if err = m.handle(id, msg); err != nil { + m.logger.Warn().Err(err).Msg("data handle failed") + } + }) + }) + + connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateDisconnected: + case webrtc.PeerConnectionStateFailed: + m.logger.Info().Str("id", id).Msg("peer disconnected") + m.sessions.Destroy(id) + break + case webrtc.PeerConnectionStateConnected: + m.logger.Info().Str("id", id).Msg("peer connected") + break + } + }) + + return answer.SDP, &Peer{ + id: id, + api: api, + engine: engine, + video: video, + audio: audio, + connection: connection, + }, nil +} diff --git a/server/internal/websocket/admin.go b/server/internal/websocket/admin.go index d40b4c1..35a74c5 100644 --- a/server/internal/websocket/admin.go +++ b/server/internal/websocket/admin.go @@ -3,13 +3,13 @@ package websocket import ( "strings" - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) -func (h *MessageHandler) adminLock(id string, session *session.Session) error { - if !session.Admin { +func (h *MessageHandler) adminLock(id string, session types.Session) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -33,8 +33,8 @@ func (h *MessageHandler) adminLock(id string, session *session.Session) error { return nil } -func (h *MessageHandler) adminUnlock(id string, session *session.Session) error { - if !session.Admin { +func (h *MessageHandler) adminUnlock(id string, session types.Session) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -58,8 +58,8 @@ func (h *MessageHandler) adminUnlock(id string, session *session.Session) error return nil } -func (h *MessageHandler) adminControl(id string, session *session.Session) error { - if !session.Admin { +func (h *MessageHandler) adminControl(id string, session types.Session) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -73,7 +73,7 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error message.AdminTarget{ Event: event.ADMIN_CONTROL, ID: id, - Target: host.ID, + Target: host.ID(), }, nil); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL) return err @@ -92,8 +92,8 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error return nil } -func (h *MessageHandler) adminRelease(id string, session *session.Session) error { - if !session.Admin { +func (h *MessageHandler) adminRelease(id string, session types.Session) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -107,7 +107,7 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error message.AdminTarget{ Event: event.ADMIN_RELEASE, ID: id, - Target: host.ID, + Target: host.ID(), }, nil); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE) return err @@ -126,8 +126,8 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error return nil } -func (h *MessageHandler) adminGive(id string, session *session.Session, payload *message.Admin) error { - if !session.Admin { +func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -154,8 +154,8 @@ func (h *MessageHandler) adminGive(id string, session *session.Session, payload return nil } -func (h *MessageHandler) adminMute(id string, session *session.Session, payload *message.Admin) error { - if !session.Admin { +func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -166,17 +166,17 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload return nil } - if target.Admin { + if target.Admin() { h.logger.Debug().Msg("target is an admin, baling") return nil } - target.Muted = true + target.SetMuted(true) if err := h.sessions.Brodcast( message.AdminTarget{ Event: event.ADMIN_MUTE, - Target: target.ID, + Target: target.ID(), ID: id, }, nil); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE) @@ -186,8 +186,8 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload return nil } -func (h *MessageHandler) adminUnmute(id string, session *session.Session, payload *message.Admin) error { - if !session.Admin { +func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -198,12 +198,12 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa return nil } - target.Muted = false + target.SetMuted(false) if err := h.sessions.Brodcast( message.AdminTarget{ Event: event.ADMIN_UNMUTE, - Target: target.ID, + Target: target.ID(), ID: id, }, nil); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE) @@ -213,8 +213,8 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa return nil } -func (h *MessageHandler) adminKick(id string, session *session.Session, payload *message.Admin) error { - if !session.Admin { +func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -225,22 +225,19 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload return nil } - if target.Admin { + if target.Admin() { h.logger.Debug().Msg("target is an admin, baling") return nil } - if err := target.Kick(message.Disconnect{ - Event: event.SYSTEM_DISCONNECT, - Message: "You have been kicked", - }); err != nil { + if err := target.Kick("You have been kicked"); err != nil { return err } if err := h.sessions.Brodcast( message.AdminTarget{ Event: event.ADMIN_KICK, - Target: target.ID, + Target: target.ID(), ID: id, }, []string{payload.ID}); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK) @@ -250,8 +247,8 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload return nil } -func (h *MessageHandler) adminBan(id string, session *session.Session, payload *message.Admin) error { - if !session.Admin { +func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error { + if !session.Admin() { h.logger.Debug().Msg("user not admin") return nil } @@ -262,12 +259,12 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload * return nil } - if target.Admin { + if target.Admin() { h.logger.Debug().Msg("target is an admin, baling") return nil } - remote := target.RemoteAddr() + remote := target.Address() if remote == nil { h.logger.Debug().Msg("no remote address, baling") return nil @@ -283,17 +280,14 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload * h.banned[address[0]] = true - if err := target.Kick(message.Disconnect{ - Event: event.SYSTEM_DISCONNECT, - Message: "You have been banned", - }); err != nil { + if err := target.Kick("You have been banned"); err != nil { return err } if err := h.sessions.Brodcast( message.AdminTarget{ Event: event.ADMIN_BAN, - Target: target.ID, + Target: target.ID(), ID: id, }, []string{payload.ID}); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN) diff --git a/server/internal/websocket/chat.go b/server/internal/websocket/chat.go index 26316f8..78dd38d 100644 --- a/server/internal/websocket/chat.go +++ b/server/internal/websocket/chat.go @@ -1,13 +1,13 @@ package websocket import ( - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) -func (h *MessageHandler) chat(id string, session *session.Session, payload *message.ChatRecieve) error { - if session.Muted { +func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatRecieve) error { + if session.Muted() { return nil } @@ -23,8 +23,8 @@ func (h *MessageHandler) chat(id string, session *session.Session, payload *mess return nil } -func (h *MessageHandler) chatEmote(id string, session *session.Session, payload *message.EmoteRecieve) error { - if session.Muted { +func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteRecieve) error { + if session.Muted() { return nil } diff --git a/server/internal/websocket/control.go b/server/internal/websocket/control.go index fbbc79a..ca7c6ce 100644 --- a/server/internal/websocket/control.go +++ b/server/internal/websocket/control.go @@ -1,12 +1,12 @@ package websocket import ( - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) -func (h *MessageHandler) controlRelease(id string, session *session.Session) error { +func (h *MessageHandler) controlRelease(id string, session types.Session) error { // check if session is host if !h.sessions.IsHost(id) { @@ -31,7 +31,7 @@ func (h *MessageHandler) controlRelease(id string, session *session.Session) err return nil } -func (h *MessageHandler) controlRequest(id string, session *session.Session) error { +func (h *MessageHandler) controlRequest(id string, session types.Session) error { // check for host if !h.sessions.HasHost() { // set host @@ -57,7 +57,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err // tell session there is a host if err := session.Send(message.Control{ Event: event.CONTROL_REQUEST, - ID: host.ID, + ID: host.ID(), }); err != nil { h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST) return err @@ -68,7 +68,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err Event: event.CONTROL_REQUESTING, ID: id, }); err != nil { - h.logger.Warn().Err(err).Str("id", host.ID).Msgf("sending event %s has failed", event.CONTROL_REQUESTING) + h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING) return err } } @@ -76,7 +76,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err return nil } -func (h *MessageHandler) controlGive(id string, session *session.Session, payload *message.Control) error { +func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error { // check if session is host if !h.sessions.IsHost(id) { h.logger.Debug().Str("id", id).Msg("is not the host") diff --git a/server/internal/websocket/handler.go b/server/internal/websocket/handler.go index 048525a..ba10c48 100644 --- a/server/internal/websocket/handler.go +++ b/server/internal/websocket/handler.go @@ -1,235 +1,142 @@ package websocket import ( - "fmt" - "net/http" - "time" + "encoding/json" - "github.com/gorilla/websocket" + "github.com/pkg/errors" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "n.eko.moe/neko/internal/config" - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" "n.eko.moe/neko/internal/utils" - "n.eko.moe/neko/internal/webrtc" ) -func New(sessions *session.SessionManager, webrtc *webrtc.WebRTCManager, conf *config.WebSocket) *WebSocketHandler { - logger := log.With().Str("module", "websocket").Logger() - - return &WebSocketHandler{ - logger: logger, - conf: conf, - sessions: sessions, - upgrader: websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, - handler: &MessageHandler{ - logger: logger.With().Str("subsystem", "handler").Logger(), - sessions: sessions, - webrtc: webrtc, - banned: make(map[string]bool), - locked: false, - }, - } -} - -// Send pings to peer with this period. Must be less than pongWait. -const pingPeriod = 60 * time.Second - -type WebSocketHandler struct { +type MessageHandler struct { logger zerolog.Logger - upgrader websocket.Upgrader - handler *MessageHandler - conf *config.WebSocket - sessions *session.SessionManager - shutdown chan bool + sessions types.SessionManager + webrtc types.WebRTCManager + banned map[string]bool + locked bool } -func (ws *WebSocketHandler) Start() error { - - go func() { - defer func() { - ws.logger.Info().Msg("shutdown") - }() - - for { - select { - case <-ws.shutdown: - return - } +func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) { + address := socket.Address() + if address == nil { + h.logger.Debug().Msg("no remote address, baling") + } else { + ok, banned := h.banned[*address] + if ok && banned { + h.logger.Debug().Str("address", *address).Msg("banned") + return false, "This IP has been banned", nil } - }() + } - ws.sessions.OnCreated(func(id string, session *session.Session) { - if err := ws.handler.SessionCreated(id, session); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error") - } else { - ws.logger.Debug().Str("id", id).Msg("session created") - } - }) + if h.locked { + h.logger.Debug().Msg("server locked") + return false, "Server is currently locked", nil + } - ws.sessions.OnConnected(func(id string, session *session.Session) { - if err := ws.handler.SessionConnected(id, session); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error") - } else { - ws.logger.Debug().Str("id", id).Msg("session connected") - } - }) - - ws.sessions.OnDestroy(func(id string) { - if err := ws.handler.SessionDestroyed(id); err != nil { - ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error") - } else { - ws.logger.Debug().Str("id", id).Msg("session destroyed") - } - }) - - return nil + return true, "", nil } -func (ws *WebSocketHandler) Shutdown() error { - ws.shutdown <- true - return nil +func (h *MessageHandler) Disconnected(id string) error { + return h.sessions.Destroy(id) } -func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error { - ws.logger.Debug().Msg("attempting to upgrade connection") - - socket, err := ws.upgrader.Upgrade(w, r, nil) - if err != nil { - ws.logger.Error().Err(err).Msg("failed to upgrade connection") - return err - } - - id, admin, err := ws.authenticate(r) - if err != nil { - ws.logger.Warn().Err(err).Msg("authenticatetion failed") - - if err = socket.WriteJSON(message.Disconnect{ - Event: event.SYSTEM_DISCONNECT, - Message: "invalid password", - }); err != nil { - ws.logger.Error().Err(err).Msg("failed to send disconnect") - } - - if err = socket.Close(); err != nil { - return err - } - return nil - } - - ok, reason, err := ws.handler.SocketConnected(id, socket) - if err != nil { - ws.logger.Error().Err(err).Msg("connection failed") +func (h *MessageHandler) Message(id string, raw []byte) error { + header := message.Message{} + if err := json.Unmarshal(raw, &header); err != nil { return err } + session, ok := h.sessions.Get(id) if !ok { - if err = socket.WriteJSON(message.Disconnect{ - Event: event.SYSTEM_DISCONNECT, - Message: reason, - }); err != nil { - ws.logger.Error().Err(err).Msg("failed to send disconnect") - } - - if err = socket.Close(); err != nil { - return err - } - - return nil + errors.Errorf("unknown session id %s", id) } - ws.sessions.New(id, admin, socket) + switch header.Event { + // Signal Events + case event.SIGNAL_PROVIDE: + payload := &message.Signal{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.createPeer(id, session, payload) + }), "%s failed", header.Event) + // Identity Events + case event.IDENTITY_DETAILS: + payload := &message.IdentityDetails{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.identityDetails(id, session, payload) + }), "%s failed", header.Event) - ws.logger. - Debug(). - Str("session", id). - Str("address", socket.RemoteAddr().String()). - Msg("new connection created") + // Control Events + case event.CONTROL_RELEASE: + return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event) + case event.CONTROL_REQUEST: + return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event) + case event.CONTROL_GIVE: + payload := &message.Control{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.controlGive(id, session, payload) + }), "%s failed", header.Event) - defer func() { - ws.logger. - Debug(). - Str("session", id). - Str("address", socket.RemoteAddr().String()). - Msg("session ended") - }() + // Chat Events + case event.CHAT_MESSAGE: + payload := &message.ChatRecieve{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.chat(id, session, payload) + }), "%s failed", header.Event) + case event.CHAT_EMOTE: + payload := &message.EmoteRecieve{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.chatEmote(id, session, payload) + }), "%s failed", header.Event) - ws.handle(socket, id) - return nil -} - -func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) { - id, err := utils.NewUID(32) - if err != nil { - return "", false, err - } - - passwords, ok := r.URL.Query()["password"] - if !ok || len(passwords[0]) < 1 { - return "", false, fmt.Errorf("no password provided") - } - - if passwords[0] == ws.conf.AdminPassword { - return id, true, nil - } - - if passwords[0] == ws.conf.Password { - return id, false, nil - } - - return "", false, fmt.Errorf("invalid password: %s", passwords[0]) -} - -func (ws *WebSocketHandler) handle(socket *websocket.Conn, id string) { - bytes := make(chan []byte) - cancel := make(chan struct{}) - ticker := time.NewTicker(pingPeriod) - - go func() { - defer func() { - ticker.Stop() - ws.logger.Debug().Str("address", socket.RemoteAddr().String()).Msg("handle socket ending") - ws.handler.SocketDisconnected(id) - }() - - for { - _, raw, err := socket.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - ws.logger.Warn().Err(err).Msg("read message error") - } else { - ws.logger.Debug().Err(err).Msg("read message error") - } - close(cancel) - break - } - bytes <- raw - } - }() - - for { - select { - case raw := <-bytes: - ws.logger.Debug(). - Str("session", id). - Str("raw", string(raw)). - Msg("recieved message from client") - if err := ws.handler.Message(id, raw); err != nil { - ws.logger.Error().Err(err).Msg("message handler has failed") - } - case <-cancel: - return - case _ = <-ticker.C: - if err := socket.WriteMessage(websocket.PingMessage, nil); err != nil { - return - } - } + // Admin Events + case event.ADMIN_LOCK: + return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event) + case event.ADMIN_UNLOCK: + return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event) + case event.ADMIN_CONTROL: + return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event) + case event.ADMIN_RELEASE: + return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event) + case event.ADMIN_GIVE: + payload := &message.Admin{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.adminGive(id, session, payload) + }), "%s failed", header.Event) + case event.ADMIN_BAN: + payload := &message.Admin{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.adminBan(id, session, payload) + }), "%s failed", header.Event) + case event.ADMIN_KICK: + payload := &message.Admin{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.adminKick(id, session, payload) + }), "%s failed", header.Event) + case event.ADMIN_MUTE: + payload := &message.Admin{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.adminMute(id, session, payload) + }), "%s failed", header.Event) + case event.ADMIN_UNMUTE: + payload := &message.Admin{} + return errors.Wrapf( + utils.Unmarshal(payload, raw, func() error { + return h.adminUnmute(id, session, payload) + }), "%s failed", header.Event) + default: + return errors.Errorf("unknown message event %s", header.Event) } } diff --git a/server/internal/websocket/identity.go b/server/internal/websocket/identity.go index 8d7114f..a73cf8f 100644 --- a/server/internal/websocket/identity.go +++ b/server/internal/websocket/identity.go @@ -1,13 +1,34 @@ package websocket import ( - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) -func (h *MessageHandler) identityDetails(id string, session *session.Session, payload *message.IdentityDetails) error { - if _, err := h.sessions.SetName(id, payload.Username); err != nil { +func (h *MessageHandler) identityDetails(id string, session types.Session, payload *message.IdentityDetails) error { + if err := session.SetName(payload.Username); err != nil { return err } return nil } + +func (h *MessageHandler) createPeer(id string, session types.Session, payload *message.Signal) error { + sdp, peer, err := h.webrtc.CreatePeer(id, payload.SDP) + if err != nil { + return err + } + + if err := session.SetPeer(peer); err != nil { + return err + } + + if err := session.Send(message.Signal{ + Event: event.SIGNAL_ANSWER, + SDP: sdp, + }); err != nil { + return err + } + + return nil +} diff --git a/server/internal/websocket/messages.go b/server/internal/websocket/messages.go deleted file mode 100644 index a3a7b8e..0000000 --- a/server/internal/websocket/messages.go +++ /dev/null @@ -1,149 +0,0 @@ -package websocket - -import ( - "encoding/json" - "strings" - - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/rs/zerolog" - - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" - "n.eko.moe/neko/internal/utils" - "n.eko.moe/neko/internal/webrtc" -) - -type MessageHandler struct { - logger zerolog.Logger - sessions *session.SessionManager - webrtc *webrtc.WebRTCManager - banned map[string]bool - locked bool -} - -func (h *MessageHandler) SocketConnected(id string, socket *websocket.Conn) (bool, string, error) { - remote := socket.RemoteAddr().String() - if remote != "" { - address := strings.SplitN(remote, ":", -1) - if len(address[0]) < 1 { - h.logger.Debug().Str("address", remote).Msg("no remote address, baling") - } else { - - ok, banned := h.banned[address[0]] - if ok && banned { - h.logger.Debug().Str("address", remote).Msg("banned") - return false, "This IP has been banned", nil - } - } - } - - if h.locked { - h.logger.Debug().Str("address", remote).Msg("locked") - return false, "Server is currently locked", nil - } - return true, "", nil -} - -func (h *MessageHandler) SocketDisconnected(id string) error { - return h.sessions.Destroy(id) -} - -func (h *MessageHandler) Message(id string, raw []byte) error { - header := message.Message{} - if err := json.Unmarshal(raw, &header); err != nil { - return err - } - - session, ok := h.sessions.Get(id) - if !ok { - errors.Errorf("unknown session id %s", id) - } - - switch header.Event { - // Signal Events - case event.SIGNAL_PROVIDE: - payload := message.Signal{} - return errors.Wrapf( - utils.Unmarshal(&payload, raw, func() error { - return h.webrtc.CreatePeer(id, payload.SDP) - }), "%s failed", header.Event) - - // Identity Events - case event.IDENTITY_DETAILS: - payload := &message.IdentityDetails{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.identityDetails(id, session, payload) - }), "%s failed", header.Event) - - // Control Events - case event.CONTROL_RELEASE: - return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event) - case event.CONTROL_REQUEST: - return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event) - case event.CONTROL_GIVE: - payload := &message.Control{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.controlGive(id, session, payload) - }), "%s failed", header.Event) - - // Chat Events - case event.CHAT_MESSAGE: - payload := &message.ChatRecieve{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.chat(id, session, payload) - }), "%s failed", header.Event) - case event.CHAT_EMOTE: - payload := &message.EmoteRecieve{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.chatEmote(id, session, payload) - }), "%s failed", header.Event) - - // Admin Events - case event.ADMIN_LOCK: - return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event) - case event.ADMIN_UNLOCK: - return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event) - case event.ADMIN_CONTROL: - return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event) - case event.ADMIN_RELEASE: - return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event) - case event.ADMIN_GIVE: - payload := &message.Admin{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.adminGive(id, session, payload) - }), "%s failed", header.Event) - case event.ADMIN_BAN: - payload := &message.Admin{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.adminBan(id, session, payload) - }), "%s failed", header.Event) - case event.ADMIN_KICK: - payload := &message.Admin{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.adminKick(id, session, payload) - }), "%s failed", header.Event) - case event.ADMIN_MUTE: - payload := &message.Admin{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.adminMute(id, session, payload) - }), "%s failed", header.Event) - case event.ADMIN_UNMUTE: - payload := &message.Admin{} - return errors.Wrapf( - utils.Unmarshal(payload, raw, func() error { - return h.adminUnmute(id, session, payload) - }), "%s failed", header.Event) - default: - return errors.Errorf("unknown message event %s", header.Event) - } -} diff --git a/server/internal/websocket/session.go b/server/internal/websocket/session.go index ae2adfc..dc30792 100644 --- a/server/internal/websocket/session.go +++ b/server/internal/websocket/session.go @@ -1,12 +1,12 @@ package websocket import ( - "n.eko.moe/neko/internal/event" - "n.eko.moe/neko/internal/message" - "n.eko.moe/neko/internal/session" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" ) -func (h *MessageHandler) SessionCreated(id string, session *session.Session) error { +func (h *MessageHandler) SessionCreated(id string, session types.Session) error { if err := session.Send(message.Identity{ Event: event.IDENTITY_PROVIDE, ID: id, @@ -17,11 +17,11 @@ func (h *MessageHandler) SessionCreated(id string, session *session.Session) err return nil } -func (h *MessageHandler) SessionConnected(id string, session *session.Session) error { +func (h *MessageHandler) SessionConnected(id string, session types.Session) error { // send list of members to session if err := session.Send(message.MembersList{ Event: event.MEMBER_LIST, - Memebers: h.sessions.GetConnected(), + Memebers: h.sessions.Members(), }); err != nil { h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST) return err @@ -32,7 +32,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e if ok { if err := session.Send(message.Control{ Event: event.CONTROL_LOCKED, - ID: host.ID, + ID: host.ID(), }); err != nil { h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED) return err @@ -42,8 +42,8 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e // let everyone know there is a new session if err := h.sessions.Brodcast( message.Member{ - Event: event.MEMBER_CONNECTED, - Session: session, + Event: event.MEMBER_CONNECTED, + Member: session.Member(), }, nil); err != nil { h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE) return err diff --git a/server/internal/websocket/socket.go b/server/internal/websocket/socket.go new file mode 100644 index 0000000..5af4371 --- /dev/null +++ b/server/internal/websocket/socket.go @@ -0,0 +1,37 @@ +package websocket + +import ( + "strings" + + "github.com/gorilla/websocket" +) + +type WebSocket struct { + id string + connection *websocket.Conn +} + +func (socket *WebSocket) Address() *string { + remote := socket.connection.RemoteAddr() + address := strings.SplitN(remote.String(), ":", -1) + if len(address[0]) < 1 { + return nil + } + return &address[0] +} + +func (socket *WebSocket) Send(v interface{}) error { + if socket.connection == nil { + return nil + } + + return socket.connection.WriteJSON(v) +} + +func (socket *WebSocket) Destroy() error { + if socket.connection == nil { + return nil + } + + return socket.connection.Close() +} diff --git a/server/internal/websocket/websocket.go b/server/internal/websocket/websocket.go new file mode 100644 index 0000000..339d6a9 --- /dev/null +++ b/server/internal/websocket/websocket.go @@ -0,0 +1,239 @@ +package websocket + +import ( + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "n.eko.moe/neko/internal/config" + "n.eko.moe/neko/internal/types" + "n.eko.moe/neko/internal/types/event" + "n.eko.moe/neko/internal/types/message" + "n.eko.moe/neko/internal/utils" +) + +func New(sessions types.SessionManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler { + logger := log.With().Str("module", "websocket").Logger() + + return &WebSocketHandler{ + logger: logger, + conf: conf, + sessions: sessions, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + handler: &MessageHandler{ + logger: logger.With().Str("subsystem", "handler").Logger(), + sessions: sessions, + webrtc: webrtc, + banned: make(map[string]bool), + locked: false, + }, + } +} + +// Send pings to peer with this period. Must be less than pongWait. +const pingPeriod = 60 * time.Second + +type WebSocketHandler struct { + logger zerolog.Logger + upgrader websocket.Upgrader + sessions types.SessionManager + conf *config.WebSocket + handler *MessageHandler + shutdown chan bool +} + +func (ws *WebSocketHandler) Start() error { + + go func() { + defer func() { + ws.logger.Info().Msg("shutdown") + }() + + for { + select { + case <-ws.shutdown: + return + } + } + }() + + ws.sessions.OnCreated(func(id string, session types.Session) { + if err := ws.handler.SessionCreated(id, session); err != nil { + ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error") + } else { + ws.logger.Debug().Str("id", id).Msg("session created") + } + }) + + ws.sessions.OnConnected(func(id string, session types.Session) { + if err := ws.handler.SessionConnected(id, session); err != nil { + ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error") + } else { + ws.logger.Debug().Str("id", id).Msg("session connected") + } + }) + + ws.sessions.OnDestroy(func(id string) { + if err := ws.handler.SessionDestroyed(id); err != nil { + ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error") + } else { + ws.logger.Debug().Str("id", id).Msg("session destroyed") + } + }) + + return nil +} + +func (ws *WebSocketHandler) Shutdown() error { + ws.shutdown <- true + return nil +} + +func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error { + ws.logger.Debug().Msg("attempting to upgrade connection") + + connection, err := ws.upgrader.Upgrade(w, r, nil) + if err != nil { + ws.logger.Error().Err(err).Msg("failed to upgrade connection") + return err + } + + id, admin, err := ws.authenticate(r) + if err != nil { + ws.logger.Warn().Err(err).Msg("authenticatetion failed") + + if err = connection.WriteJSON(message.Disconnect{ + Event: event.SYSTEM_DISCONNECT, + Message: "invalid password", + }); err != nil { + ws.logger.Error().Err(err).Msg("failed to send disconnect") + } + + if err = connection.Close(); err != nil { + return err + } + return nil + } + + socket := &WebSocket{ + id: id, + connection: connection, + } + + ok, reason, err := ws.handler.Connected(id, socket) + if err != nil { + ws.logger.Error().Err(err).Msg("connection failed") + return err + } + + if !ok { + if err = connection.WriteJSON(message.Disconnect{ + Event: event.SYSTEM_DISCONNECT, + Message: reason, + }); err != nil { + ws.logger.Error().Err(err).Msg("failed to send disconnect") + } + + if err = connection.Close(); err != nil { + return err + } + + return nil + } + + ws.sessions.New(id, admin, socket) + + ws.logger. + Debug(). + Str("session", id). + Str("address", connection.RemoteAddr().String()). + Msg("new connection created") + + defer func() { + ws.logger. + Debug(). + Str("session", id). + Str("address", connection.RemoteAddr().String()). + Msg("session ended") + }() + + ws.handle(connection, id) + return nil +} + +func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) { + id, err := utils.NewUID(32) + if err != nil { + return "", false, err + } + + passwords, ok := r.URL.Query()["password"] + if !ok || len(passwords[0]) < 1 { + return "", false, fmt.Errorf("no password provided") + } + + if passwords[0] == ws.conf.AdminPassword { + return id, true, nil + } + + if passwords[0] == ws.conf.Password { + return id, false, nil + } + + return "", false, fmt.Errorf("invalid password: %s", passwords[0]) +} + +func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) { + bytes := make(chan []byte) + cancel := make(chan struct{}) + ticker := time.NewTicker(pingPeriod) + + go func() { + defer func() { + ticker.Stop() + ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending") + ws.handler.Disconnected(id) + }() + + for { + _, raw, err := connection.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + ws.logger.Warn().Err(err).Msg("read message error") + } else { + ws.logger.Debug().Err(err).Msg("read message error") + } + close(cancel) + break + } + bytes <- raw + } + }() + + for { + select { + case raw := <-bytes: + ws.logger.Debug(). + Str("session", id). + Str("raw", string(raw)). + Msg("recieved message from client") + if err := ws.handler.Message(id, raw); err != nil { + ws.logger.Error().Err(err).Msg("message handler has failed") + } + case <-cancel: + return + case _ = <-ticker.C: + if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +}