This commit is contained in:
Craig 2020-01-24 15:47:37 +00:00
parent a0866a4ab9
commit e3a73aa264
26 changed files with 1154 additions and 934 deletions

View File

@ -66,3 +66,4 @@ NEKO_CERT= // (SSL)Cert
### Non Goals ### Non Goals
* Turning n.eko into a service that serves multiple rooms and browsers/desktops. * Turning n.eko into a service that serves multiple rooms and browsers/desktops.
* Voice chat, use [Discord](https://discordapp.com/))

View File

@ -34,7 +34,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
} }
get peerConnected() { get peerConnected() {
return typeof this._peer !== 'undefined' && this._state === 'connected' return typeof this._peer !== 'undefined' && ['connected', 'checking', 'completed'].includes(this._state)
} }
get connected() { get connected() {
@ -60,7 +60,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
this._ws.onmessage = this.onMessage.bind(this) this._ws.onmessage = this.onMessage.bind(this)
this._ws.onerror = event => this.onError.bind(this) this._ws.onerror = event => this.onError.bind(this)
this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed')) this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed'))
this._timeout = setTimeout(this.onTimeout.bind(this), 5000) this._timeout = setTimeout(this.onTimeout.bind(this), 15000)
} catch (err) { } catch (err) {
this.onDisconnected(err) this.onDisconnected(err)
} }

View File

@ -127,6 +127,7 @@ github.com/pion/transport v0.8.10 h1:lTiobMEw2PG6BH/mgIVqTV2mBp/mPT+IJLaN8ZxgdHk
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8= github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80= github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80=
github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU= github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU=
github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4=
github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk= github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk=
github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY= github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@ -1,8 +1,7 @@
package config package config
import ( import (
"strings" "github.com/pion/webrtc/v2"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -22,13 +21,8 @@ func (WebRTC) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().String("ac", "opus", "Audio codec to use for streaming") cmd.PersistentFlags().String("aduio", "", "Audio codec parameters to use for streaming")
if err := viper.BindPFlag("acodec", cmd.PersistentFlags().Lookup("ac")); err != nil { if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("aduio")); err != nil {
return err
}
cmd.PersistentFlags().String("ap", "", "Audio codec parameters to use for streaming")
if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("ap")); err != nil {
return err return err
} }
@ -37,13 +31,45 @@ func (WebRTC) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().String("vc", "vp8", "Video codec to use for streaming") cmd.PersistentFlags().String("video", "", "Video codec parameters to use for streaming")
if err := viper.BindPFlag("vcodec", cmd.PersistentFlags().Lookup("vc")); err != nil { if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("video")); err != nil {
return err return err
} }
cmd.PersistentFlags().String("vp", "", "Video codec parameters to use for streaming") // video codecs
if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("vp")); err != nil { cmd.PersistentFlags().Bool("vp8", false, "Use VP8 codec")
if err := viper.BindPFlag("vp8", cmd.PersistentFlags().Lookup("vp8")); err != nil {
return err
}
cmd.PersistentFlags().Bool("vp9", false, "Use VP9 codec")
if err := viper.BindPFlag("vp9", cmd.PersistentFlags().Lookup("vp9")); err != nil {
return err
}
cmd.PersistentFlags().Bool("h264", false, "Use H264 codec")
if err := viper.BindPFlag("h264", cmd.PersistentFlags().Lookup("h264")); err != nil {
return err
}
// audio codecs
cmd.PersistentFlags().Bool("opus", false, "Use Opus codec")
if err := viper.BindPFlag("opus", cmd.PersistentFlags().Lookup("opus")); err != nil {
return err
}
cmd.PersistentFlags().Bool("g722", false, "Use G722 codec")
if err := viper.BindPFlag("g722", cmd.PersistentFlags().Lookup("g722")); err != nil {
return err
}
cmd.PersistentFlags().Bool("pcmu", false, "Use PCMU codec")
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
return err
}
cmd.PersistentFlags().Bool("pcma", false, "Use PCMA codec")
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
return err return err
} }
@ -51,10 +77,30 @@ func (WebRTC) Init(cmd *cobra.Command) error {
} }
func (s *WebRTC) Set() { func (s *WebRTC) Set() {
s.Device = strings.ToLower(viper.GetString("device")) videoCodec := webrtc.VP8
s.AudioCodec = strings.ToLower(viper.GetString("acodec")) if viper.GetBool("vp8") {
s.AudioParams = strings.ToLower(viper.GetString("aparams")) videoCodec = webrtc.VP8
s.Display = strings.ToLower(viper.GetString("display")) } else if viper.GetBool("vp9") {
s.VideoCodec = strings.ToLower(viper.GetString("vcodec")) videoCodec = webrtc.VP9
s.VideoParams = strings.ToLower(viper.GetString("vparams")) } else if viper.GetBool("h264") {
videoCodec = webrtc.H264
}
audioCodec := webrtc.VP8
if viper.GetBool("opus") {
audioCodec = webrtc.Opus
} else if viper.GetBool("g722") {
audioCodec = webrtc.G722
} else if viper.GetBool("pcmu") {
audioCodec = webrtc.PCMU
} else if viper.GetBool("pcma") {
audioCodec = webrtc.PCMA
}
s.Device = viper.GetString("device")
s.AudioCodec = audioCodec
s.AudioParams = viper.GetString("aparams")
s.Display = viper.GetString("display")
s.VideoCodec = videoCodec
s.VideoParams = viper.GetString("vparams")
} }

View File

