fixes #14
This commit is contained in:
parent
a0866a4ab9
commit
e3a73aa264
@ -54,7 +54,7 @@ NEKO_DISPLAY=0 // Display number
|
||||
NEKO_WIDTH=1280 // Display width
|
||||
NEKO_HEIGHT=720 // Display height
|
||||
NEKO_PASSWORD=neko // Password
|
||||
NEKO_ADMIN=neko // Admin Password
|
||||
NEKO_ADMIN=neko // Admin Password
|
||||
NEKO_BIND=0.0.0.0:8080 // Bind
|
||||
NEKO_KEY= // (SSL)Key
|
||||
NEKO_CERT= // (SSL)Cert
|
||||
@ -66,3 +66,4 @@ NEKO_CERT= // (SSL)Cert
|
||||
|
||||
### Non Goals
|
||||
* Turning n.eko into a service that serves multiple rooms and browsers/desktops.
|
||||
* Voice chat, use [Discord](https://discordapp.com/))
|
@ -34,7 +34,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
|
||||
}
|
||||
|
||||
get peerConnected() {
|
||||
return typeof this._peer !== 'undefined' && this._state === 'connected'
|
||||
return typeof this._peer !== 'undefined' && ['connected', 'checking', 'completed'].includes(this._state)
|
||||
}
|
||||
|
||||
get connected() {
|
||||
@ -60,7 +60,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
|
||||
this._ws.onmessage = this.onMessage.bind(this)
|
||||
this._ws.onerror = event => this.onError.bind(this)
|
||||
this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed'))
|
||||
this._timeout = setTimeout(this.onTimeout.bind(this), 5000)
|
||||
this._timeout = setTimeout(this.onTimeout.bind(this), 15000)
|
||||
} catch (err) {
|
||||
this.onDisconnected(err)
|
||||
}
|
||||
|
@ -127,6 +127,7 @@ github.com/pion/transport v0.8.10 h1:lTiobMEw2PG6BH/mgIVqTV2mBp/mPT+IJLaN8ZxgdHk
|
||||
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
|
||||
github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80=
|
||||
github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU=
|
||||
github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4=
|
||||
github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk=
|
||||
github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
@ -1,8 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -22,13 +21,8 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("ac", "opus", "Audio codec to use for streaming")
|
||||
if err := viper.BindPFlag("acodec", cmd.PersistentFlags().Lookup("ac")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("ap", "", "Audio codec parameters to use for streaming")
|
||||
if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("ap")); err != nil {
|
||||
cmd.PersistentFlags().String("aduio", "", "Audio codec parameters to use for streaming")
|
||||
if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("aduio")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -37,13 +31,45 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("vc", "vp8", "Video codec to use for streaming")
|
||||
if err := viper.BindPFlag("vcodec", cmd.PersistentFlags().Lookup("vc")); err != nil {
|
||||
cmd.PersistentFlags().String("video", "", "Video codec parameters to use for streaming")
|
||||
if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("video")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("vp", "", "Video codec parameters to use for streaming")
|
||||
if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("vp")); err != nil {
|
||||
// video codecs
|
||||
cmd.PersistentFlags().Bool("vp8", false, "Use VP8 codec")
|
||||
if err := viper.BindPFlag("vp8", cmd.PersistentFlags().Lookup("vp8")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("vp9", false, "Use VP9 codec")
|
||||
if err := viper.BindPFlag("vp9", cmd.PersistentFlags().Lookup("vp9")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("h264", false, "Use H264 codec")
|
||||
if err := viper.BindPFlag("h264", cmd.PersistentFlags().Lookup("h264")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// audio codecs
|
||||
cmd.PersistentFlags().Bool("opus", false, "Use Opus codec")
|
||||
if err := viper.BindPFlag("opus", cmd.PersistentFlags().Lookup("opus")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("g722", false, "Use G722 codec")
|
||||
if err := viper.BindPFlag("g722", cmd.PersistentFlags().Lookup("g722")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("pcmu", false, "Use PCMU codec")
|
||||
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("pcma", false, "Use PCMA codec")
|
||||
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -51,10 +77,30 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
func (s *WebRTC) Set() {
|
||||
s.Device = strings.ToLower(viper.GetString("device"))
|
||||
s.AudioCodec = strings.ToLower(viper.GetString("acodec"))
|
||||
s.AudioParams = strings.ToLower(viper.GetString("aparams"))
|
||||
s.Display = strings.ToLower(viper.GetString("display"))
|
||||
s.VideoCodec = strings.ToLower(viper.GetString("vcodec"))
|
||||
s.VideoParams = strings.ToLower(viper.GetString("vparams"))
|
||||
videoCodec := webrtc.VP8
|
||||
if viper.GetBool("vp8") {
|
||||
videoCodec = webrtc.VP8
|
||||
} else if viper.GetBool("vp9") {
|
||||
videoCodec = webrtc.VP9
|
||||
} else if viper.GetBool("h264") {
|
||||
videoCodec = webrtc.H264
|
||||
}
|
||||
|
||||
audioCodec := webrtc.VP8
|
||||
if viper.GetBool("opus") {
|
||||
audioCodec = webrtc.Opus
|
||||
} else if viper.GetBool("g722") {
|
||||
audioCodec = webrtc.G722
|
||||
} else if viper.GetBool("pcmu") {
|
||||
audioCodec = webrtc.PCMU
|
||||
} else if viper.GetBool("pcma") {
|
||||
audioCodec = webrtc.PCMA
|
||||
}
|
||||
|
||||
s.Device = viper.GetString("device")
|
||||
s.AudioCodec = audioCodec
|
||||
s.AudioParams = viper.GetString("aparams")
|
||||
s.Display = viper.GetString("display")
|
||||
s.VideoCodec = videoCodec
|
||||
s.VideoParams = viper.GetString("vparams")
|
||||
}
|
||||
|
@ -9,12 +9,13 @@ package gst
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/pion/webrtc/v2/pkg/media"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"n.eko.moe/neko/internal/types"
|
||||
)
|
||||
|
||||
/*
|
||||
@ -35,10 +36,10 @@ import (
|
||||
// Pipeline is a wrapper for a GStreamer Pipeline
|
||||
type Pipeline struct {
|
||||
Pipeline *C.GstElement
|
||||
tracks []*webrtc.Track
|
||||
Sample chan types.Sample
|
||||
CodecName string
|
||||
ClockRate float32
|
||||
id int
|
||||
codecName string
|
||||
clockRate float32
|
||||
}
|
||||
|
||||
var pipelines = make(map[int]*Pipeline)
|
||||
@ -57,7 +58,7 @@ func init() {
|
||||
}
|
||||
|
||||
// CreatePipeline creates a GStreamer Pipeline
|
||||
func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string) *Pipeline {
|
||||
func CreatePipeline(codecName string, pipelineSrc string) (*Pipeline, error) {
|
||||
pipelineStr := "appsink name=appsink"
|
||||
var clockRate float32
|
||||
|
||||
@ -70,7 +71,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = videoClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case webrtc.VP9:
|
||||
@ -83,7 +84,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = videoClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case webrtc.H264:
|
||||
@ -98,14 +99,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = videoClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"ximagesrc"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := CheckPlugins([]string{"openh264"}); err != nil {
|
||||
pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr
|
||||
|
||||
if err := CheckPlugins([]string{"x264"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,7 +118,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = audioClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case webrtc.G722:
|
||||
@ -128,7 +129,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = audioClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case webrtc.PCMU:
|
||||
@ -140,7 +141,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = pcmClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case webrtc.PCMA:
|
||||
@ -151,11 +152,11 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
clockRate = pcmClockRate
|
||||
|
||||
if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
panic("Unhandled codec " + codecName)
|
||||
return nil, errors.Errorf("unknown video codec %s", codecName)
|
||||
}
|
||||
|
||||
pipelineStrUnsafe := C.CString(pipelineStr)
|
||||
@ -166,14 +167,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
||||
|
||||
pipeline := &Pipeline{
|
||||
Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe),
|
||||
tracks: tracks,
|
||||
Sample: make(chan types.Sample),
|
||||
CodecName: codecName,
|
||||
ClockRate: clockRate,
|
||||
id: len(pipelines),
|
||||
codecName: codecName,
|
||||
clockRate: clockRate,
|
||||
}
|
||||
|
||||
pipelines[pipeline.id] = pipeline
|
||||
return pipeline
|
||||
return pipeline, nil
|
||||
}
|
||||
|
||||
// Start starts the GStreamer Pipeline
|
||||
@ -193,14 +194,13 @@ func CheckPlugins(plugins []string) error {
|
||||
plugin = C.gst_registry_find_plugin(registry, plugincstr)
|
||||
C.free(unsafe.Pointer(plugincstr))
|
||||
if plugin == nil {
|
||||
return fmt.Errorf("Required gstreamer plugin %s not found", pluginstr)
|
||||
return fmt.Errorf("required gstreamer plugin %s not found", pluginstr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
//export goHandlePipelineBuffer
|
||||
func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) {
|
||||
pipelinesLock.Lock()
|
||||
@ -208,12 +208,8 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i
|
||||
pipelinesLock.Unlock()
|
||||
|
||||
if ok {
|
||||
samples := uint32(pipeline.clockRate * (float32(duration) / 1000000000))
|
||||
for _, t := range pipeline.tracks {
|
||||
if err := t.WriteSample(media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}); err != nil && err != io.ErrClosedPipe {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
samples := uint32(pipeline.ClockRate * (float32(duration) / 1000000000))
|
||||
pipeline.Sample <- types.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}
|
||||
} else {
|
||||
fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID))
|
||||
}
|
||||
|
@ -3,15 +3,17 @@ package session
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/kataras/go-events"
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
)
|
||||
|
||||
func New() *SessionManager {
|
||||
return &SessionManager{
|
||||
logger: log.With().Str("module", "session").Logger(),
|
||||
host: "",
|
||||
members: make(map[string]*Session),
|
||||
emmiter: events.New(),
|
||||
@ -19,153 +21,100 @@ func New() *SessionManager {
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
logger zerolog.Logger
|
||||
host string
|
||||
members map[string]*Session
|
||||
emmiter events.EventEmmiter
|
||||
}
|
||||
|
||||
func (m *SessionManager) New(id string, admin bool, socket *websocket.Conn) *Session {
|
||||
func (manager *SessionManager) New(id string, admin bool, socket types.WebScoket) types.Session {
|
||||
session := &Session{
|
||||
ID: id,
|
||||
Admin: admin,
|
||||
id: id,
|
||||
admin: admin,
|
||||
manager: manager,
|
||||
socket: socket,
|
||||
logger: manager.logger.With().Str("id", id).Logger(),
|
||||
connected: false,
|
||||
}
|
||||
|
||||
m.members[id] = session
|
||||
m.emmiter.Emit("created", id, session)
|
||||
manager.members[id] = session
|
||||
manager.emmiter.Emit("created", id, session)
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (m *SessionManager) IsHost(id string) bool {
|
||||
return m.host == id
|
||||
func (manager *SessionManager) HasHost() bool {
|
||||
return manager.host != ""
|
||||
}
|
||||
|
||||
func (m *SessionManager) HasHost() bool {
|
||||
return m.host != ""
|
||||
func (manager *SessionManager) IsHost(id string) bool {
|
||||
return manager.host == id
|
||||
}
|
||||
|
||||
func (m *SessionManager) SetHost(id string) error {
|
||||
_, ok := m.members[id]
|
||||
func (manager *SessionManager) SetHost(id string) error {
|
||||
_, ok := manager.members[id]
|
||||
if ok {
|
||||
m.host = id
|
||||
m.emmiter.Emit("host", id)
|
||||
manager.host = id
|
||||
manager.emmiter.Emit("host", id)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid session id %s", id)
|
||||
}
|
||||
|
||||
func (m *SessionManager) GetHost() (*Session, bool) {
|
||||
host, ok := m.members[m.host]
|
||||
func (manager *SessionManager) GetHost() (types.Session, bool) {
|
||||
host, ok := manager.members[manager.host]
|
||||
return host, ok
|
||||
}
|
||||
|
||||
func (m *SessionManager) ClearHost() {
|
||||
id := m.host
|
||||
m.host = ""
|
||||
m.emmiter.Emit("host_cleared", id)
|
||||
func (manager *SessionManager) ClearHost() {
|
||||
id := manager.host
|
||||
manager.host = ""
|
||||
manager.emmiter.Emit("host_cleared", id)
|
||||
}
|
||||
|
||||
func (m *SessionManager) Has(id string) bool {
|
||||
_, ok := m.members[id]
|
||||
func (manager *SessionManager) Has(id string) bool {
|
||||
_, ok := manager.members[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(id string) (*Session, bool) {
|
||||
session, ok := m.members[id]
|
||||
func (manager *SessionManager) Get(id string) (types.Session, bool) {
|
||||
session, ok := manager.members[id]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (m *SessionManager) GetConnected() []*Session {
|
||||
var sessions []*Session
|
||||
for _, sess := range m.members {
|
||||
if sess.connected {
|
||||
sessions = append(sessions, sess)
|
||||
func (manager *SessionManager) Members() []*types.Member {
|
||||
members := []*types.Member{}
|
||||
for _, session := range manager.members {
|
||||
if !session.connected {
|
||||
continue
|
||||
}
|
||||
|
||||
member := session.Member()
|
||||
if member != nil {
|
||||
members = append(members, member)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions
|
||||
return members
|
||||
}
|
||||
|
||||
func (m *SessionManager) Set(id string, session *Session) {
|
||||
m.members[id] = session
|
||||
}
|
||||
|
||||
func (m *SessionManager) Destroy(id string) error {
|
||||
session, ok := m.members[id]
|
||||
func (manager *SessionManager) Destroy(id string) error {
|
||||
session, ok := manager.members[id]
|
||||
if ok {
|
||||
err := session.destroy()
|
||||
delete(m.members, id)
|
||||
m.emmiter.Emit("destroyed", id)
|
||||
delete(manager.members, id)
|
||||
manager.emmiter.Emit("destroyed", id)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) SetSocket(id string, socket *websocket.Conn) (bool, error) {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
session.socket = socket
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("invalid session id %s", id)
|
||||
}
|
||||
|
||||
func (m *SessionManager) SetPeer(id string, peer *webrtc.PeerConnection) (bool, error) {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
session.peer = peer
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("invalid session id %s", id)
|
||||
}
|
||||
|
||||
func (m *SessionManager) SetName(id string, name string) (bool, error) {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
session.Name = name
|
||||
session.connected = true
|
||||
m.emmiter.Emit("connected", id, session)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("invalid session id %s", id)
|
||||
}
|
||||
|
||||
func (m *SessionManager) Mute(id string) error {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
session.Muted = true
|
||||
}
|
||||
func (manager *SessionManager) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Unmute(id string) error {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
session.Muted = false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Kick(id string, v interface{}) error {
|
||||
session, ok := m.members[id]
|
||||
if ok {
|
||||
return session.Kick(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
||||
for id, sess := range m.members {
|
||||
if !sess.connected {
|
||||
func (manager *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
||||
for id, session := range manager.members {
|
||||
if !session.connected {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -175,39 +124,65 @@ func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := sess.Send(v); err != nil {
|
||||
if err := session.Send(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnHost(listener func(id string)) {
|
||||
m.emmiter.On("host", func(payload ...interface{}) {
|
||||
func (manager *SessionManager) WriteVideoSample(sample types.Sample) error {
|
||||
for _, session := range manager.members {
|
||||
if !session.connected {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := session.WriteVideoSample(sample); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *SessionManager) WriteAudioSample(sample types.Sample) error {
|
||||
for _, session := range manager.members {
|
||||
if !session.connected {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := session.WriteAudioSample(sample); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *SessionManager) OnHost(listener func(id string)) {
|
||||
manager.emmiter.On("host", func(payload ...interface{}) {
|
||||
listener(payload[0].(string))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnHostCleared(listener func(id string)) {
|
||||
m.emmiter.On("host_cleared", func(payload ...interface{}) {
|
||||
func (manager *SessionManager) OnHostCleared(listener func(id string)) {
|
||||
manager.emmiter.On("host_cleared", func(payload ...interface{}) {
|
||||
listener(payload[0].(string))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnCreated(listener func(id string, session *Session)) {
|
||||
m.emmiter.On("created", func(payload ...interface{}) {
|
||||
func (manager *SessionManager) OnDestroy(listener func(id string)) {
|
||||
manager.emmiter.On("destroyed", func(payload ...interface{}) {
|
||||
listener(payload[0].(string))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) {
|
||||
manager.emmiter.On("created", func(payload ...interface{}) {
|
||||
listener(payload[0].(string), payload[1].(*Session))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnConnected(listener func(id string, session *Session)) {
|
||||
m.emmiter.On("connected", func(payload ...interface{}) {
|
||||
func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) {
|
||||
manager.emmiter.On("connected", func(payload ...interface{}) {
|
||||
listener(payload[0].(string), payload[1].(*Session))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnDestroy(listener func(id string)) {
|
||||
m.emmiter.On("destroyed", func(payload ...interface{}) {
|
||||
listener(payload[0].(string))
|
||||
})
|
||||
}
|
||||
|
@ -3,38 +3,90 @@ package session
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"username"`
|
||||
Admin bool `json:"admin"`
|
||||
Muted bool `json:"muted"`
|
||||
logger zerolog.Logger
|
||||
id string
|
||||
name string
|
||||
admin bool
|
||||
muted bool
|
||||
connected bool
|
||||
socket *websocket.Conn
|
||||
peer *webrtc.PeerConnection
|
||||
manager *SessionManager
|
||||
socket types.WebScoket
|
||||
peer types.Peer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (session *Session) RemoteAddr() *string {
|
||||
if session.socket != nil {
|
||||
address := session.socket.RemoteAddr().String()
|
||||
return &address
|
||||
func (session *Session) ID() string {
|
||||
return session.id
|
||||
}
|
||||
|
||||
func (session *Session) Name() string {
|
||||
return session.name
|
||||
}
|
||||
|
||||
func (session *Session) Admin() bool {
|
||||
return session.admin
|
||||
}
|
||||
|
||||
func (session *Session) Muted() bool {
|
||||
return session.muted
|
||||
}
|
||||
|
||||
func (session *Session) Connected() bool {
|
||||
return session.connected
|
||||
}
|
||||
|
||||
func (session *Session) Address() *string {
|
||||
if session.socket == nil {
|
||||
return nil
|
||||
}
|
||||
return session.socket.Address()
|
||||
}
|
||||
|
||||
func (session *Session) Member() *types.Member {
|
||||
return &types.Member{
|
||||
ID: session.id,
|
||||
Name: session.name,
|
||||
Admin: session.admin,
|
||||
Muted: session.muted,
|
||||
}
|
||||
}
|
||||
|
||||
func (session *Session) SetMuted(muted bool) {
|
||||
session.muted = muted
|
||||
}
|
||||
|
||||
func (session *Session) SetName(name string) error {
|
||||
session.name = name
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: write to peer data channel
|
||||
func (session *Session) Write(v interface{}) error {
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
func (session *Session) SetSocket(socket types.WebScoket) error {
|
||||
session.socket = socket
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) Kick(v interface{}) error {
|
||||
if err := session.Send(v); err != nil {
|
||||
func (session *Session) SetPeer(peer types.Peer) error {
|
||||
session.peer = peer
|
||||
session.connected = true
|
||||
session.manager.emmiter.Emit("connected", session.id, session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) Kick(reason string) error {
|
||||
if session.socket == nil {
|
||||
return nil
|
||||
}
|
||||
if err := session.socket.Send(&message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: reason,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -42,27 +94,40 @@ func (session *Session) Kick(v interface{}) error {
|
||||
}
|
||||
|
||||
func (session *Session) Send(v interface{}) error {
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
if session.socket != nil {
|
||||
return session.socket.WriteJSON(v)
|
||||
if session.socket == nil {
|
||||
return nil
|
||||
}
|
||||
return session.socket.Send(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
func (session *Session) Write(v interface{}) error {
|
||||
if session.socket == nil {
|
||||
return nil
|
||||
}
|
||||
return session.socket.Send(v)
|
||||
}
|
||||
|
||||
func (session *Session) WriteVideoSample(sample types.Sample) error {
|
||||
if session.peer == nil || !session.connected {
|
||||
return nil
|
||||
}
|
||||
return session.peer.WriteVideoSample(sample)
|
||||
}
|
||||
|
||||
func (session *Session) WriteAudioSample(sample types.Sample) error {
|
||||
if session.peer == nil || !session.connected {
|
||||
return nil
|
||||
}
|
||||
return session.peer.WriteAudioSample(sample)
|
||||
}
|
||||
|
||||
func (session *Session) destroy() error {
|
||||
if session.peer != nil && session.peer.ConnectionState() == webrtc.PeerConnectionStateConnected {
|
||||
if err := session.peer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.socket.Destroy(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if session.socket != nil {
|
||||
if err := session.socket.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.peer.Destroy(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1,6 +1,8 @@
|
||||
package message
|
||||
|
||||
import "n.eko.moe/neko/internal/session"
|
||||
import (
|
||||
"n.eko.moe/neko/internal/types"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Event string `json:"event"`
|
||||
@ -27,13 +29,13 @@ type Signal struct {
|
||||
}
|
||||
|
||||
type MembersList struct {
|
||||
Event string `json:"event"`
|
||||
Memebers []*session.Session `json:"members"`
|
||||
Event string `json:"event"`
|
||||
Memebers []*types.Member `json:"members"`
|
||||
}
|
||||
|
||||
type Member struct {
|
||||
Event string `json:"event"`
|
||||
*session.Session
|
||||
*types.Member
|
||||
}
|
||||
type MemberDisconnected struct {
|
||||
Event string `json:"event"`
|
49
server/internal/types/session.go
Normal file
49
server/internal/types/session.go
Normal file
@ -0,0 +1,49 @@
|
||||
package types
|
||||
|
||||
type Member struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"username"`
|
||||
Admin bool `json:"admin"`
|
||||
Muted bool `json:"muted"`
|
||||
}
|
||||
|
||||
type Session interface {
|
||||
ID() string
|
||||
Name() string
|
||||
Admin() bool
|
||||
Muted() bool
|
||||
Connected() bool
|
||||
Member() *Member
|
||||
SetMuted(muted bool)
|
||||
SetName(name string) error
|
||||
SetSocket(socket WebScoket) error
|
||||
SetPeer(peer Peer) error
|
||||
Address() *string
|
||||
Kick(message string) error
|
||||
Write(v interface{}) error
|
||||
Send(v interface{}) error
|
||||
WriteAudioSample(sample Sample) error
|
||||
WriteVideoSample(sample Sample) error
|
||||
}
|
||||
|
||||
type SessionManager interface {
|
||||
New(id string, admin bool, socket WebScoket) Session
|
||||
HasHost() bool
|
||||
IsHost(id string) bool
|
||||
SetHost(id string) error
|
||||
GetHost() (Session, bool)
|
||||
ClearHost()
|
||||
Has(id string) bool
|
||||
Get(id string) (Session, bool)
|
||||
Members() []*Member
|
||||
Destroy(id string) error
|
||||
Clear() error
|
||||
Brodcast(v interface{}, exclude interface{}) error
|
||||
WriteAudioSample(sample Sample) error
|
||||
WriteVideoSample(sample Sample) error
|
||||
OnHost(listener func(id string))
|
||||
OnHostCleared(listener func(id string))
|
||||
OnDestroy(listener func(id string))
|
||||
OnCreated(listener func(id string, session Session))
|
||||
OnConnected(listener func(id string, session Session))
|
||||
}
|
19
server/internal/types/webrtc.go
Normal file
19
server/internal/types/webrtc.go
Normal file
@ -0,0 +1,19 @@
|
||||
package types
|
||||
|
||||
type Sample struct {
|
||||
Data []byte
|
||||
Samples uint32
|
||||
}
|
||||
|
||||
type WebRTCManager interface {
|
||||
Start()
|
||||
Shutdown() error
|
||||
CreatePeer(id string, sdp string) (string, Peer, error)
|
||||
}
|
||||
|
||||
type Peer interface {
|
||||
WriteVideoSample(sample Sample) error
|
||||
WriteAudioSample(sample Sample) error
|
||||
WriteData(v interface{}) error
|
||||
Destroy() error
|
||||
}
|
7
server/internal/types/webscoket.go
Normal file
7
server/internal/types/webscoket.go
Normal file
@ -0,0 +1,7 @@
|
||||
package types
|
||||
|
||||
type WebScoket interface {
|
||||
Address() *string
|
||||
Send(v interface{}) error
|
||||
Destroy() error
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@ -13,8 +16,19 @@ func (l logger) Trace(msg string) { l.logger.Trace().Ms
|
||||
func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) }
|
||||
func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) }
|
||||
func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) }
|
||||
func (l logger) Info(msg string) { l.logger.Info().Msg(msg) }
|
||||
func (l logger) Infof(format string, args ...interface{}) { l.logger.Info().Msgf(format, args...) }
|
||||
func (l logger) Info(msg string) {
|
||||
if strings.Contains(msg, "packetio.Buffer is full") {
|
||||
l.logger.Panic().Msg(msg)
|
||||
}
|
||||
l.logger.Info().Msg(msg)
|
||||
}
|
||||
func (l logger) Infof(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if strings.Contains(msg, "packetio.Buffer is full") {
|
||||
l.logger.Panic().Msg(msg)
|
||||
}
|
||||
l.logger.Info().Msgf(format, args...)
|
||||
}
|
||||
func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) }
|
||||
func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) }
|
||||
func (l logger) Error(msg string) { l.logger.Error().Msg(msg) }
|
||||
|
@ -1,226 +0,0 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/gst"
|
||||
"n.eko.moe/neko/internal/hid"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
)
|
||||
|
||||
func New(sessions *session.SessionManager, conf *config.WebRTC) *WebRTCManager {
|
||||
logger := log.With().Str("module", "webrtc").Logger()
|
||||
engine := webrtc.MediaEngine{}
|
||||
engine.RegisterDefaultCodecs()
|
||||
|
||||
setings := webrtc.SettingEngine{
|
||||
LoggerFactory: loggerFactory{
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
|
||||
return &WebRTCManager{
|
||||
logger: logger,
|
||||
engine: engine,
|
||||
setings: setings,
|
||||
api: webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(setings)),
|
||||
cleanup: time.NewTicker(1 * time.Second),
|
||||
shutdown: make(chan bool),
|
||||
sessions: sessions,
|
||||
conf: conf,
|
||||
config: webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
},
|
||||
},
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type WebRTCManager struct {
|
||||
logger zerolog.Logger
|
||||
engine webrtc.MediaEngine
|
||||
setings webrtc.SettingEngine
|
||||
config webrtc.Configuration
|
||||
sessions *session.SessionManager
|
||||
api *webrtc.API
|
||||
video *webrtc.Track
|
||||
audio *webrtc.Track
|
||||
videoPipeline *gst.Pipeline
|
||||
audioPipeline *gst.Pipeline
|
||||
cleanup *time.Ticker
|
||||
conf *config.WebRTC
|
||||
shutdown chan bool
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) Start() {
|
||||
|
||||
hid.Display(m.conf.Display)
|
||||
|
||||
switch m.conf.VideoCodec {
|
||||
case "vp8":
|
||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP8); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
case "vp9":
|
||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP9); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
case "h264":
|
||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeH264); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
default:
|
||||
m.logger.Panic().Err(errors.Errorf("unknown video codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
|
||||
}
|
||||
|
||||
switch m.conf.AudioCodec {
|
||||
case "opus":
|
||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeOpus); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
case "g722":
|
||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeG722); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
case "pcmu":
|
||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMU); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
case "pcma":
|
||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMA); err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
default:
|
||||
m.logger.Panic().Err(errors.Errorf("unknown audio codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
|
||||
}
|
||||
|
||||
m.videoPipeline.Start()
|
||||
m.audioPipeline.Start()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
m.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.shutdown:
|
||||
return
|
||||
case <-m.cleanup.C:
|
||||
hid.Check(time.Second * 10)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
m.sessions.OnHostCleared(func(id string) {
|
||||
hid.Reset()
|
||||
})
|
||||
|
||||
m.sessions.OnCreated(func(id string, session *session.Session) {
|
||||
m.logger.Debug().Str("id", id).Msg("session created")
|
||||
})
|
||||
|
||||
m.sessions.OnDestroy(func(id string) {
|
||||
m.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
})
|
||||
|
||||
// TODO: log resolution, bit rate and codec parameters
|
||||
m.logger.Info().
|
||||
Str("video_display", m.conf.Display).
|
||||
Str("video_codec", m.conf.VideoCodec).
|
||||
Str("audio_device", m.conf.Device).
|
||||
Str("audio_codec", m.conf.AudioCodec).
|
||||
Msgf("webrtc streaming")
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) Shutdown() error {
|
||||
m.logger.Info().Msgf("webrtc shutting down")
|
||||
|
||||
m.cleanup.Stop()
|
||||
m.shutdown <- true
|
||||
m.videoPipeline.Stop()
|
||||
m.audioPipeline.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) CreatePeer(id string, sdp string) error {
|
||||
session, ok := m.sessions.Get(id)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid session id %s", id)
|
||||
}
|
||||
|
||||
peer, err := m.api.NewPeerConnection(m.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := peer.AddTransceiverFromTrack(m.video, webrtc.RtpTransceiverInit{
|
||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := peer.AddTransceiverFromTrack(m.audio, webrtc.RtpTransceiverInit{
|
||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer.SetRemoteDescription(webrtc.SessionDescription{
|
||||
SDP: sdp,
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
})
|
||||
|
||||
answer, err := peer.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = peer.SetLocalDescription(answer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.Send(message.Signal{
|
||||
Event: event.SIGNAL_ANSWER,
|
||||
SDP: answer.SDP,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer.OnDataChannel(func(d *webrtc.DataChannel) {
|
||||
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
if err = m.handle(id, msg); err != nil {
|
||||
m.logger.Warn().Err(err).Msg("data handle failed")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
peer.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
|
||||
switch connectionState {
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
m.logger.Info().Str("id", id).Msg("peer disconnected")
|
||||
m.sessions.Destroy(id)
|
||||
break
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
m.logger.Info().Str("id", id).Msg("peer connected")
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
m.sessions.SetPeer(id, peer)
|
||||
|
||||
return nil
|
||||
}
|
44
server/internal/webrtc/peer.go
Normal file
44
server/internal/webrtc/peer.go
Normal file
@ -0,0 +1,44 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/pion/webrtc/v2/pkg/media"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
id string
|
||||
engine webrtc.MediaEngine
|
||||
api *webrtc.API
|
||||
video *webrtc.Track
|
||||
audio *webrtc.Track
|
||||
connection *webrtc.PeerConnection
|
||||
}
|
||||
|
||||
func (peer *Peer) WriteAudioSample(sample types.Sample) error {
|
||||
if err := peer.audio.WriteSample(media.Sample(sample)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) WriteVideoSample(sample types.Sample) error {
|
||||
if err := peer.video.WriteSample(media.Sample(sample)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) WriteData(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *Peer) Destroy() error {
|
||||
if peer.connection != nil && peer.connection.ConnectionState() == webrtc.PeerConnectionStateConnected {
|
||||
if err := peer.connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -5,95 +5,36 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"n.eko.moe/neko/internal/gst"
|
||||
)
|
||||
|
||||
func (m *WebRTCManager) createVideoTrack(payloadType uint8) error {
|
||||
|
||||
clockrate := uint32(90000)
|
||||
func (m *WebRTCManager) createVideoTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
|
||||
var codec *webrtc.RTPCodec
|
||||
switch payloadType {
|
||||
case webrtc.DefaultPayloadTypeVP8:
|
||||
codec = webrtc.NewRTPVP8Codec(payloadType, clockrate)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeVP9:
|
||||
codec = webrtc.NewRTPVP9Codec(payloadType, clockrate)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeH264:
|
||||
codec = webrtc.NewRTPH264Codec(payloadType, clockrate)
|
||||
break
|
||||
default:
|
||||
return errors.Errorf("unknown video codec %s", payloadType)
|
||||
for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeVideo) {
|
||||
if videoCodec.Name == m.videoPipeline.CodecName {
|
||||
codec = videoCodec
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
|
||||
if err != nil {
|
||||
return err
|
||||
if codec == nil || codec.PayloadType == 0 {
|
||||
return nil, fmt.Errorf("remote peer does not support %s", m.videoPipeline.CodecName)
|
||||
}
|
||||
|
||||
var pipeline *gst.Pipeline
|
||||
src := fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.conf.Display)
|
||||
switch payloadType {
|
||||
case webrtc.DefaultPayloadTypeVP8:
|
||||
pipeline = gst.CreatePipeline(webrtc.VP8, []*webrtc.Track{track}, src)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeVP9:
|
||||
pipeline = gst.CreatePipeline(webrtc.VP9, []*webrtc.Track{track}, src)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeH264:
|
||||
pipeline = gst.CreatePipeline(webrtc.H264, []*webrtc.Track{track}, src)
|
||||
break
|
||||
}
|
||||
|
||||
m.video = track
|
||||
m.videoPipeline = pipeline
|
||||
return nil
|
||||
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) createAudioTrack(payloadType uint8) error {
|
||||
func (m *WebRTCManager) createAudioTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
|
||||
var codec *webrtc.RTPCodec
|
||||
switch payloadType {
|
||||
case webrtc.DefaultPayloadTypeOpus:
|
||||
codec = webrtc.NewRTPOpusCodec(payloadType, 48000)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeG722:
|
||||
codec = webrtc.NewRTPG722Codec(payloadType, 48000)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypePCMU:
|
||||
codec = webrtc.NewRTPPCMUCodec(payloadType, 8000)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypePCMA:
|
||||
codec = webrtc.NewRTPPCMACodec(payloadType, 8000)
|
||||
break
|
||||
default:
|
||||
return errors.Errorf("unknown audio codec %s", payloadType)
|
||||
for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeAudio) {
|
||||
if videoCodec.Name == m.audioPipeline.CodecName {
|
||||
codec = videoCodec
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
|
||||
if err != nil {
|
||||
return err
|
||||
if codec == nil || codec.PayloadType == 0 {
|
||||
return nil, fmt.Errorf("remote peer does not support %s", m.audioPipeline.CodecName)
|
||||
}
|
||||
|
||||
var pipeline *gst.Pipeline
|
||||
src := fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.conf.Device)
|
||||
switch payloadType {
|
||||
case webrtc.DefaultPayloadTypeOpus:
|
||||
pipeline = gst.CreatePipeline(webrtc.Opus, []*webrtc.Track{track}, src)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypeG722:
|
||||
pipeline = gst.CreatePipeline(webrtc.G722, []*webrtc.Track{track}, src)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypePCMU:
|
||||
pipeline = gst.CreatePipeline(webrtc.PCMU, []*webrtc.Track{track}, src)
|
||||
break
|
||||
case webrtc.DefaultPayloadTypePCMA:
|
||||
pipeline = gst.CreatePipeline(webrtc.PCMA, []*webrtc.Track{track}, src)
|
||||
break
|
||||
}
|
||||
|
||||
m.audio = track
|
||||
m.audioPipeline = pipeline
|
||||
return nil
|
||||
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
|
||||
}
|
||||
|
237
server/internal/webrtc/webrtc.go
Normal file
237
server/internal/webrtc/webrtc.go
Normal file
@ -0,0 +1,237 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/gst"
|
||||
"n.eko.moe/neko/internal/hid"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
)
|
||||
|
||||
func New(sessions types.SessionManager, config *config.WebRTC) *WebRTCManager {
|
||||
logger := log.With().Str("module", "webrtc").Logger()
|
||||
setings := webrtc.SettingEngine{
|
||||
LoggerFactory: loggerFactory{
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
|
||||
return &WebRTCManager{
|
||||
logger: logger,
|
||||
setings: setings,
|
||||
cleanup: time.NewTicker(1 * time.Second),
|
||||
shutdown: make(chan bool),
|
||||
sessions: sessions,
|
||||
config: config,
|
||||
configuration: &webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
},
|
||||
},
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type WebRTCManager struct {
|
||||
logger zerolog.Logger
|
||||
setings webrtc.SettingEngine
|
||||
sessions types.SessionManager
|
||||
videoPipeline *gst.Pipeline
|
||||
audioPipeline *gst.Pipeline
|
||||
cleanup *time.Ticker
|
||||
config *config.WebRTC
|
||||
shutdown chan bool
|
||||
configuration *webrtc.Configuration
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) Start() {
|
||||
hid.Display(m.config.Display)
|
||||
|
||||
videoPipeline, err := gst.CreatePipeline(
|
||||
m.config.VideoCodec,
|
||||
fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.config.Display),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
|
||||
audioPipeline, err := gst.CreatePipeline(
|
||||
m.config.AudioCodec,
|
||||
fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.config.Device),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||
}
|
||||
|
||||
m.videoPipeline = videoPipeline
|
||||
m.audioPipeline = audioPipeline
|
||||
|
||||
videoPipeline.Start()
|
||||
audioPipeline.Start()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
m.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.shutdown:
|
||||
return
|
||||
case sample := <-videoPipeline.Sample:
|
||||
if err := m.sessions.WriteVideoSample(sample); err != nil {
|
||||
m.logger.Warn().Err(err).Msg("video pipeline failed")
|
||||
}
|
||||
case sample := <-audioPipeline.Sample:
|
||||
if err := m.sessions.WriteAudioSample(sample); err != nil {
|
||||
m.logger.Warn().Err(err).Msg("audio pipeline failed")
|
||||
}
|
||||
case <-m.cleanup.C:
|
||||
hid.Check(time.Second * 10)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
m.sessions.OnHostCleared(func(id string) {
|
||||
hid.Reset()
|
||||
})
|
||||
|
||||
m.sessions.OnCreated(func(id string, session types.Session) {
|
||||
m.logger.Debug().Str("id", id).Msg("session created")
|
||||
})
|
||||
|
||||
m.sessions.OnDestroy(func(id string) {
|
||||
m.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
})
|
||||
|
||||
// TODO: log resolution, bit rate and codec parameters
|
||||
m.logger.Info().
|
||||
Str("video_display", m.config.Display).
|
||||
Str("video_codec", m.config.VideoCodec).
|
||||
Str("audio_device", m.config.Device).
|
||||
Str("audio_codec", m.config.AudioCodec).
|
||||
Msgf("webrtc streaming")
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) Shutdown() error {
|
||||
m.logger.Info().Msgf("webrtc shutting down")
|
||||
m.videoPipeline.Stop()
|
||||
m.audioPipeline.Stop()
|
||||
m.cleanup.Stop()
|
||||
m.shutdown <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WebRTCManager) CreatePeer(id string, sdp string) (string, types.Peer, error) {
|
||||
// create SessionDescription
|
||||
description := webrtc.SessionDescription{
|
||||
SDP: sdp,
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
}
|
||||
|
||||
// create MediaEngine based off sdp
|
||||
engine := webrtc.MediaEngine{}
|
||||
engine.PopulateFromSDP(description)
|
||||
|
||||
// create API with MediaEngine and SettingEngine
|
||||
api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(m.setings))
|
||||
|
||||
// create new peer connection
|
||||
connection, err := api.NewPeerConnection(*m.configuration)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// create video track
|
||||
video, err := m.createVideoTrack(engine)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
videoTransceiver, err := connection.AddTransceiverFromTrack(video, webrtc.RtpTransceiverInit{
|
||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// create audio track
|
||||
audio, err := m.createAudioTrack(engine)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
audioTransceiver, err := connection.AddTransceiverFromTrack(audio, webrtc.RtpTransceiverInit{
|
||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// clear the Transceiver bufers
|
||||
go func() {
|
||||
for {
|
||||
if _, err := audioTransceiver.Sender.ReadRTCP(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = videoTransceiver.Sender.ReadRTCP(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// set remote description
|
||||
connection.SetRemoteDescription(description)
|
||||
|
||||
answer, err := connection.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err = connection.SetLocalDescription(answer); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
connection.OnDataChannel(func(d *webrtc.DataChannel) {
|
||||
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
if err = m.handle(id, msg); err != nil {
|
||||
m.logger.Warn().Err(err).Msg("data handle failed")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
m.logger.Info().Str("id", id).Msg("peer disconnected")
|
||||
m.sessions.Destroy(id)
|
||||
break
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
m.logger.Info().Str("id", id).Msg("peer connected")
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
return answer.SDP, &Peer{
|
||||
id: id,
|
||||
api: api,
|
||||
engine: engine,
|
||||
video: video,
|
||||
audio: audio,
|
||||
connection: connection,
|
||||
}, nil
|
||||
}
|
@ -3,13 +3,13 @@ package websocket
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminLock(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -33,8 +33,8 @@ func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminUnlock(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminUnlock(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -58,8 +58,8 @@ func (h *MessageHandler) adminUnlock(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminControl(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminControl(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -73,7 +73,7 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_CONTROL,
|
||||
ID: id,
|
||||
Target: host.ID,
|
||||
Target: host.ID(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL)
|
||||
return err
|
||||
@ -92,8 +92,8 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminRelease(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminRelease(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -107,7 +107,7 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_RELEASE,
|
||||
ID: id,
|
||||
Target: host.ID,
|
||||
Target: host.ID(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE)
|
||||
return err
|
||||
@ -126,8 +126,8 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminGive(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -154,8 +154,8 @@ func (h *MessageHandler) adminGive(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminMute(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -166,17 +166,17 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
target.Muted = true
|
||||
target.SetMuted(true)
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_MUTE,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||
@ -186,8 +186,8 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminUnmute(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -198,12 +198,12 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
||||
return nil
|
||||
}
|
||||
|
||||
target.Muted = false
|
||||
target.SetMuted(false)
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_UNMUTE,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||
@ -213,8 +213,8 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminKick(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -225,22 +225,19 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := target.Kick(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "You have been kicked",
|
||||
}); err != nil {
|
||||
if err := target.Kick("You have been kicked"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_KICK,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, []string{payload.ID}); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK)
|
||||
@ -250,8 +247,8 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminBan(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -262,12 +259,12 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
remote := target.RemoteAddr()
|
||||
remote := target.Address()
|
||||
if remote == nil {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
return nil
|
||||
@ -283,17 +280,14 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
||||
|
||||
h.banned[address[0]] = true
|
||||
|
||||
if err := target.Kick(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "You have been banned",
|
||||
}); err != nil {
|
||||
if err := target.Kick("You have been banned"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_BAN,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, []string{payload.ID}); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN)
|
||||
|
@ -1,13 +1,13 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) chat(id string, session *session.Session, payload *message.ChatRecieve) error {
|
||||
if session.Muted {
|
||||
func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatRecieve) error {
|
||||
if session.Muted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -23,8 +23,8 @@ func (h *MessageHandler) chat(id string, session *session.Session, payload *mess
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) chatEmote(id string, session *session.Session, payload *message.EmoteRecieve) error {
|
||||
if session.Muted {
|
||||
func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteRecieve) error {
|
||||
if session.Muted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) controlRelease(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) controlRelease(id string, session types.Session) error {
|
||||
|
||||
// check if session is host
|
||||
if !h.sessions.IsHost(id) {
|
||||
@ -31,7 +31,7 @@ func (h *MessageHandler) controlRelease(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) controlRequest(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) controlRequest(id string, session types.Session) error {
|
||||
// check for host
|
||||
if !h.sessions.HasHost() {
|
||||
// set host
|
||||
@ -57,7 +57,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
// tell session there is a host
|
||||
if err := session.Send(message.Control{
|
||||
Event: event.CONTROL_REQUEST,
|
||||
ID: host.ID,
|
||||
ID: host.ID(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
|
||||
return err
|
||||
@ -68,7 +68,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
Event: event.CONTROL_REQUESTING,
|
||||
ID: id,
|
||||
}); err != nil {
|
||||
h.logger.Warn().Err(err).Str("id", host.ID).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
||||
h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -76,7 +76,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) controlGive(id string, session *session.Session, payload *message.Control) error {
|
||||
func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error {
|
||||
// check if session is host
|
||||
if !h.sessions.IsHost(id) {
|
||||
h.logger.Debug().Str("id", id).Msg("is not the host")
|
||||
|
@ -1,235 +1,142 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
"n.eko.moe/neko/internal/webrtc"
|
||||
)
|
||||
|
||||
func New(sessions *session.SessionManager, webrtc *webrtc.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
conf: conf,
|
||||
sessions: sessions,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
handler: &MessageHandler{
|
||||
logger: logger.With().Str("subsystem", "handler").Logger(),
|
||||
sessions: sessions,
|
||||
webrtc: webrtc,
|
||||
banned: make(map[string]bool),
|
||||
locked: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
const pingPeriod = 60 * time.Second
|
||||
|
||||
type WebSocketHandler struct {
|
||||
type MessageHandler struct {
|
||||
logger zerolog.Logger
|
||||
upgrader websocket.Upgrader
|
||||
handler *MessageHandler
|
||||
conf *config.WebSocket
|
||||
sessions *session.SessionManager
|
||||
shutdown chan bool
|
||||
sessions types.SessionManager
|
||||
webrtc types.WebRTCManager
|
||||
banned map[string]bool
|
||||
locked bool
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Start() error {
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ws.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ws.shutdown:
|
||||
return
|
||||
}
|
||||
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
|
||||
address := socket.Address()
|
||||
if address == nil {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
} else {
|
||||
ok, banned := h.banned[*address]
|
||||
if ok && banned {
|
||||
h.logger.Debug().Str("address", *address).Msg("banned")
|
||||
return false, "This IP has been banned", nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ws.sessions.OnCreated(func(id string, session *session.Session) {
|
||||
if err := ws.handler.SessionCreated(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session created")
|
||||
}
|
||||
})
|
||||
if h.locked {
|
||||
h.logger.Debug().Msg("server locked")
|
||||
return false, "Server is currently locked", nil
|
||||
}
|
||||
|
||||
ws.sessions.OnConnected(func(id string, session *session.Session) {
|
||||
if err := ws.handler.SessionConnected(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session connected")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnDestroy(func(id string) {
|
||||
if err := ws.handler.SessionDestroyed(id); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Shutdown() error {
|
||||
ws.shutdown <- true
|
||||
return nil
|
||||
func (h *MessageHandler) Disconnected(id string) error {
|
||||
return h.sessions.Destroy(id)
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||
|
||||
socket, err := ws.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||
return err
|
||||
}
|
||||
|
||||
id, admin, err := ws.authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
||||
|
||||
if err = socket.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "invalid password",
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = socket.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, reason, err := ws.handler.SocketConnected(id, socket)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("connection failed")
|
||||
func (h *MessageHandler) Message(id string, raw []byte) error {
|
||||
header := message.Message{}
|
||||
if err := json.Unmarshal(raw, &header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := h.sessions.Get(id)
|
||||
if !ok {
|
||||
if err = socket.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: reason,
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = socket.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
errors.Errorf("unknown session id %s", id)
|
||||
}
|
||||
|
||||
ws.sessions.New(id, admin, socket)
|
||||
switch header.Event {
|
||||
// Signal Events
|
||||
case event.SIGNAL_PROVIDE:
|
||||
payload := &message.Signal{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.createPeer(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
// Identity Events
|
||||
case event.IDENTITY_DETAILS:
|
||||
payload := &message.IdentityDetails{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.identityDetails(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", socket.RemoteAddr().String()).
|
||||
Msg("new connection created")
|
||||
// Control Events
|
||||
case event.CONTROL_RELEASE:
|
||||
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_REQUEST:
|
||||
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_GIVE:
|
||||
payload := &message.Control{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.controlGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
defer func() {
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", socket.RemoteAddr().String()).
|
||||
Msg("session ended")
|
||||
}()
|
||||
// Chat Events
|
||||
case event.CHAT_MESSAGE:
|
||||
payload := &message.ChatRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chat(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.CHAT_EMOTE:
|
||||
payload := &message.EmoteRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chatEmote(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
ws.handle(socket, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.AdminPassword {
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.Password {
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) handle(socket *websocket.Conn, id string) {
|
||||
bytes := make(chan []byte)
|
||||
cancel := make(chan struct{})
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
ws.logger.Debug().Str("address", socket.RemoteAddr().String()).Msg("handle socket ending")
|
||||
ws.handler.SocketDisconnected(id)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, raw, err := socket.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
ws.logger.Warn().Err(err).Msg("read message error")
|
||||
} else {
|
||||
ws.logger.Debug().Err(err).Msg("read message error")
|
||||
}
|
||||
close(cancel)
|
||||
break
|
||||
}
|
||||
bytes <- raw
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case raw := <-bytes:
|
||||
ws.logger.Debug().
|
||||
Str("session", id).
|
||||
Str("raw", string(raw)).
|
||||
Msg("recieved message from client")
|
||||
if err := ws.handler.Message(id, raw); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("message handler has failed")
|
||||
}
|
||||
case <-cancel:
|
||||
return
|
||||
case _ = <-ticker.C:
|
||||
if err := socket.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Admin Events
|
||||
case event.ADMIN_LOCK:
|
||||
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_UNLOCK:
|
||||
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_CONTROL:
|
||||
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_RELEASE:
|
||||
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_GIVE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_BAN:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminBan(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_KICK:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminKick(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_MUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminMute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_UNMUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminUnmute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
default:
|
||||
return errors.Errorf("unknown message event %s", header.Event)
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,34 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) identityDetails(id string, session *session.Session, payload *message.IdentityDetails) error {
|
||||
if _, err := h.sessions.SetName(id, payload.Username); err != nil {
|
||||
func (h *MessageHandler) identityDetails(id string, session types.Session, payload *message.IdentityDetails) error {
|
||||
if err := session.SetName(payload.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) createPeer(id string, session types.Session, payload *message.Signal) error {
|
||||
sdp, peer, err := h.webrtc.CreatePeer(id, payload.SDP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.SetPeer(peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.Send(message.Signal{
|
||||
Event: event.SIGNAL_ANSWER,
|
||||
SDP: sdp,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,149 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
"n.eko.moe/neko/internal/webrtc"
|
||||
)
|
||||
|
||||
type MessageHandler struct {
|
||||
logger zerolog.Logger
|
||||
sessions *session.SessionManager
|
||||
webrtc *webrtc.WebRTCManager
|
||||
banned map[string]bool
|
||||
locked bool
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SocketConnected(id string, socket *websocket.Conn) (bool, string, error) {
|
||||
remote := socket.RemoteAddr().String()
|
||||
if remote != "" {
|
||||
address := strings.SplitN(remote, ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
|
||||
} else {
|
||||
|
||||
ok, banned := h.banned[address[0]]
|
||||
if ok && banned {
|
||||
h.logger.Debug().Str("address", remote).Msg("banned")
|
||||
return false, "This IP has been banned", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.locked {
|
||||
h.logger.Debug().Str("address", remote).Msg("locked")
|
||||
return false, "Server is currently locked", nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SocketDisconnected(id string) error {
|
||||
return h.sessions.Destroy(id)
|
||||
}
|
||||
|
||||
func (h *MessageHandler) Message(id string, raw []byte) error {
|
||||
header := message.Message{}
|
||||
if err := json.Unmarshal(raw, &header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := h.sessions.Get(id)
|
||||
if !ok {
|
||||
errors.Errorf("unknown session id %s", id)
|
||||
}
|
||||
|
||||
switch header.Event {
|
||||
// Signal Events
|
||||
case event.SIGNAL_PROVIDE:
|
||||
payload := message.Signal{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(&payload, raw, func() error {
|
||||
return h.webrtc.CreatePeer(id, payload.SDP)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Identity Events
|
||||
case event.IDENTITY_DETAILS:
|
||||
payload := &message.IdentityDetails{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.identityDetails(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Control Events
|
||||
case event.CONTROL_RELEASE:
|
||||
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_REQUEST:
|
||||
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_GIVE:
|
||||
payload := &message.Control{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.controlGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Chat Events
|
||||
case event.CHAT_MESSAGE:
|
||||
payload := &message.ChatRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chat(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.CHAT_EMOTE:
|
||||
payload := &message.EmoteRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chatEmote(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Admin Events
|
||||
case event.ADMIN_LOCK:
|
||||
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_UNLOCK:
|
||||
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_CONTROL:
|
||||
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_RELEASE:
|
||||
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_GIVE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_BAN:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminBan(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_KICK:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminKick(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_MUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminMute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_UNMUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminUnmute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
default:
|
||||
return errors.Errorf("unknown message event %s", header.Event)
|
||||
}
|
||||
}
|
@ -1,12 +1,12 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) SessionCreated(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) SessionCreated(id string, session types.Session) error {
|
||||
if err := session.Send(message.Identity{
|
||||
Event: event.IDENTITY_PROVIDE,
|
||||
ID: id,
|
||||
@ -17,11 +17,11 @@ func (h *MessageHandler) SessionCreated(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SessionConnected(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) SessionConnected(id string, session types.Session) error {
|
||||
// send list of members to session
|
||||
if err := session.Send(message.MembersList{
|
||||
Event: event.MEMBER_LIST,
|
||||
Memebers: h.sessions.GetConnected(),
|
||||
Memebers: h.sessions.Members(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
|
||||
return err
|
||||
@ -32,7 +32,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
||||
if ok {
|
||||
if err := session.Send(message.Control{
|
||||
Event: event.CONTROL_LOCKED,
|
||||
ID: host.ID,
|
||||
ID: host.ID(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
|
||||
return err
|
||||
@ -42,8 +42,8 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
||||
// let everyone know there is a new session
|
||||
if err := h.sessions.Brodcast(
|
||||
message.Member{
|
||||
Event: event.MEMBER_CONNECTED,
|
||||
Session: session,
|
||||
Event: event.MEMBER_CONNECTED,
|
||||
Member: session.Member(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE)
|
||||
return err
|
||||
|
37
server/internal/websocket/socket.go
Normal file
37
server/internal/websocket/socket.go
Normal file
@ -0,0 +1,37 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
id string
|
||||
connection *websocket.Conn
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Address() *string {
|
||||
remote := socket.connection.RemoteAddr()
|
||||
address := strings.SplitN(remote.String(), ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
return nil
|
||||
}
|
||||
return &address[0]
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Send(v interface{}) error {
|
||||
if socket.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return socket.connection.WriteJSON(v)
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Destroy() error {
|
||||
if socket.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return socket.connection.Close()
|
||||
}
|
239
server/internal/websocket/websocket.go
Normal file
239
server/internal/websocket/websocket.go
Normal file
@ -0,0 +1,239 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
)
|
||||
|
||||
func New(sessions types.SessionManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
conf: conf,
|
||||
sessions: sessions,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
handler: &MessageHandler{
|
||||
logger: logger.With().Str("subsystem", "handler").Logger(),
|
||||
sessions: sessions,
|
||||
webrtc: webrtc,
|
||||
banned: make(map[string]bool),
|
||||
locked: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
const pingPeriod = 60 * time.Second
|
||||
|
||||
type WebSocketHandler struct {
|
||||
logger zerolog.Logger
|
||||
upgrader websocket.Upgrader
|
||||
sessions types.SessionManager
|
||||
conf *config.WebSocket
|
||||
handler *MessageHandler
|
||||
shutdown chan bool
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Start() error {
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ws.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ws.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ws.sessions.OnCreated(func(id string, session types.Session) {
|
||||
if err := ws.handler.SessionCreated(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session created")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnConnected(func(id string, session types.Session) {
|
||||
if err := ws.handler.SessionConnected(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session connected")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnDestroy(func(id string) {
|
||||
if err := ws.handler.SessionDestroyed(id); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Shutdown() error {
|
||||
ws.shutdown <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||
|
||||
connection, err := ws.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||
return err
|
||||
}
|
||||
|
||||
id, admin, err := ws.authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
||||
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "invalid password",
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
socket := &WebSocket{
|
||||
id: id,
|
||||
connection: connection,
|
||||
}
|
||||
|
||||
ok, reason, err := ws.handler.Connected(id, socket)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("connection failed")
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: reason,
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ws.sessions.New(id, admin, socket)
|
||||
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("new connection created")
|
||||
|
||||
defer func() {
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("session ended")
|
||||
}()
|
||||
|
||||
ws.handle(connection, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.AdminPassword {
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.Password {
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
|
||||
bytes := make(chan []byte)
|
||||
cancel := make(chan struct{})
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending")
|
||||
ws.handler.Disconnected(id)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, raw, err := connection.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
ws.logger.Warn().Err(err).Msg("read message error")
|
||||
} else {
|
||||
ws.logger.Debug().Err(err).Msg("read message error")
|
||||
}
|
||||
close(cancel)
|
||||
break
|
||||
}
|
||||
bytes <- raw
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case raw := <-bytes:
|
||||
ws.logger.Debug().
|
||||
Str("session", id).
|
||||
Str("raw", string(raw)).
|
||||
Msg("recieved message from client")
|
||||
if err := ws.handler.Message(id, raw); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("message handler has failed")
|
||||
}
|
||||
case <-cancel:
|
||||
return
|
||||
case _ = <-ticker.C:
|
||||
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user