@ -9,12 +9,13 @@ package gst
import "C" import "C"
import ( import (
"fmt" "fmt"
"io"
"sync" "sync"
"unsafe" "unsafe"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
"github.com/pion/webrtc/v2/pkg/media" "github.com/pkg/errors"
"n.eko.moe/neko/internal/types"
) )
/* /*
@ -35,10 +36,10 @@ import (
// Pipeline is a wrapper for a GStreamer Pipeline // Pipeline is a wrapper for a GStreamer Pipeline
type Pipeline struct { type Pipeline struct {
Pipeline *C.GstElement Pipeline *C.GstElement
tracks []*webrtc.Track Sample chan types.Sample
CodecName string
ClockRate float32
id int id int
codecName string
clockRate float32
} }
var pipelines = make(map[int]*Pipeline) var pipelines = make(map[int]*Pipeline)
@ -57,7 +58,7 @@ func init() {
} }
// CreatePipeline creates a GStreamer Pipeline // CreatePipeline creates a GStreamer Pipeline
func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string) *Pipeline { func CreatePipeline(codecName string, pipelineSrc string) (*Pipeline, error) {
pipelineStr := "appsink name=appsink" pipelineStr := "appsink name=appsink"
var clockRate float32 var clockRate float32
@ -70,7 +71,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = videoClockRate clockRate = videoClockRate
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil { if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
panic(err) return nil, err
} }
case webrtc.VP9: case webrtc.VP9:
@ -83,7 +84,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = videoClockRate clockRate = videoClockRate
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil { if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
panic(err) return nil, err
} }
case webrtc.H264: case webrtc.H264:
@ -98,14 +99,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = videoClockRate clockRate = videoClockRate
if err := CheckPlugins([]string{"ximagesrc"}); err != nil { if err := CheckPlugins([]string{"ximagesrc"}); err != nil {
panic(err) return nil, err
} }
if err := CheckPlugins([]string{"openh264"}); err != nil { if err := CheckPlugins([]string{"openh264"}); err != nil {
pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr
if err := CheckPlugins([]string{"x264"}); err != nil { if err := CheckPlugins([]string{"x264"}); err != nil {
panic(err) return nil, err
} }
} }
@ -117,7 +118,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = audioClockRate clockRate = audioClockRate
if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil { if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil {
panic(err) return nil, err
} }
case webrtc.G722: case webrtc.G722:
@ -128,7 +129,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = audioClockRate clockRate = audioClockRate
if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil { if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil {
panic(err) return nil, err
} }
case webrtc.PCMU: case webrtc.PCMU:
@ -140,7 +141,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = pcmClockRate clockRate = pcmClockRate
if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil { if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil {
panic(err) return nil, err
} }
case webrtc.PCMA: case webrtc.PCMA:
@ -151,11 +152,11 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
clockRate = pcmClockRate clockRate = pcmClockRate
if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil { if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil {
panic(err) return nil, err
} }
default: default:
panic("Unhandled codec " + codecName) return nil, errors.Errorf("unknown video codec %s", codecName)
} }
pipelineStrUnsafe := C.CString(pipelineStr) pipelineStrUnsafe := C.CString(pipelineStr)
@ -166,14 +167,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
pipeline := &Pipeline{ pipeline := &Pipeline{
Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe), Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe),
tracks: tracks, Sample: make(chan types.Sample),
CodecName: codecName,
ClockRate: clockRate,
id: len(pipelines), id: len(pipelines),
codecName: codecName,
clockRate: clockRate,
} }
pipelines[pipeline.id] = pipeline pipelines[pipeline.id] = pipeline
return pipeline return pipeline, nil
} }
// Start starts the GStreamer Pipeline // Start starts the GStreamer Pipeline
@ -193,14 +194,13 @@ func CheckPlugins(plugins []string) error {
plugin = C.gst_registry_find_plugin(registry, plugincstr) plugin = C.gst_registry_find_plugin(registry, plugincstr)
C.free(unsafe.Pointer(plugincstr)) C.free(unsafe.Pointer(plugincstr))
if plugin == nil { if plugin == nil {
return fmt.Errorf("Required gstreamer plugin %s not found", pluginstr) return fmt.Errorf("required gstreamer plugin %s not found", pluginstr)
} }
} }
return nil return nil
} }
//export goHandlePipelineBuffer //export goHandlePipelineBuffer
func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) { func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) {
pipelinesLock.Lock() pipelinesLock.Lock()
@ -208,12 +208,8 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i
pipelinesLock.Unlock() pipelinesLock.Unlock()
if ok { if ok {
samples := uint32(pipeline.clockRate * (float32(duration) / 1000000000)) samples := uint32(pipeline.ClockRate * (float32(duration) / 1000000000))
for _, t := range pipeline.tracks { pipeline.Sample <- types.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}
if err := t.WriteSample(media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}); err != nil && err != io.ErrClosedPipe {
panic(err)
}
}
} else { } else {
fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID)) fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID))
} }

View File

@ -3,15 +3,17 @@ package session
import ( import (
"fmt" "fmt"
"github.com/gorilla/websocket"
"github.com/kataras/go-events" "github.com/kataras/go-events"
"github.com/pion/webrtc/v2" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/utils" "n.eko.moe/neko/internal/utils"
) )
func New() *SessionManager { func New() *SessionManager {
return &SessionManager{ return &SessionManager{
logger: log.With().Str("module", "session").Logger(),
host: "", host: "",
members: make(map[string]*Session), members: make(map[string]*Session),
emmiter: events.New(), emmiter: events.New(),
@ -19,153 +21,100 @@ func New() *SessionManager {
} }
type SessionManager struct { type SessionManager struct {
logger zerolog.Logger
host string host string
members map[string]*Session members map[string]*Session
emmiter events.EventEmmiter emmiter events.EventEmmiter
} }
func (m *SessionManager) New(id string, admin bool, socket *websocket.Conn) *Session { func (manager *SessionManager) New(id string, admin bool, socket types.WebScoket) types.Session {
session := &Session{ session := &Session{
ID: id, id: id,
Admin: admin, admin: admin,
manager: manager,
socket: socket, socket: socket,
logger: manager.logger.With().Str("id", id).Logger(),
connected: false, connected: false,
} }
m.members[id] = session manager.members[id] = session
m.emmiter.Emit("created", id, session) manager.emmiter.Emit("created", id, session)
return session return session
} }
func (m *SessionManager) IsHost(id string) bool { func (manager *SessionManager) HasHost() bool {
return m.host == id return manager.host != ""
} }
func (m *SessionManager) HasHost() bool { func (manager *SessionManager) IsHost(id string) bool {
return m.host != "" return manager.host == id
} }
func (m *SessionManager) SetHost(id string) error { func (manager *SessionManager) SetHost(id string) error {
_, ok := m.members[id] _, ok := manager.members[id]
if ok { if ok {
m.host = id manager.host = id
m.emmiter.Emit("host", id) manager.emmiter.Emit("host", id)
return nil return nil
} }
return fmt.Errorf("invalid session id %s", id) return fmt.Errorf("invalid session id %s", id)
} }
func (m *SessionManager) GetHost() (*Session, bool) { func (manager *SessionManager) GetHost() (types.Session, bool) {
host, ok := m.members[m.host] host, ok := manager.members[manager.host]
return host, ok return host, ok
} }
func (m *SessionManager) ClearHost() { func (manager *SessionManager) ClearHost() {
id := m.host id := manager.host
m.host = "" manager.host = ""
m.emmiter.Emit("host_cleared", id) manager.emmiter.Emit("host_cleared", id)
} }
func (m *SessionManager) Has(id string) bool { func (manager *SessionManager) Has(id string) bool {
_, ok := m.members[id] _, ok := manager.members[id]
return ok return ok
} }
func (m *SessionManager) Get(id string) (*Session, bool) { func (manager *SessionManager) Get(id string) (types.Session, bool) {
session, ok := m.members[id] session, ok := manager.members[id]
return session, ok return session, ok
} }
func (m *SessionManager) GetConnected() []*Session { func (manager *SessionManager) Members() []*types.Member {
var sessions []*Session members := []*types.Member{}
for _, sess := range m.members { for _, session := range manager.members {
if sess.connected { if !session.connected {
sessions = append(sessions, sess) continue
}
} }
return sessions member := session.Member()
if member != nil {
members = append(members, member)
}
}
return members
} }
func (m *SessionManager) Set(id string, session *Session) { func (manager *SessionManager) Destroy(id string) error {
m.members[id] = session session, ok := manager.members[id]
}
func (m *SessionManager) Destroy(id string) error {
session, ok := m.members[id]
if ok { if ok {
err := session.destroy() err := session.destroy()
delete(m.members, id) delete(manager.members, id)
m.emmiter.Emit("destroyed", id) manager.emmiter.Emit("destroyed", id)
return err return err
} }
return nil return nil
} }
func (m *SessionManager) SetSocket(id string, socket *websocket.Conn) (bool, error) { func (manager *SessionManager) Clear() error {
session, ok := m.members[id]
if ok {
session.socket = socket
return true, nil
}
return false, fmt.Errorf("invalid session id %s", id)
}
func (m *SessionManager) SetPeer(id string, peer *webrtc.PeerConnection) (bool, error) {
session, ok := m.members[id]
if ok {
session.peer = peer
return true, nil
}
return false, fmt.Errorf("invalid session id %s", id)
}
func (m *SessionManager) SetName(id string, name string) (bool, error) {
session, ok := m.members[id]
if ok {
session.Name = name
session.connected = true
m.emmiter.Emit("connected", id, session)
return true, nil
}
return false, fmt.Errorf("invalid session id %s", id)
}
func (m *SessionManager) Mute(id string) error {
session, ok := m.members[id]
if ok {
session.Muted = true
}
return nil return nil
} }
func (m *SessionManager) Unmute(id string) error { func (manager *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
session, ok := m.members[id] for id, session := range manager.members {
if ok { if !session.connected {
session.Muted = false
}
return nil
}
func (m *SessionManager) Kick(id string, v interface{}) error {
session, ok := m.members[id]
if ok {
return session.Kick(v)
}
return nil
}
func (m *SessionManager) Clear() error {
return nil
}
func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
for id, sess := range m.members {
if !sess.connected {
continue continue
} }
@ -175,39 +124,65 @@ func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
} }
} }
if err := sess.Send(v); err != nil { if err := session.Send(v); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (m *SessionManager) OnHost(listener func(id string)) { func (manager *SessionManager) WriteVideoSample(sample types.Sample) error {
m.emmiter.On("host", func(payload ...interface{}) { for _, session := range manager.members {
if !session.connected {
continue
}
if err := session.WriteVideoSample(sample); err != nil {
return err
}
}
return nil
}
func (manager *SessionManager) WriteAudioSample(sample types.Sample) error {
for _, session := range manager.members {
if !session.connected {
continue
}
if err := session.WriteAudioSample(sample); err != nil {
return err
}
}
return nil
}
func (manager *SessionManager) OnHost(listener func(id string)) {
manager.emmiter.On("host", func(payload ...interface{}) {
listener(payload[0].(string)) listener(payload[0].(string))
}) })
} }
func (m *SessionManager) OnHostCleared(listener func(id string)) { func (manager *SessionManager) OnHostCleared(listener func(id string)) {
m.emmiter.On("host_cleared", func(payload ...interface{}) { manager.emmiter.On("host_cleared", func(payload ...interface{}) {
listener(payload[0].(string)) listener(payload[0].(string))
}) })
} }
func (m *SessionManager) OnCreated(listener func(id string, session *Session)) { func (manager *SessionManager) OnDestroy(listener func(id string)) {
m.emmiter.On("created", func(payload ...interface{}) { manager.emmiter.On("destroyed", func(payload ...interface{}) {
listener(payload[0].(string))
})
}
func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) {
manager.emmiter.On("created", func(payload ...interface{}) {
listener(payload[0].(string), payload[1].(*Session)) listener(payload[0].(string), payload[1].(*Session))
}) })
} }
func (m *SessionManager) OnConnected(listener func(id string, session *Session)) { func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) {
m.emmiter.On("connected", func(payload ...interface{}) { manager.emmiter.On("connected", func(payload ...interface{}) {
listener(payload[0].(string), payload[1].(*Session)) listener(payload[0].(string), payload[1].(*Session))
}) })
} }
func (m *SessionManager) OnDestroy(listener func(id string)) {
m.emmiter.On("destroyed", func(payload ...interface{}) {
listener(payload[0].(string))
})
}

View File

@ -3,38 +3,90 @@ package session
import ( import (
"sync" "sync"
"github.com/gorilla/websocket" "github.com/rs/zerolog"
"github.com/pion/webrtc/v2" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
) )
type Session struct { type Session struct {
ID string `json:"id"` logger zerolog.Logger
Name string `json:"username"` id string
Admin bool `json:"admin"` name string
Muted bool `json:"muted"` admin bool
muted bool
connected bool connected bool
socket *websocket.Conn manager *SessionManager
peer *webrtc.PeerConnection socket types.WebScoket
peer types.Peer
mu sync.Mutex mu sync.Mutex
} }
func (session *Session) RemoteAddr() *string { func (session *Session) ID() string {
if session.socket != nil { return session.id
address := session.socket.RemoteAddr().String() }
return &address
func (session *Session) Name() string {
return session.name
}
func (session *Session) Admin() bool {
return session.admin
}
func (session *Session) Muted() bool {
return session.muted
}
func (session *Session) Connected() bool {
return session.connected
}
func (session *Session) Address() *string {
if session.socket == nil {
return nil
} }
return session.socket.Address()
}
func (session *Session) Member() *types.Member {
return &types.Member{
ID: session.id,
Name: session.name,
Admin: session.admin,
Muted: session.muted,
}
}
func (session *Session) SetMuted(muted bool) {
session.muted = muted
}
func (session *Session) SetName(name string) error {
session.name = name
return nil return nil
} }
// TODO: write to peer data channel func (session *Session) SetSocket(socket types.WebScoket) error {
func (session *Session) Write(v interface{}) error { session.socket = socket
session.mu.Lock()
defer session.mu.Unlock()
return nil return nil
} }
func (session *Session) Kick(v interface{}) error { func (session *Session) SetPeer(peer types.Peer) error {
if err := session.Send(v); err != nil { session.peer = peer
session.connected = true
session.manager.emmiter.Emit("connected", session.id, session)
return nil
}
func (session *Session) Kick(reason string) error {
if session.socket == nil {
return nil
}
if err := session.socket.Send(&message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: reason,
}); err != nil {
return err return err
} }
@ -42,28 +94,41 @@ func (session *Session) Kick(v interface{}) error {
} }
func (session *Session) Send(v interface{}) error { func (session *Session) Send(v interface{}) error {
session.mu.Lock() if session.socket == nil {
defer session.mu.Unlock()
if session.socket != nil {
return session.socket.WriteJSON(v)
}
return nil return nil
}
return session.socket.Send(v)
}
func (session *Session) Write(v interface{}) error {
if session.socket == nil {
return nil
}
return session.socket.Send(v)
}
func (session *Session) WriteVideoSample(sample types.Sample) error {
if session.peer == nil || !session.connected {
return nil
}
return session.peer.WriteVideoSample(sample)
}
func (session *Session) WriteAudioSample(sample types.Sample) error {
if session.peer == nil || !session.connected {
return nil
}
return session.peer.WriteAudioSample(sample)
} }
func (session *Session) destroy() error { func (session *Session) destroy() error {
if session.peer != nil && session.peer.ConnectionState() == webrtc.PeerConnectionStateConnected { if err := session.socket.Destroy(); err != nil {
if err := session.peer.Close(); err != nil {
return err return err
} }
}
if session.socket != nil { if err := session.peer.Destroy(); err != nil {
if err := session.socket.Close(); err != nil {
return err return err
} }
}
return nil return nil
} }

View File

@ -1,6 +1,8 @@
package message package message
import "n.eko.moe/neko/internal/session" import (
"n.eko.moe/neko/internal/types"
)
type Message struct { type Message struct {
Event string `json:"event"` Event string `json:"event"`
@ -28,12 +30,12 @@ type Signal struct {
type MembersList struct { type MembersList struct {
Event string `json:"event"` Event string `json:"event"`
Memebers []*session.Session `json:"members"` Memebers []*types.Member `json:"members"`
} }
type Member struct { type Member struct {
Event string `json:"event"` Event string `json:"event"`
*session.Session *types.Member
} }
type MemberDisconnected struct { type MemberDisconnected struct {
Event string `json:"event"` Event string `json:"event"`

View File

@ -0,0 +1,49 @@
package types
type Member struct {
ID string `json:"id"`
Name string `json:"username"`
Admin bool `json:"admin"`
Muted bool `json:"muted"`
}
type Session interface {
ID() string
Name() string
Admin() bool
Muted() bool
Connected() bool
Member() *Member
SetMuted(muted bool)
SetName(name string) error
SetSocket(socket WebScoket) error
SetPeer(peer Peer) error
Address() *string
Kick(message string) error
Write(v interface{}) error
Send(v interface{}) error
WriteAudioSample(sample Sample) error
WriteVideoSample(sample Sample) error
}
type SessionManager interface {
New(id string, admin bool, socket WebScoket) Session
HasHost() bool
IsHost(id string) bool
SetHost(id string) error
GetHost() (Session, bool)
ClearHost()
Has(id string) bool
Get(id string) (Session, bool)
Members() []*Member
Destroy(id string) error
Clear() error
Brodcast(v interface{}, exclude interface{}) error
WriteAudioSample(sample Sample) error
WriteVideoSample(sample Sample) error
OnHost(listener func(id string))
OnHostCleared(listener func(id string))
OnDestroy(listener func(id string))
OnCreated(listener func(id string, session Session))
OnConnected(listener func(id string, session Session))
}

View File

@ -0,0 +1,19 @@
package types
type Sample struct {
Data []byte
Samples uint32
}
type WebRTCManager interface {
Start()
Shutdown() error
CreatePeer(id string, sdp string) (string, Peer, error)
}
type Peer interface {
WriteVideoSample(sample Sample) error
WriteAudioSample(sample Sample) error
WriteData(v interface{}) error
Destroy() error
}

View File

@ -0,0 +1,7 @@
package types
type WebScoket interface {
Address() *string
Send(v interface{}) error
Destroy() error
}

View File

@ -1,6 +1,9 @@
package webrtc package webrtc
import ( import (
"fmt"
"strings"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -13,8 +16,19 @@ func (l logger) Trace(msg string) { l.logger.Trace().Ms
func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) } func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) }
func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) } func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) }
func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) } func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) }
func (l logger) Info(msg string) { l.logger.Info().Msg(msg) } func (l logger) Info(msg string) {
func (l logger) Infof(format string, args ...interface{}) { l.logger.Info().Msgf(format, args...) } if strings.Contains(msg, "packetio.Buffer is full") {
l.logger.Panic().Msg(msg)
}
l.logger.Info().Msg(msg)
}
func (l logger) Infof(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if strings.Contains(msg, "packetio.Buffer is full") {
l.logger.Panic().Msg(msg)
}
l.logger.Info().Msgf(format, args...)
}
func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) } func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) }
func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) } func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) }
func (l logger) Error(msg string) { l.logger.Error().Msg(msg) } func (l logger) Error(msg string) { l.logger.Error().Msg(msg) }

View File

@ -1,226 +0,0 @@
package webrtc
import (
"fmt"
"time"
"github.com/pion/webrtc/v2"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/config"
"n.eko.moe/neko/internal/event"
"n.eko.moe/neko/internal/gst"
"n.eko.moe/neko/internal/hid"
"n.eko.moe/neko/internal/message"
"n.eko.moe/neko/internal/session"
)
func New(sessions *session.SessionManager, conf *config.WebRTC) *WebRTCManager {
logger := log.With().Str("module", "webrtc").Logger()
engine := webrtc.MediaEngine{}
engine.RegisterDefaultCodecs()
setings := webrtc.SettingEngine{
LoggerFactory: loggerFactory{
logger: logger,
},
}
return &WebRTCManager{
logger: logger,
engine: engine,
setings: setings,
api: webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(setings)),
cleanup: time.NewTicker(1 * time.Second),
shutdown: make(chan bool),
sessions: sessions,
conf: conf,
config: webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
URLs: []string{"stun:stun.l.google.com:19302"},
},
},
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
},
}
}
type WebRTCManager struct {
logger zerolog.Logger
engine webrtc.MediaEngine
setings webrtc.SettingEngine
config webrtc.Configuration
sessions *session.SessionManager
api *webrtc.API
video *webrtc.Track
audio *webrtc.Track
videoPipeline *gst.Pipeline
audioPipeline *gst.Pipeline
cleanup *time.Ticker
conf *config.WebRTC
shutdown chan bool
}
func (m *WebRTCManager) Start() {
hid.Display(m.conf.Display)
switch m.conf.VideoCodec {
case "vp8":
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP8); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
case "vp9":
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP9); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
case "h264":
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeH264); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
default:
m.logger.Panic().Err(errors.Errorf("unknown video codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
}
switch m.conf.AudioCodec {
case "opus":
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeOpus); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
case "g722":
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeG722); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
case "pcmu":
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMU); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
case "pcma":
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMA); err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
default:
m.logger.Panic().Err(errors.Errorf("unknown audio codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
}
m.videoPipeline.Start()
m.audioPipeline.Start()
go func() {
defer func() {
m.logger.Info().Msg("shutdown")
}()
for {
select {
case <-m.shutdown:
return
case <-m.cleanup.C:
hid.Check(time.Second * 10)
}
}
}()
m.sessions.OnHostCleared(func(id string) {
hid.Reset()
})
m.sessions.OnCreated(func(id string, session *session.Session) {
m.logger.Debug().Str("id", id).Msg("session created")
})
m.sessions.OnDestroy(func(id string) {
m.logger.Debug().Str("id", id).Msg("session destroyed")
})
// TODO: log resolution, bit rate and codec parameters
m.logger.Info().
Str("video_display", m.conf.Display).
Str("video_codec", m.conf.VideoCodec).
Str("audio_device", m.conf.Device).
Str("audio_codec", m.conf.AudioCodec).
Msgf("webrtc streaming")
}
func (m *WebRTCManager) Shutdown() error {
m.logger.Info().Msgf("webrtc shutting down")
m.cleanup.Stop()
m.shutdown <- true
m.videoPipeline.Stop()
m.audioPipeline.Stop()
return nil
}
func (m *WebRTCManager) CreatePeer(id string, sdp string) error {
session, ok := m.sessions.Get(id)
if !ok {
return fmt.Errorf("invalid session id %s", id)
}
peer, err := m.api.NewPeerConnection(m.config)
if err != nil {
return err
}
if _, err := peer.AddTransceiverFromTrack(m.video, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
}); err != nil {
return err
}
if _, err := peer.AddTransceiverFromTrack(m.audio, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
}); err != nil {
return err
}
peer.SetRemoteDescription(webrtc.SessionDescription{
SDP: sdp,
Type: webrtc.SDPTypeOffer,
})
answer, err := peer.CreateAnswer(nil)
if err != nil {
return err
}
if err = peer.SetLocalDescription(answer); err != nil {
return err
}
if err := session.Send(message.Signal{
Event: event.SIGNAL_ANSWER,
SDP: answer.SDP,
}); err != nil {
return err
}
peer.OnDataChannel(func(d *webrtc.DataChannel) {
d.OnMessage(func(msg webrtc.DataChannelMessage) {
if err = m.handle(id, msg); err != nil {
m.logger.Warn().Err(err).Msg("data handle failed")
}
})
})
peer.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
switch connectionState {
case webrtc.PeerConnectionStateDisconnected:
case webrtc.PeerConnectionStateFailed:
m.logger.Info().Str("id", id).Msg("peer disconnected")
m.sessions.Destroy(id)
break
case webrtc.PeerConnectionStateConnected:
m.logger.Info().Str("id", id).Msg("peer connected")
break
}
})
m.sessions.SetPeer(id, peer)
return nil
}

View File

@ -0,0 +1,44 @@
package webrtc
import (
"github.com/pion/webrtc/v2"
"github.com/pion/webrtc/v2/pkg/media"
"n.eko.moe/neko/internal/types"
)
type Peer struct {
id string
engine webrtc.MediaEngine
api *webrtc.API
video *webrtc.Track
audio *webrtc.Track
connection *webrtc.PeerConnection
}
func (peer *Peer) WriteAudioSample(sample types.Sample) error {
if err := peer.audio.WriteSample(media.Sample(sample)); err != nil {
return err
}
return nil
}
func (peer *Peer) WriteVideoSample(sample types.Sample) error {
if err := peer.video.WriteSample(media.Sample(sample)); err != nil {
return err
}
return nil
}
func (peer *Peer) WriteData(v interface{}) error {
return nil
}
func (peer *Peer) Destroy() error {
if peer.connection != nil && peer.connection.ConnectionState() == webrtc.PeerConnectionStateConnected {
if err := peer.connection.Close(); err != nil {
return err
}
}
return nil
}

View File

@ -5,95 +5,36 @@ import (
"math/rand" "math/rand"
"github.com/pion/webrtc/v2" "github.com/pion/webrtc/v2"
"github.com/pkg/errors"
"n.eko.moe/neko/internal/gst"
) )
func (m *WebRTCManager) createVideoTrack(payloadType uint8) error { func (m *WebRTCManager) createVideoTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
clockrate := uint32(90000)
var codec *webrtc.RTPCodec var codec *webrtc.RTPCodec
switch payloadType { for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeVideo) {
case webrtc.DefaultPayloadTypeVP8: if videoCodec.Name == m.videoPipeline.CodecName {
codec = webrtc.NewRTPVP8Codec(payloadType, clockrate) codec = videoCodec
break
case webrtc.DefaultPayloadTypeVP9:
codec = webrtc.NewRTPVP9Codec(payloadType, clockrate)
break
case webrtc.DefaultPayloadTypeH264:
codec = webrtc.NewRTPH264Codec(payloadType, clockrate)
break
default:
return errors.Errorf("unknown video codec %s", payloadType)
}
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
if err != nil {
return err
}
var pipeline *gst.Pipeline
src := fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.conf.Display)
switch payloadType {
case webrtc.DefaultPayloadTypeVP8:
pipeline = gst.CreatePipeline(webrtc.VP8, []*webrtc.Track{track}, src)
break
case webrtc.DefaultPayloadTypeVP9:
pipeline = gst.CreatePipeline(webrtc.VP9, []*webrtc.Track{track}, src)
break
case webrtc.DefaultPayloadTypeH264:
pipeline = gst.CreatePipeline(webrtc.H264, []*webrtc.Track{track}, src)
break break
} }
}
m.video = track if codec == nil || codec.PayloadType == 0 {
m.videoPipeline = pipeline return nil, fmt.Errorf("remote peer does not support %s", m.videoPipeline.CodecName)
return nil }
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
} }
func (m *WebRTCManager) createAudioTrack(payloadType uint8) error { func (m *WebRTCManager) createAudioTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
var codec *webrtc.RTPCodec var codec *webrtc.RTPCodec
switch payloadType { for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeAudio) {
case webrtc.DefaultPayloadTypeOpus: if videoCodec.Name == m.audioPipeline.CodecName {
codec = webrtc.NewRTPOpusCodec(payloadType, 48000) codec = videoCodec
break
case webrtc.DefaultPayloadTypeG722:
codec = webrtc.NewRTPG722Codec(payloadType, 48000)
break
case webrtc.DefaultPayloadTypePCMU:
codec = webrtc.NewRTPPCMUCodec(payloadType, 8000)
break
case webrtc.DefaultPayloadTypePCMA:
codec = webrtc.NewRTPPCMACodec(payloadType, 8000)
break
default:
return errors.Errorf("unknown audio codec %s", payloadType)
}
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
if err != nil {
return err
}
var pipeline *gst.Pipeline
src := fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.conf.Device)
switch payloadType {
case webrtc.DefaultPayloadTypeOpus:
pipeline = gst.CreatePipeline(webrtc.Opus, []*webrtc.Track{track}, src)
break
case webrtc.DefaultPayloadTypeG722:
pipeline = gst.CreatePipeline(webrtc.G722, []*webrtc.Track{track}, src)
break
case webrtc.DefaultPayloadTypePCMU:
pipeline = gst.CreatePipeline(webrtc.PCMU, []*webrtc.Track{track}, src)
break
case webrtc.DefaultPayloadTypePCMA:
pipeline = gst.CreatePipeline(webrtc.PCMA, []*webrtc.Track{track}, src)
break break
} }
}
m.audio = track if codec == nil || codec.PayloadType == 0 {
m.audioPipeline = pipeline return nil, fmt.Errorf("remote peer does not support %s", m.audioPipeline.CodecName)
return nil }
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
} }

View File

@ -0,0 +1,237 @@
package webrtc
import (
"fmt"
"time"
"github.com/pion/webrtc/v2"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/config"
"n.eko.moe/neko/internal/gst"
"n.eko.moe/neko/internal/hid"
"n.eko.moe/neko/internal/types"
)
func New(sessions types.SessionManager, config *config.WebRTC) *WebRTCManager {
logger := log.With().Str("module", "webrtc").Logger()
setings := webrtc.SettingEngine{
LoggerFactory: loggerFactory{
logger: logger,
},
}
return &WebRTCManager{
logger: logger,
setings: setings,
cleanup: time.NewTicker(1 * time.Second),
shutdown: make(chan bool),
sessions: sessions,
config: config,
configuration: &webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
URLs: []string{"stun:stun.l.google.com:19302"},
},
},
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
},
}
}
type WebRTCManager struct {
logger zerolog.Logger
setings webrtc.SettingEngine
sessions types.SessionManager
videoPipeline *gst.Pipeline
audioPipeline *gst.Pipeline
cleanup *time.Ticker
config *config.WebRTC
shutdown chan bool
configuration *webrtc.Configuration
}
func (m *WebRTCManager) Start() {
hid.Display(m.config.Display)
videoPipeline, err := gst.CreatePipeline(
m.config.VideoCodec,
fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.config.Display),
)
if err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
audioPipeline, err := gst.CreatePipeline(
m.config.AudioCodec,
fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.config.Device),
)
if err != nil {
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
}
m.videoPipeline = videoPipeline
m.audioPipeline = audioPipeline
videoPipeline.Start()
audioPipeline.Start()
go func() {
defer func() {
m.logger.Info().Msg("shutdown")
}()
for {
select {
case <-m.shutdown:
return
case sample := <-videoPipeline.Sample:
if err := m.sessions.WriteVideoSample(sample); err != nil {
m.logger.Warn().Err(err).Msg("video pipeline failed")
}
case sample := <-audioPipeline.Sample:
if err := m.sessions.WriteAudioSample(sample); err != nil {
m.logger.Warn().Err(err).Msg("audio pipeline failed")
}
case <-m.cleanup.C:
hid.Check(time.Second * 10)
}
}
}()
m.sessions.OnHostCleared(func(id string) {
hid.Reset()
})
m.sessions.OnCreated(func(id string, session types.Session) {
m.logger.Debug().Str("id", id).Msg("session created")
})
m.sessions.OnDestroy(func(id string) {
m.logger.Debug().Str("id", id).Msg("session destroyed")
})
// TODO: log resolution, bit rate and codec parameters
m.logger.Info().
Str("video_display", m.config.Display).
Str("video_codec", m.config.VideoCodec).
Str("audio_device", m.config.Device).
Str("audio_codec", m.config.AudioCodec).
Msgf("webrtc streaming")
}
func (m *WebRTCManager) Shutdown() error {
m.logger.Info().Msgf("webrtc shutting down")
m.videoPipeline.Stop()
m.audioPipeline.Stop()
m.cleanup.Stop()
m.shutdown <- true
return nil
}
func (m *WebRTCManager) CreatePeer(id string, sdp string) (string, types.Peer, error) {
// create SessionDescription
description := webrtc.SessionDescription{
SDP: sdp,
Type: webrtc.SDPTypeOffer,
}
// create MediaEngine based off sdp
engine := webrtc.MediaEngine{}
engine.PopulateFromSDP(description)
// create API with MediaEngine and SettingEngine
api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(m.setings))
// create new peer connection
connection, err := api.NewPeerConnection(*m.configuration)
if err != nil {
return "", nil, err
}
// create video track
video, err := m.createVideoTrack(engine)
if err != nil {
return "", nil, err
}
videoTransceiver, err := connection.AddTransceiverFromTrack(video, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
})
if err != nil {
return "", nil, err
}
// create audio track
audio, err := m.createAudioTrack(engine)
if err != nil {
return "", nil, err
}
audioTransceiver, err := connection.AddTransceiverFromTrack(audio, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
})
if err != nil {
return "", nil, err
}
// clear the Transceiver bufers
go func() {
for {
if _, err := audioTransceiver.Sender.ReadRTCP(); err != nil {
return
}
if _, err = videoTransceiver.Sender.ReadRTCP(); err != nil {
return
}
}
}()
// set remote description
connection.SetRemoteDescription(description)
answer, err := connection.CreateAnswer(nil)
if err != nil {
return "", nil, err
}
if err = connection.SetLocalDescription(answer); err != nil {
return "", nil, err
}
connection.OnDataChannel(func(d *webrtc.DataChannel) {
d.OnMessage(func(msg webrtc.DataChannelMessage) {
if err = m.handle(id, msg); err != nil {
m.logger.Warn().Err(err).Msg("data handle failed")
}
})
})
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateDisconnected:
case webrtc.PeerConnectionStateFailed:
m.logger.Info().Str("id", id).Msg("peer disconnected")
m.sessions.Destroy(id)
break
case webrtc.PeerConnectionStateConnected:
m.logger.Info().Str("id", id).Msg("peer connected")
break
}
})
return answer.SDP, &Peer{
id: id,
api: api,
engine: engine,
video: video,
audio: audio,
connection: connection,
}, nil
}

View File

@ -3,13 +3,13 @@ package websocket
import ( import (
"strings" "strings"
"n.eko.moe/neko/internal/event" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/session" "n.eko.moe/neko/internal/types/message"
) )
func (h *MessageHandler) adminLock(id string, session *session.Session) error { func (h *MessageHandler) adminLock(id string, session types.Session) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -33,8 +33,8 @@ func (h *MessageHandler) adminLock(id string, session *session.Session) error {
return nil return nil
} }
func (h *MessageHandler) adminUnlock(id string, session *session.Session) error { func (h *MessageHandler) adminUnlock(id string, session types.Session) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -58,8 +58,8 @@ func (h *MessageHandler) adminUnlock(id string, session *session.Session) error
return nil return nil
} }
func (h *MessageHandler) adminControl(id string, session *session.Session) error { func (h *MessageHandler) adminControl(id string, session types.Session) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -73,7 +73,7 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_CONTROL, Event: event.ADMIN_CONTROL,
ID: id, ID: id,
Target: host.ID, Target: host.ID(),
}, nil); err != nil { }, nil); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL)
return err return err
@ -92,8 +92,8 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
return nil return nil
} }
func (h *MessageHandler) adminRelease(id string, session *session.Session) error { func (h *MessageHandler) adminRelease(id string, session types.Session) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -107,7 +107,7 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_RELEASE, Event: event.ADMIN_RELEASE,
ID: id, ID: id,
Target: host.ID, Target: host.ID(),
}, nil); err != nil { }, nil); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE)
return err return err
@ -126,8 +126,8 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
return nil return nil
} }
func (h *MessageHandler) adminGive(id string, session *session.Session, payload *message.Admin) error { func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -154,8 +154,8 @@ func (h *MessageHandler) adminGive(id string, session *session.Session, payload
return nil return nil
} }
func (h *MessageHandler) adminMute(id string, session *session.Session, payload *message.Admin) error { func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -166,17 +166,17 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
return nil return nil
} }
if target.Admin { if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling") h.logger.Debug().Msg("target is an admin, baling")
return nil return nil
} }
target.Muted = true target.SetMuted(true)
if err := h.sessions.Brodcast( if err := h.sessions.Brodcast(
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_MUTE, Event: event.ADMIN_MUTE,
Target: target.ID, Target: target.ID(),
ID: id, ID: id,
}, nil); err != nil { }, nil); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
@ -186,8 +186,8 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
return nil return nil
} }
func (h *MessageHandler) adminUnmute(id string, session *session.Session, payload *message.Admin) error { func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -198,12 +198,12 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
return nil return nil
} }
target.Muted = false target.SetMuted(false)
if err := h.sessions.Brodcast( if err := h.sessions.Brodcast(
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_UNMUTE, Event: event.ADMIN_UNMUTE,
Target: target.ID, Target: target.ID(),
ID: id, ID: id,
}, nil); err != nil { }, nil); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
@ -213,8 +213,8 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
return nil return nil
} }
func (h *MessageHandler) adminKick(id string, session *session.Session, payload *message.Admin) error { func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -225,22 +225,19 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
return nil return nil
} }
if target.Admin { if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling") h.logger.Debug().Msg("target is an admin, baling")
return nil return nil
} }
if err := target.Kick(message.Disconnect{ if err := target.Kick("You have been kicked"); err != nil {
Event: event.SYSTEM_DISCONNECT,
Message: "You have been kicked",
}); err != nil {
return err return err
} }
if err := h.sessions.Brodcast( if err := h.sessions.Brodcast(
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_KICK, Event: event.ADMIN_KICK,
Target: target.ID, Target: target.ID(),
ID: id, ID: id,
}, []string{payload.ID}); err != nil { }, []string{payload.ID}); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK)
@ -250,8 +247,8 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
return nil return nil
} }
func (h *MessageHandler) adminBan(id string, session *session.Session, payload *message.Admin) error { func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error {
if !session.Admin { if !session.Admin() {
h.logger.Debug().Msg("user not admin") h.logger.Debug().Msg("user not admin")
return nil return nil
} }
@ -262,12 +259,12 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
return nil return nil
} }
if target.Admin { if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling") h.logger.Debug().Msg("target is an admin, baling")
return nil return nil
} }
remote := target.RemoteAddr() remote := target.Address()
if remote == nil { if remote == nil {
h.logger.Debug().Msg("no remote address, baling") h.logger.Debug().Msg("no remote address, baling")
return nil return nil
@ -283,17 +280,14 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
h.banned[address[0]] = true h.banned[address[0]] = true
if err := target.Kick(message.Disconnect{ if err := target.Kick("You have been banned"); err != nil {
Event: event.SYSTEM_DISCONNECT,
Message: "You have been banned",
}); err != nil {
return err return err
} }
if err := h.sessions.Brodcast( if err := h.sessions.Brodcast(
message.AdminTarget{ message.AdminTarget{
Event: event.ADMIN_BAN, Event: event.ADMIN_BAN,
Target: target.ID, Target: target.ID(),
ID: id, ID: id,
}, []string{payload.ID}); err != nil { }, []string{payload.ID}); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN)

View File

@ -1,13 +1,13 @@
package websocket package websocket
import ( import (
"n.eko.moe/neko/internal/event" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/session" "n.eko.moe/neko/internal/types/message"
) )
func (h *MessageHandler) chat(id string, session *session.Session, payload *message.ChatRecieve) error { func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatRecieve) error {
if session.Muted { if session.Muted() {
return nil return nil
} }
@ -23,8 +23,8 @@ func (h *MessageHandler) chat(id string, session *session.Session, payload *mess
return nil return nil
} }
func (h *MessageHandler) chatEmote(id string, session *session.Session, payload *message.EmoteRecieve) error { func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteRecieve) error {
if session.Muted { if session.Muted() {
return nil return nil
} }

View File

@ -1,12 +1,12 @@
package websocket package websocket
import ( import (
"n.eko.moe/neko/internal/event" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/session" "n.eko.moe/neko/internal/types/message"
) )
func (h *MessageHandler) controlRelease(id string, session *session.Session) error { func (h *MessageHandler) controlRelease(id string, session types.Session) error {
// check if session is host // check if session is host
if !h.sessions.IsHost(id) { if !h.sessions.IsHost(id) {
@ -31,7 +31,7 @@ func (h *MessageHandler) controlRelease(id string, session *session.Session) err
return nil return nil
} }
func (h *MessageHandler) controlRequest(id string, session *session.Session) error { func (h *MessageHandler) controlRequest(id string, session types.Session) error {
// check for host // check for host
if !h.sessions.HasHost() { if !h.sessions.HasHost() {
// set host // set host
@ -57,7 +57,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
// tell session there is a host // tell session there is a host
if err := session.Send(message.Control{ if err := session.Send(message.Control{
Event: event.CONTROL_REQUEST, Event: event.CONTROL_REQUEST,
ID: host.ID, ID: host.ID(),
}); err != nil { }); err != nil {
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST) h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
return err return err
@ -68,7 +68,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
Event: event.CONTROL_REQUESTING, Event: event.CONTROL_REQUESTING,
ID: id, ID: id,
}); err != nil { }); err != nil {
h.logger.Warn().Err(err).Str("id", host.ID).Msgf("sending event %s has failed", event.CONTROL_REQUESTING) h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
return err return err
} }
} }
@ -76,7 +76,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
return nil return nil
} }
func (h *MessageHandler) controlGive(id string, session *session.Session, payload *message.Control) error { func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error {
// check if session is host // check if session is host
if !h.sessions.IsHost(id) { if !h.sessions.IsHost(id) {
h.logger.Debug().Str("id", id).Msg("is not the host") h.logger.Debug().Str("id", id).Msg("is not the host")

View File

@ -1,235 +1,142 @@
package websocket package websocket
import ( import (
"fmt" "encoding/json"
"net/http"
"time"
"github.com/gorilla/websocket" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/config" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/event" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types/message"
"n.eko.moe/neko/internal/session"
"n.eko.moe/neko/internal/utils" "n.eko.moe/neko/internal/utils"
"n.eko.moe/neko/internal/webrtc"
) )
func New(sessions *session.SessionManager, webrtc *webrtc.WebRTCManager, conf *config.WebSocket) *WebSocketHandler { type MessageHandler struct {
logger := log.With().Str("module", "websocket").Logger()
return &WebSocketHandler{
logger: logger,
conf: conf,
sessions: sessions,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
handler: &MessageHandler{
logger: logger.With().Str("subsystem", "handler").Logger(),
sessions: sessions,
webrtc: webrtc,
banned: make(map[string]bool),
locked: false,
},
}
}
// Send pings to peer with this period. Must be less than pongWait.
const pingPeriod = 60 * time.Second
type WebSocketHandler struct {
logger zerolog.Logger logger zerolog.Logger
upgrader websocket.Upgrader sessions types.SessionManager
handler *MessageHandler webrtc types.WebRTCManager
conf *config.WebSocket banned map[string]bool
sessions *session.SessionManager locked bool
shutdown chan bool
} }
func (ws *WebSocketHandler) Start() error { func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
address := socket.Address()
go func() { if address == nil {
defer func() { h.logger.Debug().Msg("no remote address, baling")
ws.logger.Info().Msg("shutdown")
}()
for {
select {
case <-ws.shutdown:
return
}
}
}()
ws.sessions.OnCreated(func(id string, session *session.Session) {
if err := ws.handler.SessionCreated(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
} else { } else {
ws.logger.Debug().Str("id", id).Msg("session created") ok, banned := h.banned[*address]
if ok && banned {
h.logger.Debug().Str("address", *address).Msg("banned")
return false, "This IP has been banned", nil
} }
})
ws.sessions.OnConnected(func(id string, session *session.Session) {
if err := ws.handler.SessionConnected(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session connected")
} }
})
ws.sessions.OnDestroy(func(id string) { if h.locked {
if err := ws.handler.SessionDestroyed(id); err != nil { h.logger.Debug().Msg("server locked")
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error") return false, "Server is currently locked", nil
} else {
ws.logger.Debug().Str("id", id).Msg("session destroyed")
} }
})
return nil return true, "", nil
} }
func (ws *WebSocketHandler) Shutdown() error { func (h *MessageHandler) Disconnected(id string) error {
ws.shutdown <- true return h.sessions.Destroy(id)
return nil
} }
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error { func (h *MessageHandler) Message(id string, raw []byte) error {
ws.logger.Debug().Msg("attempting to upgrade connection") header := message.Message{}
if err := json.Unmarshal(raw, &header); err != nil {
socket, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err
}
id, admin, err := ws.authenticate(r)
if err != nil {
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
if err = socket.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: "invalid password",
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
}
if err = socket.Close(); err != nil {
return err
}
return nil
}
ok, reason, err := ws.handler.SocketConnected(id, socket)
if err != nil {
ws.logger.Error().Err(err).Msg("connection failed")
return err return err
} }
session, ok := h.sessions.Get(id)
if !ok { if !ok {
if err = socket.WriteJSON(message.Disconnect{ errors.Errorf("unknown session id %s", id)
Event: event.SYSTEM_DISCONNECT,
Message: reason,
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
} }
if err = socket.Close(); err != nil { switch header.Event {
return err // Signal Events
} case event.SIGNAL_PROVIDE:
payload := &message.Signal{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.createPeer(id, session, payload)
}), "%s failed", header.Event)
// Identity Events
case event.IDENTITY_DETAILS:
payload := &message.IdentityDetails{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.identityDetails(id, session, payload)
}), "%s failed", header.Event)
return nil // Control Events
} case event.CONTROL_RELEASE:
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
case event.CONTROL_REQUEST:
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
case event.CONTROL_GIVE:
payload := &message.Control{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.controlGive(id, session, payload)
}), "%s failed", header.Event)
ws.sessions.New(id, admin, socket) // Chat Events
case event.CHAT_MESSAGE:
payload := &message.ChatRecieve{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chat(id, session, payload)
}), "%s failed", header.Event)
case event.CHAT_EMOTE:
payload := &message.EmoteRecieve{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chatEmote(id, session, payload)
}), "%s failed", header.Event)
ws.logger. // Admin Events
Debug(). case event.ADMIN_LOCK:
Str("session", id). return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
Str("address", socket.RemoteAddr().String()). case event.ADMIN_UNLOCK:
Msg("new connection created") return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
case event.ADMIN_CONTROL:
defer func() { return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
ws.logger. case event.ADMIN_RELEASE:
Debug(). return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
Str("session", id). case event.ADMIN_GIVE:
Str("address", socket.RemoteAddr().String()). payload := &message.Admin{}
Msg("session ended") return errors.Wrapf(
}() utils.Unmarshal(payload, raw, func() error {
return h.adminGive(id, session, payload)
ws.handle(socket, id) }), "%s failed", header.Event)
return nil case event.ADMIN_BAN:
} payload := &message.Admin{}
return errors.Wrapf(
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) { utils.Unmarshal(payload, raw, func() error {
id, err := utils.NewUID(32) return h.adminBan(id, session, payload)
if err != nil { }), "%s failed", header.Event)
return "", false, err case event.ADMIN_KICK:
} payload := &message.Admin{}
return errors.Wrapf(
passwords, ok := r.URL.Query()["password"] utils.Unmarshal(payload, raw, func() error {
if !ok || len(passwords[0]) < 1 { return h.adminKick(id, session, payload)
return "", false, fmt.Errorf("no password provided") }), "%s failed", header.Event)
} case event.ADMIN_MUTE:
payload := &message.Admin{}
if passwords[0] == ws.conf.AdminPassword { return errors.Wrapf(
return id, true, nil utils.Unmarshal(payload, raw, func() error {
} return h.adminMute(id, session, payload)
}), "%s failed", header.Event)
if passwords[0] == ws.conf.Password { case event.ADMIN_UNMUTE:
return id, false, nil payload := &message.Admin{}
} return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return "", false, fmt.Errorf("invalid password: %s", passwords[0]) return h.adminUnmute(id, session, payload)
} }), "%s failed", header.Event)
default:
func (ws *WebSocketHandler) handle(socket *websocket.Conn, id string) { return errors.Errorf("unknown message event %s", header.Event)
bytes := make(chan []byte)
cancel := make(chan struct{})
ticker := time.NewTicker(pingPeriod)
go func() {
defer func() {
ticker.Stop()
ws.logger.Debug().Str("address", socket.RemoteAddr().String()).Msg("handle socket ending")
ws.handler.SocketDisconnected(id)
}()
for {
_, raw, err := socket.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
ws.logger.Warn().Err(err).Msg("read message error")
} else {
ws.logger.Debug().Err(err).Msg("read message error")
}
close(cancel)
break
}
bytes <- raw
}
}()
for {
select {
case raw := <-bytes:
ws.logger.Debug().
Str("session", id).
Str("raw", string(raw)).
Msg("recieved message from client")
if err := ws.handler.Message(id, raw); err != nil {
ws.logger.Error().Err(err).Msg("message handler has failed")
}
case <-cancel:
return
case _ = <-ticker.C:
if err := socket.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
} }
} }

View File

@ -1,13 +1,34 @@
package websocket package websocket
import ( import (
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/session" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
) )
func (h *MessageHandler) identityDetails(id string, session *session.Session, payload *message.IdentityDetails) error { func (h *MessageHandler) identityDetails(id string, session types.Session, payload *message.IdentityDetails) error {
if _, err := h.sessions.SetName(id, payload.Username); err != nil { if err := session.SetName(payload.Username); err != nil {
return err return err
} }
return nil return nil
} }
func (h *MessageHandler) createPeer(id string, session types.Session, payload *message.Signal) error {
sdp, peer, err := h.webrtc.CreatePeer(id, payload.SDP)
if err != nil {
return err
}
if err := session.SetPeer(peer); err != nil {
return err
}
if err := session.Send(message.Signal{
Event: event.SIGNAL_ANSWER,
SDP: sdp,
}); err != nil {
return err
}
return nil
}

View File

@ -1,149 +0,0 @@
package websocket
import (
"encoding/json"
"strings"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"n.eko.moe/neko/internal/event"
"n.eko.moe/neko/internal/message"
"n.eko.moe/neko/internal/session"
"n.eko.moe/neko/internal/utils"
"n.eko.moe/neko/internal/webrtc"
)
type MessageHandler struct {
logger zerolog.Logger
sessions *session.SessionManager
webrtc *webrtc.WebRTCManager
banned map[string]bool
locked bool
}
func (h *MessageHandler) SocketConnected(id string, socket *websocket.Conn) (bool, string, error) {
remote := socket.RemoteAddr().String()
if remote != "" {
address := strings.SplitN(remote, ":", -1)
if len(address[0]) < 1 {
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
} else {
ok, banned := h.banned[address[0]]
if ok && banned {
h.logger.Debug().Str("address", remote).Msg("banned")
return false, "This IP has been banned", nil
}
}
}
if h.locked {
h.logger.Debug().Str("address", remote).Msg("locked")
return false, "Server is currently locked", nil
}
return true, "", nil
}
func (h *MessageHandler) SocketDisconnected(id string) error {
return h.sessions.Destroy(id)
}
func (h *MessageHandler) Message(id string, raw []byte) error {
header := message.Message{}
if err := json.Unmarshal(raw, &header); err != nil {
return err
}
session, ok := h.sessions.Get(id)
if !ok {
errors.Errorf("unknown session id %s", id)
}
switch header.Event {
// Signal Events
case event.SIGNAL_PROVIDE:
payload := message.Signal{}
return errors.Wrapf(
utils.Unmarshal(&payload, raw, func() error {
return h.webrtc.CreatePeer(id, payload.SDP)
}), "%s failed", header.Event)
// Identity Events
case event.IDENTITY_DETAILS:
payload := &message.IdentityDetails{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.identityDetails(id, session, payload)
}), "%s failed", header.Event)
// Control Events
case event.CONTROL_RELEASE:
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
case event.CONTROL_REQUEST:
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
case event.CONTROL_GIVE:
payload := &message.Control{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.controlGive(id, session, payload)
}), "%s failed", header.Event)
// Chat Events
case event.CHAT_MESSAGE:
payload := &message.ChatRecieve{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chat(id, session, payload)
}), "%s failed", header.Event)
case event.CHAT_EMOTE:
payload := &message.EmoteRecieve{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chatEmote(id, session, payload)
}), "%s failed", header.Event)
// Admin Events
case event.ADMIN_LOCK:
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
case event.ADMIN_UNLOCK:
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
case event.ADMIN_CONTROL:
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
case event.ADMIN_RELEASE:
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
case event.ADMIN_GIVE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminGive(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_BAN:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminBan(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_KICK:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminKick(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_MUTE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminMute(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_UNMUTE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminUnmute(id, session, payload)
}), "%s failed", header.Event)
default:
return errors.Errorf("unknown message event %s", header.Event)
}
}

View File

@ -1,12 +1,12 @@
package websocket package websocket
import ( import (
"n.eko.moe/neko/internal/event" "n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/message" "n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/session" "n.eko.moe/neko/internal/types/message"
) )
func (h *MessageHandler) SessionCreated(id string, session *session.Session) error { func (h *MessageHandler) SessionCreated(id string, session types.Session) error {
if err := session.Send(message.Identity{ if err := session.Send(message.Identity{
Event: event.IDENTITY_PROVIDE, Event: event.IDENTITY_PROVIDE,
ID: id, ID: id,
@ -17,11 +17,11 @@ func (h *MessageHandler) SessionCreated(id string, session *session.Session) err
return nil return nil
} }
func (h *MessageHandler) SessionConnected(id string, session *session.Session) error { func (h *MessageHandler) SessionConnected(id string, session types.Session) error {
// send list of members to session // send list of members to session
if err := session.Send(message.MembersList{ if err := session.Send(message.MembersList{
Event: event.MEMBER_LIST, Event: event.MEMBER_LIST,
Memebers: h.sessions.GetConnected(), Memebers: h.sessions.Members(),
}); err != nil { }); err != nil {
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST) h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
return err return err
@ -32,7 +32,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
if ok { if ok {
if err := session.Send(message.Control{ if err := session.Send(message.Control{
Event: event.CONTROL_LOCKED, Event: event.CONTROL_LOCKED,
ID: host.ID, ID: host.ID(),
}); err != nil { }); err != nil {
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED) h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
return err return err
@ -43,7 +43,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
if err := h.sessions.Brodcast( if err := h.sessions.Brodcast(
message.Member{ message.Member{
Event: event.MEMBER_CONNECTED, Event: event.MEMBER_CONNECTED,
Session: session, Member: session.Member(),
}, nil); err != nil { }, nil); err != nil {
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE) h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE)
return err return err

View File

@ -0,0 +1,37 @@
package websocket
import (
"strings"
"github.com/gorilla/websocket"
)
type WebSocket struct {
id string
connection *websocket.Conn
}
func (socket *WebSocket) Address() *string {
remote := socket.connection.RemoteAddr()
address := strings.SplitN(remote.String(), ":", -1)
if len(address[0]) < 1 {
return nil
}
return &address[0]
}
func (socket *WebSocket) Send(v interface{}) error {
if socket.connection == nil {
return nil
}
return socket.connection.WriteJSON(v)
}
func (socket *WebSocket) Destroy() error {
if socket.connection == nil {
return nil
}
return socket.connection.Close()
}

View File

@ -0,0 +1,239 @@
package websocket
import (
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/config"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
"n.eko.moe/neko/internal/utils"
)
func New(sessions types.SessionManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
logger := log.With().Str("module", "websocket").Logger()
return &WebSocketHandler{
logger: logger,
conf: conf,
sessions: sessions,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
handler: &MessageHandler{
logger: logger.With().Str("subsystem", "handler").Logger(),
sessions: sessions,
webrtc: webrtc,
banned: make(map[string]bool),
locked: false,
},
}
}
// Send pings to peer with this period. Must be less than pongWait.
const pingPeriod = 60 * time.Second
type WebSocketHandler struct {
logger zerolog.Logger
upgrader websocket.Upgrader
sessions types.SessionManager
conf *config.WebSocket
handler *MessageHandler
shutdown chan bool
}
func (ws *WebSocketHandler) Start() error {
go func() {
defer func() {
ws.logger.Info().Msg("shutdown")
}()
for {
select {
case <-ws.shutdown:
return
}
}
}()
ws.sessions.OnCreated(func(id string, session types.Session) {
if err := ws.handler.SessionCreated(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session created")
}
})
ws.sessions.OnConnected(func(id string, session types.Session) {
if err := ws.handler.SessionConnected(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session connected")
}
})
ws.sessions.OnDestroy(func(id string) {
if err := ws.handler.SessionDestroyed(id); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session destroyed")
}
})
return nil
}
func (ws *WebSocketHandler) Shutdown() error {
ws.shutdown <- true
return nil
}
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
ws.logger.Debug().Msg("attempting to upgrade connection")
connection, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err
}
id, admin, err := ws.authenticate(r)
if err != nil {
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
if err = connection.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: "invalid password",
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
}
if err = connection.Close(); err != nil {
return err
}
return nil
}
socket := &WebSocket{
id: id,
connection: connection,
}
ok, reason, err := ws.handler.Connected(id, socket)
if err != nil {
ws.logger.Error().Err(err).Msg("connection failed")
return err
}
if !ok {
if err = connection.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: reason,
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
}
if err = connection.Close(); err != nil {
return err
}
return nil
}
ws.sessions.New(id, admin, socket)
ws.logger.
Debug().
Str("session", id).
Str("address", connection.RemoteAddr().String()).
Msg("new connection created")
defer func() {
ws.logger.
Debug().
Str("session", id).
Str("address", connection.RemoteAddr().String()).
Msg("session ended")
}()
ws.handle(connection, id)
return nil
}
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
id, err := utils.NewUID(32)
if err != nil {
return "", false, err
}
passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 {
return "", false, fmt.Errorf("no password provided")
}
if passwords[0] == ws.conf.AdminPassword {
return id, true, nil
}
if passwords[0] == ws.conf.Password {
return id, false, nil
}
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
}
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
bytes := make(chan []byte)
cancel := make(chan struct{})
ticker := time.NewTicker(pingPeriod)
go func() {
defer func() {
ticker.Stop()
ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending")
ws.handler.Disconnected(id)
}()
for {
_, raw, err := connection.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
ws.logger.Warn().Err(err).Msg("read message error")
} else {
ws.logger.Debug().Err(err).Msg("read message error")
}
close(cancel)
break
}
bytes <- raw
}
}()
for {
select {
case raw := <-bytes:
ws.logger.Debug().
Str("session", id).
Str("raw", string(raw)).
Msg("recieved message from client")
if err := ws.handler.Message(id, raw); err != nil {
ws.logger.Error().Err(err).Msg("message handler has failed")
}
case <-cancel:
return
case _ = <-ticker.C:
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}