mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
fixes #14
This commit is contained in:
parent
a0866a4ab9
commit
e3a73aa264
@ -66,3 +66,4 @@ NEKO_CERT= // (SSL)Cert
|
|||||||
|
|
||||||
### Non Goals
|
### Non Goals
|
||||||
* Turning n.eko into a service that serves multiple rooms and browsers/desktops.
|
* Turning n.eko into a service that serves multiple rooms and browsers/desktops.
|
||||||
|
* Voice chat, use [Discord](https://discordapp.com/))
|
@ -34,7 +34,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
get peerConnected() {
|
get peerConnected() {
|
||||||
return typeof this._peer !== 'undefined' && this._state === 'connected'
|
return typeof this._peer !== 'undefined' && ['connected', 'checking', 'completed'].includes(this._state)
|
||||||
}
|
}
|
||||||
|
|
||||||
get connected() {
|
get connected() {
|
||||||
@ -60,7 +60,7 @@ export abstract class BaseClient extends EventEmitter<BaseEvents> {
|
|||||||
this._ws.onmessage = this.onMessage.bind(this)
|
this._ws.onmessage = this.onMessage.bind(this)
|
||||||
this._ws.onerror = event => this.onError.bind(this)
|
this._ws.onerror = event => this.onError.bind(this)
|
||||||
this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed'))
|
this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed'))
|
||||||
this._timeout = setTimeout(this.onTimeout.bind(this), 5000)
|
this._timeout = setTimeout(this.onTimeout.bind(this), 15000)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.onDisconnected(err)
|
this.onDisconnected(err)
|
||||||
}
|
}
|
||||||
|
@ -127,6 +127,7 @@ github.com/pion/transport v0.8.10 h1:lTiobMEw2PG6BH/mgIVqTV2mBp/mPT+IJLaN8ZxgdHk
|
|||||||
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
|
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
|
||||||
github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80=
|
github.com/pion/turn v1.4.0 h1:7NUMRehQz4fIo53Qv9ui1kJ0Kr1CA82I81RHKHCeM80=
|
||||||
github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU=
|
github.com/pion/turn v1.4.0/go.mod h1:aDSi6hWX/hd1+gKia9cExZOR0MU95O7zX9p3Gw/P2aU=
|
||||||
|
github.com/pion/webrtc v1.2.0 h1:3LGGPQEMacwG2hcDfhdvwQPz315gvjZXOfY4vaF4+I4=
|
||||||
github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk=
|
github.com/pion/webrtc/v2 v2.1.18 h1:g0VN0xfEUSlVNfQmlCD6yOeXy/tMaktESBmHMnBS3bk=
|
||||||
github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY=
|
github.com/pion/webrtc/v2 v2.1.18/go.mod h1:m0rKlYgLRZWyhmcMWegpF6xtK1ASxmOg8DAR74ttzQY=
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"github.com/pion/webrtc/v2"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
@ -22,13 +21,8 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.PersistentFlags().String("ac", "opus", "Audio codec to use for streaming")
|
cmd.PersistentFlags().String("aduio", "", "Audio codec parameters to use for streaming")
|
||||||
if err := viper.BindPFlag("acodec", cmd.PersistentFlags().Lookup("ac")); err != nil {
|
if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("aduio")); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.PersistentFlags().String("ap", "", "Audio codec parameters to use for streaming")
|
|
||||||
if err := viper.BindPFlag("aparams", cmd.PersistentFlags().Lookup("ap")); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,13 +31,45 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.PersistentFlags().String("vc", "vp8", "Video codec to use for streaming")
|
cmd.PersistentFlags().String("video", "", "Video codec parameters to use for streaming")
|
||||||
if err := viper.BindPFlag("vcodec", cmd.PersistentFlags().Lookup("vc")); err != nil {
|
if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("video")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.PersistentFlags().String("vp", "", "Video codec parameters to use for streaming")
|
// video codecs
|
||||||
if err := viper.BindPFlag("vparams", cmd.PersistentFlags().Lookup("vp")); err != nil {
|
cmd.PersistentFlags().Bool("vp8", false, "Use VP8 codec")
|
||||||
|
if err := viper.BindPFlag("vp8", cmd.PersistentFlags().Lookup("vp8")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("vp9", false, "Use VP9 codec")
|
||||||
|
if err := viper.BindPFlag("vp9", cmd.PersistentFlags().Lookup("vp9")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("h264", false, "Use H264 codec")
|
||||||
|
if err := viper.BindPFlag("h264", cmd.PersistentFlags().Lookup("h264")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// audio codecs
|
||||||
|
cmd.PersistentFlags().Bool("opus", false, "Use Opus codec")
|
||||||
|
if err := viper.BindPFlag("opus", cmd.PersistentFlags().Lookup("opus")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("g722", false, "Use G722 codec")
|
||||||
|
if err := viper.BindPFlag("g722", cmd.PersistentFlags().Lookup("g722")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("pcmu", false, "Use PCMU codec")
|
||||||
|
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().Bool("pcma", false, "Use PCMA codec")
|
||||||
|
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,10 +77,30 @@ func (WebRTC) Init(cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WebRTC) Set() {
|
func (s *WebRTC) Set() {
|
||||||
s.Device = strings.ToLower(viper.GetString("device"))
|
videoCodec := webrtc.VP8
|
||||||
s.AudioCodec = strings.ToLower(viper.GetString("acodec"))
|
if viper.GetBool("vp8") {
|
||||||
s.AudioParams = strings.ToLower(viper.GetString("aparams"))
|
videoCodec = webrtc.VP8
|
||||||
s.Display = strings.ToLower(viper.GetString("display"))
|
} else if viper.GetBool("vp9") {
|
||||||
s.VideoCodec = strings.ToLower(viper.GetString("vcodec"))
|
videoCodec = webrtc.VP9
|
||||||
s.VideoParams = strings.ToLower(viper.GetString("vparams"))
|
} else if viper.GetBool("h264") {
|
||||||
|
videoCodec = webrtc.H264
|
||||||
|
}
|
||||||
|
|
||||||
|
audioCodec := webrtc.VP8
|
||||||
|
if viper.GetBool("opus") {
|
||||||
|
audioCodec = webrtc.Opus
|
||||||
|
} else if viper.GetBool("g722") {
|
||||||
|
audioCodec = webrtc.G722
|
||||||
|
} else if viper.GetBool("pcmu") {
|
||||||
|
audioCodec = webrtc.PCMU
|
||||||
|
} else if viper.GetBool("pcma") {
|
||||||
|
audioCodec = webrtc.PCMA
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Device = viper.GetString("device")
|
||||||
|
s.AudioCodec = audioCodec
|
||||||
|
s.AudioParams = viper.GetString("aparams")
|
||||||
|
s.Display = viper.GetString("display")
|
||||||
|
s.VideoCodec = videoCodec
|
||||||
|
s.VideoParams = viper.GetString("vparams")
|
||||||
}
|
}
|
||||||
|
@ -9,12 +9,13 @@ package gst
|
|||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/pion/webrtc/v2"
|
"github.com/pion/webrtc/v2"
|
||||||
"github.com/pion/webrtc/v2/pkg/media"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -35,10 +36,10 @@ import (
|
|||||||
// Pipeline is a wrapper for a GStreamer Pipeline
|
// Pipeline is a wrapper for a GStreamer Pipeline
|
||||||
type Pipeline struct {
|
type Pipeline struct {
|
||||||
Pipeline *C.GstElement
|
Pipeline *C.GstElement
|
||||||
tracks []*webrtc.Track
|
Sample chan types.Sample
|
||||||
|
CodecName string
|
||||||
|
ClockRate float32
|
||||||
id int
|
id int
|
||||||
codecName string
|
|
||||||
clockRate float32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var pipelines = make(map[int]*Pipeline)
|
var pipelines = make(map[int]*Pipeline)
|
||||||
@ -57,7 +58,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreatePipeline creates a GStreamer Pipeline
|
// CreatePipeline creates a GStreamer Pipeline
|
||||||
func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string) *Pipeline {
|
func CreatePipeline(codecName string, pipelineSrc string) (*Pipeline, error) {
|
||||||
pipelineStr := "appsink name=appsink"
|
pipelineStr := "appsink name=appsink"
|
||||||
var clockRate float32
|
var clockRate float32
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = videoClockRate
|
clockRate = videoClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
case webrtc.VP9:
|
case webrtc.VP9:
|
||||||
@ -83,7 +84,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = videoClockRate
|
clockRate = videoClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
case webrtc.H264:
|
case webrtc.H264:
|
||||||
@ -98,14 +99,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = videoClockRate
|
clockRate = videoClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"ximagesrc"}); err != nil {
|
if err := CheckPlugins([]string{"ximagesrc"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"openh264"}); err != nil {
|
if err := CheckPlugins([]string{"openh264"}); err != nil {
|
||||||
pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr
|
pipelineStr = pipelineSrc + " ! video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream ! " + pipelineStr
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"x264"}); err != nil {
|
if err := CheckPlugins([]string{"x264"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +118,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = audioClockRate
|
clockRate = audioClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil {
|
if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
case webrtc.G722:
|
case webrtc.G722:
|
||||||
@ -128,7 +129,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = audioClockRate
|
clockRate = audioClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil {
|
if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
case webrtc.PCMU:
|
case webrtc.PCMU:
|
||||||
@ -140,7 +141,7 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = pcmClockRate
|
clockRate = pcmClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil {
|
if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
case webrtc.PCMA:
|
case webrtc.PCMA:
|
||||||
@ -151,11 +152,11 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
clockRate = pcmClockRate
|
clockRate = pcmClockRate
|
||||||
|
|
||||||
if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil {
|
if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil {
|
||||||
panic(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
panic("Unhandled codec " + codecName)
|
return nil, errors.Errorf("unknown video codec %s", codecName)
|
||||||
}
|
}
|
||||||
|
|
||||||
pipelineStrUnsafe := C.CString(pipelineStr)
|
pipelineStrUnsafe := C.CString(pipelineStr)
|
||||||
@ -166,14 +167,14 @@ func CreatePipeline(codecName string, tracks []*webrtc.Track, pipelineSrc string
|
|||||||
|
|
||||||
pipeline := &Pipeline{
|
pipeline := &Pipeline{
|
||||||
Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe),
|
Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe),
|
||||||
tracks: tracks,
|
Sample: make(chan types.Sample),
|
||||||
|
CodecName: codecName,
|
||||||
|
ClockRate: clockRate,
|
||||||
id: len(pipelines),
|
id: len(pipelines),
|
||||||
codecName: codecName,
|
|
||||||
clockRate: clockRate,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pipelines[pipeline.id] = pipeline
|
pipelines[pipeline.id] = pipeline
|
||||||
return pipeline
|
return pipeline, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the GStreamer Pipeline
|
// Start starts the GStreamer Pipeline
|
||||||
@ -193,14 +194,13 @@ func CheckPlugins(plugins []string) error {
|
|||||||
plugin = C.gst_registry_find_plugin(registry, plugincstr)
|
plugin = C.gst_registry_find_plugin(registry, plugincstr)
|
||||||
C.free(unsafe.Pointer(plugincstr))
|
C.free(unsafe.Pointer(plugincstr))
|
||||||
if plugin == nil {
|
if plugin == nil {
|
||||||
return fmt.Errorf("Required gstreamer plugin %s not found", pluginstr)
|
return fmt.Errorf("required gstreamer plugin %s not found", pluginstr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//export goHandlePipelineBuffer
|
//export goHandlePipelineBuffer
|
||||||
func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) {
|
func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) {
|
||||||
pipelinesLock.Lock()
|
pipelinesLock.Lock()
|
||||||
@ -208,12 +208,8 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i
|
|||||||
pipelinesLock.Unlock()
|
pipelinesLock.Unlock()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
samples := uint32(pipeline.clockRate * (float32(duration) / 1000000000))
|
samples := uint32(pipeline.ClockRate * (float32(duration) / 1000000000))
|
||||||
for _, t := range pipeline.tracks {
|
pipeline.Sample <- types.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}
|
||||||
if err := t.WriteSample(media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}); err != nil && err != io.ErrClosedPipe {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID))
|
fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID))
|
||||||
}
|
}
|
||||||
|
@ -3,15 +3,17 @@ package session
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/kataras/go-events"
|
"github.com/kataras/go-events"
|
||||||
"github.com/pion/webrtc/v2"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/utils"
|
"n.eko.moe/neko/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New() *SessionManager {
|
func New() *SessionManager {
|
||||||
return &SessionManager{
|
return &SessionManager{
|
||||||
|
logger: log.With().Str("module", "session").Logger(),
|
||||||
host: "",
|
host: "",
|
||||||
members: make(map[string]*Session),
|
members: make(map[string]*Session),
|
||||||
emmiter: events.New(),
|
emmiter: events.New(),
|
||||||
@ -19,153 +21,100 @@ func New() *SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
|
logger zerolog.Logger
|
||||||
host string
|
host string
|
||||||
members map[string]*Session
|
members map[string]*Session
|
||||||
emmiter events.EventEmmiter
|
emmiter events.EventEmmiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) New(id string, admin bool, socket *websocket.Conn) *Session {
|
func (manager *SessionManager) New(id string, admin bool, socket types.WebScoket) types.Session {
|
||||||
session := &Session{
|
session := &Session{
|
||||||
ID: id,
|
id: id,
|
||||||
Admin: admin,
|
admin: admin,
|
||||||
|
manager: manager,
|
||||||
socket: socket,
|
socket: socket,
|
||||||
|
logger: manager.logger.With().Str("id", id).Logger(),
|
||||||
connected: false,
|
connected: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.members[id] = session
|
manager.members[id] = session
|
||||||
m.emmiter.Emit("created", id, session)
|
manager.emmiter.Emit("created", id, session)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) IsHost(id string) bool {
|
func (manager *SessionManager) HasHost() bool {
|
||||||
return m.host == id
|
return manager.host != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) HasHost() bool {
|
func (manager *SessionManager) IsHost(id string) bool {
|
||||||
return m.host != ""
|
return manager.host == id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) SetHost(id string) error {
|
func (manager *SessionManager) SetHost(id string) error {
|
||||||
_, ok := m.members[id]
|
_, ok := manager.members[id]
|
||||||
if ok {
|
if ok {
|
||||||
m.host = id
|
manager.host = id
|
||||||
m.emmiter.Emit("host", id)
|
manager.emmiter.Emit("host", id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("invalid session id %s", id)
|
return fmt.Errorf("invalid session id %s", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) GetHost() (*Session, bool) {
|
func (manager *SessionManager) GetHost() (types.Session, bool) {
|
||||||
host, ok := m.members[m.host]
|
host, ok := manager.members[manager.host]
|
||||||
return host, ok
|
return host, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) ClearHost() {
|
func (manager *SessionManager) ClearHost() {
|
||||||
id := m.host
|
id := manager.host
|
||||||
m.host = ""
|
manager.host = ""
|
||||||
m.emmiter.Emit("host_cleared", id)
|
manager.emmiter.Emit("host_cleared", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Has(id string) bool {
|
func (manager *SessionManager) Has(id string) bool {
|
||||||
_, ok := m.members[id]
|
_, ok := manager.members[id]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Get(id string) (*Session, bool) {
|
func (manager *SessionManager) Get(id string) (types.Session, bool) {
|
||||||
session, ok := m.members[id]
|
session, ok := manager.members[id]
|
||||||
return session, ok
|
return session, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) GetConnected() []*Session {
|
func (manager *SessionManager) Members() []*types.Member {
|
||||||
var sessions []*Session
|
members := []*types.Member{}
|
||||||
for _, sess := range m.members {
|
for _, session := range manager.members {
|
||||||
if sess.connected {
|
if !session.connected {
|
||||||
sessions = append(sessions, sess)
|
continue
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sessions
|
member := session.Member()
|
||||||
|
if member != nil {
|
||||||
|
members = append(members, member)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return members
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Set(id string, session *Session) {
|
func (manager *SessionManager) Destroy(id string) error {
|
||||||
m.members[id] = session
|
session, ok := manager.members[id]
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) Destroy(id string) error {
|
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
if ok {
|
||||||
err := session.destroy()
|
err := session.destroy()
|
||||||
delete(m.members, id)
|
delete(manager.members, id)
|
||||||
m.emmiter.Emit("destroyed", id)
|
manager.emmiter.Emit("destroyed", id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) SetSocket(id string, socket *websocket.Conn) (bool, error) {
|
func (manager *SessionManager) Clear() error {
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
|
||||||
session.socket = socket
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, fmt.Errorf("invalid session id %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) SetPeer(id string, peer *webrtc.PeerConnection) (bool, error) {
|
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
|
||||||
session.peer = peer
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, fmt.Errorf("invalid session id %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) SetName(id string, name string) (bool, error) {
|
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
|
||||||
session.Name = name
|
|
||||||
session.connected = true
|
|
||||||
m.emmiter.Emit("connected", id, session)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, fmt.Errorf("invalid session id %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) Mute(id string) error {
|
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
|
||||||
session.Muted = true
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) Unmute(id string) error {
|
func (manager *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
||||||
session, ok := m.members[id]
|
for id, session := range manager.members {
|
||||||
if ok {
|
if !session.connected {
|
||||||
session.Muted = false
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) Kick(id string, v interface{}) error {
|
|
||||||
session, ok := m.members[id]
|
|
||||||
if ok {
|
|
||||||
return session.Kick(v)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) Clear() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
|
||||||
for id, sess := range m.members {
|
|
||||||
if !sess.connected {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,39 +124,65 @@ func (m *SessionManager) Brodcast(v interface{}, exclude interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := sess.Send(v); err != nil {
|
if err := session.Send(v); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) OnHost(listener func(id string)) {
|
func (manager *SessionManager) WriteVideoSample(sample types.Sample) error {
|
||||||
m.emmiter.On("host", func(payload ...interface{}) {
|
for _, session := range manager.members {
|
||||||
|
if !session.connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.WriteVideoSample(sample); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *SessionManager) WriteAudioSample(sample types.Sample) error {
|
||||||
|
for _, session := range manager.members {
|
||||||
|
if !session.connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.WriteAudioSample(sample); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *SessionManager) OnHost(listener func(id string)) {
|
||||||
|
manager.emmiter.On("host", func(payload ...interface{}) {
|
||||||
listener(payload[0].(string))
|
listener(payload[0].(string))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) OnHostCleared(listener func(id string)) {
|
func (manager *SessionManager) OnHostCleared(listener func(id string)) {
|
||||||
m.emmiter.On("host_cleared", func(payload ...interface{}) {
|
manager.emmiter.On("host_cleared", func(payload ...interface{}) {
|
||||||
listener(payload[0].(string))
|
listener(payload[0].(string))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) OnCreated(listener func(id string, session *Session)) {
|
func (manager *SessionManager) OnDestroy(listener func(id string)) {
|
||||||
m.emmiter.On("created", func(payload ...interface{}) {
|
manager.emmiter.On("destroyed", func(payload ...interface{}) {
|
||||||
|
listener(payload[0].(string))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) {
|
||||||
|
manager.emmiter.On("created", func(payload ...interface{}) {
|
||||||
listener(payload[0].(string), payload[1].(*Session))
|
listener(payload[0].(string), payload[1].(*Session))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) OnConnected(listener func(id string, session *Session)) {
|
func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) {
|
||||||
m.emmiter.On("connected", func(payload ...interface{}) {
|
manager.emmiter.On("connected", func(payload ...interface{}) {
|
||||||
listener(payload[0].(string), payload[1].(*Session))
|
listener(payload[0].(string), payload[1].(*Session))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *SessionManager) OnDestroy(listener func(id string)) {
|
|
||||||
m.emmiter.On("destroyed", func(payload ...interface{}) {
|
|
||||||
listener(payload[0].(string))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -3,38 +3,90 @@ package session
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/rs/zerolog"
|
||||||
"github.com/pion/webrtc/v2"
|
"n.eko.moe/neko/internal/types"
|
||||||
|
"n.eko.moe/neko/internal/types/event"
|
||||||
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
logger zerolog.Logger
|
||||||
Name string `json:"username"`
|
id string
|
||||||
Admin bool `json:"admin"`
|
name string
|
||||||
Muted bool `json:"muted"`
|
admin bool
|
||||||
|
muted bool
|
||||||
connected bool
|
connected bool
|
||||||
socket *websocket.Conn
|
manager *SessionManager
|
||||||
peer *webrtc.PeerConnection
|
socket types.WebScoket
|
||||||
|
peer types.Peer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) RemoteAddr() *string {
|
func (session *Session) ID() string {
|
||||||
if session.socket != nil {
|
return session.id
|
||||||
address := session.socket.RemoteAddr().String()
|
|
||||||
return &address
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (session *Session) Name() string {
|
||||||
|
return session.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Admin() bool {
|
||||||
|
return session.admin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Muted() bool {
|
||||||
|
return session.muted
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Connected() bool {
|
||||||
|
return session.connected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Address() *string {
|
||||||
|
if session.socket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return session.socket.Address()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Member() *types.Member {
|
||||||
|
return &types.Member{
|
||||||
|
ID: session.id,
|
||||||
|
Name: session.name,
|
||||||
|
Admin: session.admin,
|
||||||
|
Muted: session.muted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) SetMuted(muted bool) {
|
||||||
|
session.muted = muted
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) SetName(name string) error {
|
||||||
|
session.name = name
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: write to peer data channel
|
func (session *Session) SetSocket(socket types.WebScoket) error {
|
||||||
func (session *Session) Write(v interface{}) error {
|
session.socket = socket
|
||||||
session.mu.Lock()
|
|
||||||
defer session.mu.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) Kick(v interface{}) error {
|
func (session *Session) SetPeer(peer types.Peer) error {
|
||||||
if err := session.Send(v); err != nil {
|
session.peer = peer
|
||||||
|
session.connected = true
|
||||||
|
session.manager.emmiter.Emit("connected", session.id, session)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) Kick(reason string) error {
|
||||||
|
if session.socket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := session.socket.Send(&message.Disconnect{
|
||||||
|
Event: event.SYSTEM_DISCONNECT,
|
||||||
|
Message: reason,
|
||||||
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,28 +94,41 @@ func (session *Session) Kick(v interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) Send(v interface{}) error {
|
func (session *Session) Send(v interface{}) error {
|
||||||
session.mu.Lock()
|
if session.socket == nil {
|
||||||
defer session.mu.Unlock()
|
return nil
|
||||||
|
}
|
||||||
if session.socket != nil {
|
return session.socket.Send(v)
|
||||||
return session.socket.WriteJSON(v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (session *Session) Write(v interface{}) error {
|
||||||
|
if session.socket == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return session.socket.Send(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) WriteVideoSample(sample types.Sample) error {
|
||||||
|
if session.peer == nil || !session.connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return session.peer.WriteVideoSample(sample)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) WriteAudioSample(sample types.Sample) error {
|
||||||
|
if session.peer == nil || !session.connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return session.peer.WriteAudioSample(sample)
|
||||||
|
}
|
||||||
|
|
||||||
func (session *Session) destroy() error {
|
func (session *Session) destroy() error {
|
||||||
if session.peer != nil && session.peer.ConnectionState() == webrtc.PeerConnectionStateConnected {
|
if err := session.socket.Destroy(); err != nil {
|
||||||
if err := session.peer.Close(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if session.socket != nil {
|
if err := session.peer.Destroy(); err != nil {
|
||||||
if err := session.socket.Close(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package message
|
package message
|
||||||
|
|
||||||
import "n.eko.moe/neko/internal/session"
|
import (
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
@ -28,12 +30,12 @@ type Signal struct {
|
|||||||
|
|
||||||
type MembersList struct {
|
type MembersList struct {
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
Memebers []*session.Session `json:"members"`
|
Memebers []*types.Member `json:"members"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Member struct {
|
type Member struct {
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
*session.Session
|
*types.Member
|
||||||
}
|
}
|
||||||
type MemberDisconnected struct {
|
type MemberDisconnected struct {
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
49
server/internal/types/session.go
Normal file
49
server/internal/types/session.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Member struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"username"`
|
||||||
|
Admin bool `json:"admin"`
|
||||||
|
Muted bool `json:"muted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session interface {
|
||||||
|
ID() string
|
||||||
|
Name() string
|
||||||
|
Admin() bool
|
||||||
|
Muted() bool
|
||||||
|
Connected() bool
|
||||||
|
Member() *Member
|
||||||
|
SetMuted(muted bool)
|
||||||
|
SetName(name string) error
|
||||||
|
SetSocket(socket WebScoket) error
|
||||||
|
SetPeer(peer Peer) error
|
||||||
|
Address() *string
|
||||||
|
Kick(message string) error
|
||||||
|
Write(v interface{}) error
|
||||||
|
Send(v interface{}) error
|
||||||
|
WriteAudioSample(sample Sample) error
|
||||||
|
WriteVideoSample(sample Sample) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionManager interface {
|
||||||
|
New(id string, admin bool, socket WebScoket) Session
|
||||||
|
HasHost() bool
|
||||||
|
IsHost(id string) bool
|
||||||
|
SetHost(id string) error
|
||||||
|
GetHost() (Session, bool)
|
||||||
|
ClearHost()
|
||||||
|
Has(id string) bool
|
||||||
|
Get(id string) (Session, bool)
|
||||||
|
Members() []*Member
|
||||||
|
Destroy(id string) error
|
||||||
|
Clear() error
|
||||||
|
Brodcast(v interface{}, exclude interface{}) error
|
||||||
|
WriteAudioSample(sample Sample) error
|
||||||
|
WriteVideoSample(sample Sample) error
|
||||||
|
OnHost(listener func(id string))
|
||||||
|
OnHostCleared(listener func(id string))
|
||||||
|
OnDestroy(listener func(id string))
|
||||||
|
OnCreated(listener func(id string, session Session))
|
||||||
|
OnConnected(listener func(id string, session Session))
|
||||||
|
}
|
19
server/internal/types/webrtc.go
Normal file
19
server/internal/types/webrtc.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Sample struct {
|
||||||
|
Data []byte
|
||||||
|
Samples uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebRTCManager interface {
|
||||||
|
Start()
|
||||||
|
Shutdown() error
|
||||||
|
CreatePeer(id string, sdp string) (string, Peer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Peer interface {
|
||||||
|
WriteVideoSample(sample Sample) error
|
||||||
|
WriteAudioSample(sample Sample) error
|
||||||
|
WriteData(v interface{}) error
|
||||||
|
Destroy() error
|
||||||
|
}
|
7
server/internal/types/webscoket.go
Normal file
7
server/internal/types/webscoket.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type WebScoket interface {
|
||||||
|
Address() *string
|
||||||
|
Send(v interface{}) error
|
||||||
|
Destroy() error
|
||||||
|
}
|
@ -1,6 +1,9 @@
|
|||||||
package webrtc
|
package webrtc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
@ -13,8 +16,19 @@ func (l logger) Trace(msg string) { l.logger.Trace().Ms
|
|||||||
func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) }
|
func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) }
|
||||||
func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) }
|
func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) }
|
||||||
func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) }
|
func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) }
|
||||||
func (l logger) Info(msg string) { l.logger.Info().Msg(msg) }
|
func (l logger) Info(msg string) {
|
||||||
func (l logger) Infof(format string, args ...interface{}) { l.logger.Info().Msgf(format, args...) }
|
if strings.Contains(msg, "packetio.Buffer is full") {
|
||||||
|
l.logger.Panic().Msg(msg)
|
||||||
|
}
|
||||||
|
l.logger.Info().Msg(msg)
|
||||||
|
}
|
||||||
|
func (l logger) Infof(format string, args ...interface{}) {
|
||||||
|
msg := fmt.Sprintf(format, args...)
|
||||||
|
if strings.Contains(msg, "packetio.Buffer is full") {
|
||||||
|
l.logger.Panic().Msg(msg)
|
||||||
|
}
|
||||||
|
l.logger.Info().Msgf(format, args...)
|
||||||
|
}
|
||||||
func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) }
|
func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) }
|
||||||
func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) }
|
func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) }
|
||||||
func (l logger) Error(msg string) { l.logger.Error().Msg(msg) }
|
func (l logger) Error(msg string) { l.logger.Error().Msg(msg) }
|
||||||
|
@ -1,226 +0,0 @@
|
|||||||
package webrtc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pion/webrtc/v2"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"n.eko.moe/neko/internal/config"
|
|
||||||
"n.eko.moe/neko/internal/event"
|
|
||||||
"n.eko.moe/neko/internal/gst"
|
|
||||||
"n.eko.moe/neko/internal/hid"
|
|
||||||
"n.eko.moe/neko/internal/message"
|
|
||||||
"n.eko.moe/neko/internal/session"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(sessions *session.SessionManager, conf *config.WebRTC) *WebRTCManager {
|
|
||||||
logger := log.With().Str("module", "webrtc").Logger()
|
|
||||||
engine := webrtc.MediaEngine{}
|
|
||||||
engine.RegisterDefaultCodecs()
|
|
||||||
|
|
||||||
setings := webrtc.SettingEngine{
|
|
||||||
LoggerFactory: loggerFactory{
|
|
||||||
logger: logger,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return &WebRTCManager{
|
|
||||||
logger: logger,
|
|
||||||
engine: engine,
|
|
||||||
setings: setings,
|
|
||||||
api: webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(setings)),
|
|
||||||
cleanup: time.NewTicker(1 * time.Second),
|
|
||||||
shutdown: make(chan bool),
|
|
||||||
sessions: sessions,
|
|
||||||
conf: conf,
|
|
||||||
config: webrtc.Configuration{
|
|
||||||
ICEServers: []webrtc.ICEServer{
|
|
||||||
{
|
|
||||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebRTCManager struct {
|
|
||||||
logger zerolog.Logger
|
|
||||||
engine webrtc.MediaEngine
|
|
||||||
setings webrtc.SettingEngine
|
|
||||||
config webrtc.Configuration
|
|
||||||
sessions *session.SessionManager
|
|
||||||
api *webrtc.API
|
|
||||||
video *webrtc.Track
|
|
||||||
audio *webrtc.Track
|
|
||||||
videoPipeline *gst.Pipeline
|
|
||||||
audioPipeline *gst.Pipeline
|
|
||||||
cleanup *time.Ticker
|
|
||||||
conf *config.WebRTC
|
|
||||||
shutdown chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *WebRTCManager) Start() {
|
|
||||||
|
|
||||||
hid.Display(m.conf.Display)
|
|
||||||
|
|
||||||
switch m.conf.VideoCodec {
|
|
||||||
case "vp8":
|
|
||||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP8); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
case "vp9":
|
|
||||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeVP9); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
case "h264":
|
|
||||||
if err := m.createVideoTrack(webrtc.DefaultPayloadTypeH264); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
m.logger.Panic().Err(errors.Errorf("unknown video codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m.conf.AudioCodec {
|
|
||||||
case "opus":
|
|
||||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeOpus); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
case "g722":
|
|
||||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypeG722); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
case "pcmu":
|
|
||||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMU); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
case "pcma":
|
|
||||||
if err := m.createAudioTrack(webrtc.DefaultPayloadTypePCMA); err != nil {
|
|
||||||
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
m.logger.Panic().Err(errors.Errorf("unknown audio codec %s", m.conf.AudioCodec)).Msg("unable to start webrtc manager")
|
|
||||||
}
|
|
||||||
|
|
||||||
m.videoPipeline.Start()
|
|
||||||
m.audioPipeline.Start()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
m.logger.Info().Msg("shutdown")
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-m.shutdown:
|
|
||||||
return
|
|
||||||
case <-m.cleanup.C:
|
|
||||||
hid.Check(time.Second * 10)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
m.sessions.OnHostCleared(func(id string) {
|
|
||||||
hid.Reset()
|
|
||||||
})
|
|
||||||
|
|
||||||
m.sessions.OnCreated(func(id string, session *session.Session) {
|
|
||||||
m.logger.Debug().Str("id", id).Msg("session created")
|
|
||||||
})
|
|
||||||
|
|
||||||
m.sessions.OnDestroy(func(id string) {
|
|
||||||
m.logger.Debug().Str("id", id).Msg("session destroyed")
|
|
||||||
})
|
|
||||||
|
|
||||||
// TODO: log resolution, bit rate and codec parameters
|
|
||||||
m.logger.Info().
|
|
||||||
Str("video_display", m.conf.Display).
|
|
||||||
Str("video_codec", m.conf.VideoCodec).
|
|
||||||
Str("audio_device", m.conf.Device).
|
|
||||||
Str("audio_codec", m.conf.AudioCodec).
|
|
||||||
Msgf("webrtc streaming")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *WebRTCManager) Shutdown() error {
|
|
||||||
m.logger.Info().Msgf("webrtc shutting down")
|
|
||||||
|
|
||||||
m.cleanup.Stop()
|
|
||||||
m.shutdown <- true
|
|
||||||
m.videoPipeline.Stop()
|
|
||||||
m.audioPipeline.Stop()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *WebRTCManager) CreatePeer(id string, sdp string) error {
|
|
||||||
session, ok := m.sessions.Get(id)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid session id %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := m.api.NewPeerConnection(m.config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := peer.AddTransceiverFromTrack(m.video, webrtc.RtpTransceiverInit{
|
|
||||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := peer.AddTransceiverFromTrack(m.audio, webrtc.RtpTransceiverInit{
|
|
||||||
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.SetRemoteDescription(webrtc.SessionDescription{
|
|
||||||
SDP: sdp,
|
|
||||||
Type: webrtc.SDPTypeOffer,
|
|
||||||
})
|
|
||||||
|
|
||||||
answer, err := peer.CreateAnswer(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = peer.SetLocalDescription(answer); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := session.Send(message.Signal{
|
|
||||||
Event: event.SIGNAL_ANSWER,
|
|
||||||
SDP: answer.SDP,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.OnDataChannel(func(d *webrtc.DataChannel) {
|
|
||||||
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
|
||||||
if err = m.handle(id, msg); err != nil {
|
|
||||||
m.logger.Warn().Err(err).Msg("data handle failed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
peer.OnConnectionStateChange(func(connectionState webrtc.PeerConnectionState) {
|
|
||||||
switch connectionState {
|
|
||||||
case webrtc.PeerConnectionStateDisconnected:
|
|
||||||
case webrtc.PeerConnectionStateFailed:
|
|
||||||
m.logger.Info().Str("id", id).Msg("peer disconnected")
|
|
||||||
m.sessions.Destroy(id)
|
|
||||||
break
|
|
||||||
case webrtc.PeerConnectionStateConnected:
|
|
||||||
m.logger.Info().Str("id", id).Msg("peer connected")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
m.sessions.SetPeer(id, peer)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
44
server/internal/webrtc/peer.go
Normal file
44
server/internal/webrtc/peer.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pion/webrtc/v2"
|
||||||
|
"github.com/pion/webrtc/v2/pkg/media"
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
id string
|
||||||
|
engine webrtc.MediaEngine
|
||||||
|
api *webrtc.API
|
||||||
|
video *webrtc.Track
|
||||||
|
audio *webrtc.Track
|
||||||
|
connection *webrtc.PeerConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) WriteAudioSample(sample types.Sample) error {
|
||||||
|
if err := peer.audio.WriteSample(media.Sample(sample)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) WriteVideoSample(sample types.Sample) error {
|
||||||
|
if err := peer.video.WriteSample(media.Sample(sample)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) WriteData(v interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) Destroy() error {
|
||||||
|
if peer.connection != nil && peer.connection.ConnectionState() == webrtc.PeerConnectionStateConnected {
|
||||||
|
if err := peer.connection.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -5,95 +5,36 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
|
|
||||||
"github.com/pion/webrtc/v2"
|
"github.com/pion/webrtc/v2"
|
||||||
"github.com/pkg/errors"
|
|
||||||
|
|
||||||
"n.eko.moe/neko/internal/gst"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *WebRTCManager) createVideoTrack(payloadType uint8) error {
|
func (m *WebRTCManager) createVideoTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
|
||||||
|
|
||||||
clockrate := uint32(90000)
|
|
||||||
var codec *webrtc.RTPCodec
|
var codec *webrtc.RTPCodec
|
||||||
switch payloadType {
|
for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeVideo) {
|
||||||
case webrtc.DefaultPayloadTypeVP8:
|
if videoCodec.Name == m.videoPipeline.CodecName {
|
||||||
codec = webrtc.NewRTPVP8Codec(payloadType, clockrate)
|
codec = videoCodec
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypeVP9:
|
|
||||||
codec = webrtc.NewRTPVP9Codec(payloadType, clockrate)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypeH264:
|
|
||||||
codec = webrtc.NewRTPH264Codec(payloadType, clockrate)
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unknown video codec %s", payloadType)
|
|
||||||
}
|
|
||||||
|
|
||||||
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var pipeline *gst.Pipeline
|
|
||||||
src := fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.conf.Display)
|
|
||||||
switch payloadType {
|
|
||||||
case webrtc.DefaultPayloadTypeVP8:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.VP8, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypeVP9:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.VP9, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypeH264:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.H264, []*webrtc.Track{track}, src)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
m.video = track
|
|
||||||
m.videoPipeline = pipeline
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WebRTCManager) createAudioTrack(payloadType uint8) error {
|
if codec == nil || codec.PayloadType == 0 {
|
||||||
|
return nil, fmt.Errorf("remote peer does not support %s", m.videoPipeline.CodecName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WebRTCManager) createAudioTrack(engine webrtc.MediaEngine) (*webrtc.Track, error) {
|
||||||
var codec *webrtc.RTPCodec
|
var codec *webrtc.RTPCodec
|
||||||
switch payloadType {
|
for _, videoCodec := range engine.GetCodecsByKind(webrtc.RTPCodecTypeAudio) {
|
||||||
case webrtc.DefaultPayloadTypeOpus:
|
if videoCodec.Name == m.audioPipeline.CodecName {
|
||||||
codec = webrtc.NewRTPOpusCodec(payloadType, 48000)
|
codec = videoCodec
|
||||||
break
|
break
|
||||||
case webrtc.DefaultPayloadTypeG722:
|
}
|
||||||
codec = webrtc.NewRTPG722Codec(payloadType, 48000)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypePCMU:
|
|
||||||
codec = webrtc.NewRTPPCMUCodec(payloadType, 8000)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypePCMA:
|
|
||||||
codec = webrtc.NewRTPPCMACodec(payloadType, 8000)
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unknown audio codec %s", payloadType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
track, err := webrtc.NewTrack(payloadType, rand.Uint32(), "stream", "stream", codec)
|
if codec == nil || codec.PayloadType == 0 {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("remote peer does not support %s", m.audioPipeline.CodecName)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var pipeline *gst.Pipeline
|
return webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
|
||||||
src := fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.conf.Device)
|
|
||||||
switch payloadType {
|
|
||||||
case webrtc.DefaultPayloadTypeOpus:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.Opus, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypeG722:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.G722, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypePCMU:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.PCMU, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
case webrtc.DefaultPayloadTypePCMA:
|
|
||||||
pipeline = gst.CreatePipeline(webrtc.PCMA, []*webrtc.Track{track}, src)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
m.audio = track
|
|
||||||
m.audioPipeline = pipeline
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
237
server/internal/webrtc/webrtc.go
Normal file
237
server/internal/webrtc/webrtc.go
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/webrtc/v2"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"n.eko.moe/neko/internal/config"
|
||||||
|
"n.eko.moe/neko/internal/gst"
|
||||||
|
"n.eko.moe/neko/internal/hid"
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(sessions types.SessionManager, config *config.WebRTC) *WebRTCManager {
|
||||||
|
logger := log.With().Str("module", "webrtc").Logger()
|
||||||
|
setings := webrtc.SettingEngine{
|
||||||
|
LoggerFactory: loggerFactory{
|
||||||
|
logger: logger,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WebRTCManager{
|
||||||
|
logger: logger,
|
||||||
|
setings: setings,
|
||||||
|
cleanup: time.NewTicker(1 * time.Second),
|
||||||
|
shutdown: make(chan bool),
|
||||||
|
sessions: sessions,
|
||||||
|
config: config,
|
||||||
|
configuration: &webrtc.Configuration{
|
||||||
|
ICEServers: []webrtc.ICEServer{
|
||||||
|
{
|
||||||
|
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebRTCManager struct {
|
||||||
|
logger zerolog.Logger
|
||||||
|
setings webrtc.SettingEngine
|
||||||
|
sessions types.SessionManager
|
||||||
|
videoPipeline *gst.Pipeline
|
||||||
|
audioPipeline *gst.Pipeline
|
||||||
|
cleanup *time.Ticker
|
||||||
|
config *config.WebRTC
|
||||||
|
shutdown chan bool
|
||||||
|
configuration *webrtc.Configuration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WebRTCManager) Start() {
|
||||||
|
hid.Display(m.config.Display)
|
||||||
|
|
||||||
|
videoPipeline, err := gst.CreatePipeline(
|
||||||
|
m.config.VideoCodec,
|
||||||
|
fmt.Sprintf("ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue", m.config.Display),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
audioPipeline, err := gst.CreatePipeline(
|
||||||
|
m.config.AudioCodec,
|
||||||
|
fmt.Sprintf("pulsesrc device=%s ! audioconvert", m.config.Device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Panic().Err(err).Msg("unable to start webrtc manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.videoPipeline = videoPipeline
|
||||||
|
m.audioPipeline = audioPipeline
|
||||||
|
|
||||||
|
videoPipeline.Start()
|
||||||
|
audioPipeline.Start()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
m.logger.Info().Msg("shutdown")
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.shutdown:
|
||||||
|
return
|
||||||
|
case sample := <-videoPipeline.Sample:
|
||||||
|
if err := m.sessions.WriteVideoSample(sample); err != nil {
|
||||||
|
m.logger.Warn().Err(err).Msg("video pipeline failed")
|
||||||
|
}
|
||||||
|
case sample := <-audioPipeline.Sample:
|
||||||
|
if err := m.sessions.WriteAudioSample(sample); err != nil {
|
||||||
|
m.logger.Warn().Err(err).Msg("audio pipeline failed")
|
||||||
|
}
|
||||||
|
case <-m.cleanup.C:
|
||||||
|
hid.Check(time.Second * 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
m.sessions.OnHostCleared(func(id string) {
|
||||||
|
hid.Reset()
|
||||||
|
})
|
||||||
|
|
||||||
|
m.sessions.OnCreated(func(id string, session types.Session) {
|
||||||
|
m.logger.Debug().Str("id", id).Msg("session created")
|
||||||
|
})
|
||||||
|
|
||||||
|
m.sessions.OnDestroy(func(id string) {
|
||||||
|
m.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||||
|
})
|
||||||
|
|
||||||
|
// TODO: log resolution, bit rate and codec parameters
|
||||||
|
m.logger.Info().
|
||||||
|
Str("video_display", m.config.Display).
|
||||||
|
Str("video_codec", m.config.VideoCodec).
|
||||||
|
Str("audio_device", m.config.Device).
|
||||||
|
Str("audio_codec", m.config.AudioCodec).
|
||||||
|
Msgf("webrtc streaming")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WebRTCManager) Shutdown() error {
|
||||||
|
m.logger.Info().Msgf("webrtc shutting down")
|
||||||
|
m.videoPipeline.Stop()
|
||||||
|
m.audioPipeline.Stop()
|
||||||
|
m.cleanup.Stop()
|
||||||
|
m.shutdown <- true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WebRTCManager) CreatePeer(id string, sdp string) (string, types.Peer, error) {
|
||||||
|
// create SessionDescription
|
||||||
|
description := webrtc.SessionDescription{
|
||||||
|
SDP: sdp,
|
||||||
|
Type: webrtc.SDPTypeOffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// create MediaEngine based off sdp
|
||||||
|
engine := webrtc.MediaEngine{}
|
||||||
|
engine.PopulateFromSDP(description)
|
||||||
|
|
||||||
|
// create API with MediaEngine and SettingEngine
|
||||||
|
api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(m.setings))
|
||||||
|
|
||||||
|
// create new peer connection
|
||||||
|
connection, err := api.NewPeerConnection(*m.configuration)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create video track
|
||||||
|
video, err := m.createVideoTrack(engine)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
videoTransceiver, err := connection.AddTransceiverFromTrack(video, webrtc.RtpTransceiverInit{
|
||||||
|
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create audio track
|
||||||
|
audio, err := m.createAudioTrack(engine)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
audioTransceiver, err := connection.AddTransceiverFromTrack(audio, webrtc.RtpTransceiverInit{
|
||||||
|
Direction: webrtc.RTPTransceiverDirectionSendonly,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear the Transceiver bufers
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
if _, err := audioTransceiver.Sender.ReadRTCP(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = videoTransceiver.Sender.ReadRTCP(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// set remote description
|
||||||
|
connection.SetRemoteDescription(description)
|
||||||
|
|
||||||
|
answer, err := connection.CreateAnswer(nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = connection.SetLocalDescription(answer); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
connection.OnDataChannel(func(d *webrtc.DataChannel) {
|
||||||
|
d.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||||
|
if err = m.handle(id, msg); err != nil {
|
||||||
|
m.logger.Warn().Err(err).Msg("data handle failed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||||
|
switch state {
|
||||||
|
case webrtc.PeerConnectionStateDisconnected:
|
||||||
|
case webrtc.PeerConnectionStateFailed:
|
||||||
|
m.logger.Info().Str("id", id).Msg("peer disconnected")
|
||||||
|
m.sessions.Destroy(id)
|
||||||
|
break
|
||||||
|
case webrtc.PeerConnectionStateConnected:
|
||||||
|
m.logger.Info().Str("id", id).Msg("peer connected")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return answer.SDP, &Peer{
|
||||||
|
id: id,
|
||||||
|
api: api,
|
||||||
|
engine: engine,
|
||||||
|
video: video,
|
||||||
|
audio: audio,
|
||||||
|
connection: connection,
|
||||||
|
}, nil
|
||||||
|
}
|
@ -3,13 +3,13 @@ package websocket
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"n.eko.moe/neko/internal/event"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
"n.eko.moe/neko/internal/session"
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
func (h *MessageHandler) adminLock(id string, session types.Session) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -33,8 +33,8 @@ func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminUnlock(id string, session *session.Session) error {
|
func (h *MessageHandler) adminUnlock(id string, session types.Session) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -58,8 +58,8 @@ func (h *MessageHandler) adminUnlock(id string, session *session.Session) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminControl(id string, session *session.Session) error {
|
func (h *MessageHandler) adminControl(id string, session types.Session) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -73,7 +73,7 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
|||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_CONTROL,
|
Event: event.ADMIN_CONTROL,
|
||||||
ID: id,
|
ID: id,
|
||||||
Target: host.ID,
|
Target: host.ID(),
|
||||||
}, nil); err != nil {
|
}, nil); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL)
|
||||||
return err
|
return err
|
||||||
@ -92,8 +92,8 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminRelease(id string, session *session.Session) error {
|
func (h *MessageHandler) adminRelease(id string, session types.Session) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -107,7 +107,7 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
|||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_RELEASE,
|
Event: event.ADMIN_RELEASE,
|
||||||
ID: id,
|
ID: id,
|
||||||
Target: host.ID,
|
Target: host.ID(),
|
||||||
}, nil); err != nil {
|
}, nil); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE)
|
||||||
return err
|
return err
|
||||||
@ -126,8 +126,8 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminGive(id string, session *session.Session, payload *message.Admin) error {
|
func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -154,8 +154,8 @@ func (h *MessageHandler) adminGive(id string, session *session.Session, payload
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminMute(id string, session *session.Session, payload *message.Admin) error {
|
func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -166,17 +166,17 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if target.Admin {
|
if target.Admin() {
|
||||||
h.logger.Debug().Msg("target is an admin, baling")
|
h.logger.Debug().Msg("target is an admin, baling")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
target.Muted = true
|
target.SetMuted(true)
|
||||||
|
|
||||||
if err := h.sessions.Brodcast(
|
if err := h.sessions.Brodcast(
|
||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_MUTE,
|
Event: event.ADMIN_MUTE,
|
||||||
Target: target.ID,
|
Target: target.ID(),
|
||||||
ID: id,
|
ID: id,
|
||||||
}, nil); err != nil {
|
}, nil); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||||
@ -186,8 +186,8 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminUnmute(id string, session *session.Session, payload *message.Admin) error {
|
func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -198,12 +198,12 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
target.Muted = false
|
target.SetMuted(false)
|
||||||
|
|
||||||
if err := h.sessions.Brodcast(
|
if err := h.sessions.Brodcast(
|
||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_UNMUTE,
|
Event: event.ADMIN_UNMUTE,
|
||||||
Target: target.ID,
|
Target: target.ID(),
|
||||||
ID: id,
|
ID: id,
|
||||||
}, nil); err != nil {
|
}, nil); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||||
@ -213,8 +213,8 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminKick(id string, session *session.Session, payload *message.Admin) error {
|
func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -225,22 +225,19 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if target.Admin {
|
if target.Admin() {
|
||||||
h.logger.Debug().Msg("target is an admin, baling")
|
h.logger.Debug().Msg("target is an admin, baling")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := target.Kick(message.Disconnect{
|
if err := target.Kick("You have been kicked"); err != nil {
|
||||||
Event: event.SYSTEM_DISCONNECT,
|
|
||||||
Message: "You have been kicked",
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.sessions.Brodcast(
|
if err := h.sessions.Brodcast(
|
||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_KICK,
|
Event: event.ADMIN_KICK,
|
||||||
Target: target.ID,
|
Target: target.ID(),
|
||||||
ID: id,
|
ID: id,
|
||||||
}, []string{payload.ID}); err != nil {
|
}, []string{payload.ID}); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK)
|
||||||
@ -250,8 +247,8 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) adminBan(id string, session *session.Session, payload *message.Admin) error {
|
func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error {
|
||||||
if !session.Admin {
|
if !session.Admin() {
|
||||||
h.logger.Debug().Msg("user not admin")
|
h.logger.Debug().Msg("user not admin")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -262,12 +259,12 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if target.Admin {
|
if target.Admin() {
|
||||||
h.logger.Debug().Msg("target is an admin, baling")
|
h.logger.Debug().Msg("target is an admin, baling")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
remote := target.RemoteAddr()
|
remote := target.Address()
|
||||||
if remote == nil {
|
if remote == nil {
|
||||||
h.logger.Debug().Msg("no remote address, baling")
|
h.logger.Debug().Msg("no remote address, baling")
|
||||||
return nil
|
return nil
|
||||||
@ -283,17 +280,14 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
|||||||
|
|
||||||
h.banned[address[0]] = true
|
h.banned[address[0]] = true
|
||||||
|
|
||||||
if err := target.Kick(message.Disconnect{
|
if err := target.Kick("You have been banned"); err != nil {
|
||||||
Event: event.SYSTEM_DISCONNECT,
|
|
||||||
Message: "You have been banned",
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.sessions.Brodcast(
|
if err := h.sessions.Brodcast(
|
||||||
message.AdminTarget{
|
message.AdminTarget{
|
||||||
Event: event.ADMIN_BAN,
|
Event: event.ADMIN_BAN,
|
||||||
Target: target.ID,
|
Target: target.ID(),
|
||||||
ID: id,
|
ID: id,
|
||||||
}, []string{payload.ID}); err != nil {
|
}, []string{payload.ID}); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"n.eko.moe/neko/internal/event"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
"n.eko.moe/neko/internal/session"
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *MessageHandler) chat(id string, session *session.Session, payload *message.ChatRecieve) error {
|
func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatRecieve) error {
|
||||||
if session.Muted {
|
if session.Muted() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,8 +23,8 @@ func (h *MessageHandler) chat(id string, session *session.Session, payload *mess
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) chatEmote(id string, session *session.Session, payload *message.EmoteRecieve) error {
|
func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteRecieve) error {
|
||||||
if session.Muted {
|
if session.Muted() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"n.eko.moe/neko/internal/event"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
"n.eko.moe/neko/internal/session"
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *MessageHandler) controlRelease(id string, session *session.Session) error {
|
func (h *MessageHandler) controlRelease(id string, session types.Session) error {
|
||||||
|
|
||||||
// check if session is host
|
// check if session is host
|
||||||
if !h.sessions.IsHost(id) {
|
if !h.sessions.IsHost(id) {
|
||||||
@ -31,7 +31,7 @@ func (h *MessageHandler) controlRelease(id string, session *session.Session) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) controlRequest(id string, session *session.Session) error {
|
func (h *MessageHandler) controlRequest(id string, session types.Session) error {
|
||||||
// check for host
|
// check for host
|
||||||
if !h.sessions.HasHost() {
|
if !h.sessions.HasHost() {
|
||||||
// set host
|
// set host
|
||||||
@ -57,7 +57,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
|||||||
// tell session there is a host
|
// tell session there is a host
|
||||||
if err := session.Send(message.Control{
|
if err := session.Send(message.Control{
|
||||||
Event: event.CONTROL_REQUEST,
|
Event: event.CONTROL_REQUEST,
|
||||||
ID: host.ID,
|
ID: host.ID(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
|
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
|
||||||
return err
|
return err
|
||||||
@ -68,7 +68,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
|||||||
Event: event.CONTROL_REQUESTING,
|
Event: event.CONTROL_REQUESTING,
|
||||||
ID: id,
|
ID: id,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
h.logger.Warn().Err(err).Str("id", host.ID).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) controlGive(id string, session *session.Session, payload *message.Control) error {
|
func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error {
|
||||||
// check if session is host
|
// check if session is host
|
||||||
if !h.sessions.IsHost(id) {
|
if !h.sessions.IsHost(id) {
|
||||||
h.logger.Debug().Str("id", id).Msg("is not the host")
|
h.logger.Debug().Str("id", id).Msg("is not the host")
|
||||||
|
@ -1,235 +1,142 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"encoding/json"
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"n.eko.moe/neko/internal/config"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/event"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types/message"
|
||||||
"n.eko.moe/neko/internal/session"
|
|
||||||
"n.eko.moe/neko/internal/utils"
|
"n.eko.moe/neko/internal/utils"
|
||||||
"n.eko.moe/neko/internal/webrtc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(sessions *session.SessionManager, webrtc *webrtc.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
type MessageHandler struct {
|
||||||
logger := log.With().Str("module", "websocket").Logger()
|
|
||||||
|
|
||||||
return &WebSocketHandler{
|
|
||||||
logger: logger,
|
|
||||||
conf: conf,
|
|
||||||
sessions: sessions,
|
|
||||||
upgrader: websocket.Upgrader{
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
handler: &MessageHandler{
|
|
||||||
logger: logger.With().Str("subsystem", "handler").Logger(),
|
|
||||||
sessions: sessions,
|
|
||||||
webrtc: webrtc,
|
|
||||||
banned: make(map[string]bool),
|
|
||||||
locked: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send pings to peer with this period. Must be less than pongWait.
|
|
||||||
const pingPeriod = 60 * time.Second
|
|
||||||
|
|
||||||
type WebSocketHandler struct {
|
|
||||||
logger zerolog.Logger
|
logger zerolog.Logger
|
||||||
upgrader websocket.Upgrader
|
sessions types.SessionManager
|
||||||
handler *MessageHandler
|
webrtc types.WebRTCManager
|
||||||
conf *config.WebSocket
|
banned map[string]bool
|
||||||
sessions *session.SessionManager
|
locked bool
|
||||||
shutdown chan bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebSocketHandler) Start() error {
|
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
|
||||||
|
address := socket.Address()
|
||||||
go func() {
|
if address == nil {
|
||||||
defer func() {
|
h.logger.Debug().Msg("no remote address, baling")
|
||||||
ws.logger.Info().Msg("shutdown")
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ws.shutdown:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ws.sessions.OnCreated(func(id string, session *session.Session) {
|
|
||||||
if err := ws.handler.SessionCreated(id, session); err != nil {
|
|
||||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
|
||||||
} else {
|
} else {
|
||||||
ws.logger.Debug().Str("id", id).Msg("session created")
|
ok, banned := h.banned[*address]
|
||||||
|
if ok && banned {
|
||||||
|
h.logger.Debug().Str("address", *address).Msg("banned")
|
||||||
|
return false, "This IP has been banned", nil
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
ws.sessions.OnConnected(func(id string, session *session.Session) {
|
|
||||||
if err := ws.handler.SessionConnected(id, session); err != nil {
|
|
||||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
|
||||||
} else {
|
|
||||||
ws.logger.Debug().Str("id", id).Msg("session connected")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ws.sessions.OnDestroy(func(id string) {
|
|
||||||
if err := ws.handler.SessionDestroyed(id); err != nil {
|
|
||||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
|
||||||
} else {
|
|
||||||
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebSocketHandler) Shutdown() error {
|
if h.locked {
|
||||||
ws.shutdown <- true
|
h.logger.Debug().Msg("server locked")
|
||||||
return nil
|
return false, "Server is currently locked", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
return true, "", nil
|
||||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
}
|
||||||
|
|
||||||
socket, err := ws.upgrader.Upgrade(w, r, nil)
|
func (h *MessageHandler) Disconnected(id string) error {
|
||||||
if err != nil {
|
return h.sessions.Destroy(id)
|
||||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
}
|
||||||
return err
|
|
||||||
}
|
func (h *MessageHandler) Message(id string, raw []byte) error {
|
||||||
|
header := message.Message{}
|
||||||
id, admin, err := ws.authenticate(r)
|
if err := json.Unmarshal(raw, &header); err != nil {
|
||||||
if err != nil {
|
|
||||||
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
|
||||||
|
|
||||||
if err = socket.WriteJSON(message.Disconnect{
|
|
||||||
Event: event.SYSTEM_DISCONNECT,
|
|
||||||
Message: "invalid password",
|
|
||||||
}); err != nil {
|
|
||||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = socket.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, reason, err := ws.handler.SocketConnected(id, socket)
|
|
||||||
if err != nil {
|
|
||||||
ws.logger.Error().Err(err).Msg("connection failed")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session, ok := h.sessions.Get(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
if err = socket.WriteJSON(message.Disconnect{
|
errors.Errorf("unknown session id %s", id)
|
||||||
Event: event.SYSTEM_DISCONNECT,
|
|
||||||
Message: reason,
|
|
||||||
}); err != nil {
|
|
||||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = socket.Close(); err != nil {
|
switch header.Event {
|
||||||
return err
|
// Signal Events
|
||||||
}
|
case event.SIGNAL_PROVIDE:
|
||||||
|
payload := &message.Signal{}
|
||||||
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.createPeer(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
|
// Identity Events
|
||||||
|
case event.IDENTITY_DETAILS:
|
||||||
|
payload := &message.IdentityDetails{}
|
||||||
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.identityDetails(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
|
|
||||||
return nil
|
// Control Events
|
||||||
}
|
case event.CONTROL_RELEASE:
|
||||||
|
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
||||||
|
case event.CONTROL_REQUEST:
|
||||||
|
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
||||||
|
case event.CONTROL_GIVE:
|
||||||
|
payload := &message.Control{}
|
||||||
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.controlGive(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
|
|
||||||
ws.sessions.New(id, admin, socket)
|
// Chat Events
|
||||||
|
case event.CHAT_MESSAGE:
|
||||||
|
payload := &message.ChatRecieve{}
|
||||||
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.chat(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
|
case event.CHAT_EMOTE:
|
||||||
|
payload := &message.EmoteRecieve{}
|
||||||
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.chatEmote(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
|
|
||||||
ws.logger.
|
// Admin Events
|
||||||
Debug().
|
case event.ADMIN_LOCK:
|
||||||
Str("session", id).
|
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
||||||
Str("address", socket.RemoteAddr().String()).
|
case event.ADMIN_UNLOCK:
|
||||||
Msg("new connection created")
|
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
||||||
|
case event.ADMIN_CONTROL:
|
||||||
defer func() {
|
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
||||||
ws.logger.
|
case event.ADMIN_RELEASE:
|
||||||
Debug().
|
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
||||||
Str("session", id).
|
case event.ADMIN_GIVE:
|
||||||
Str("address", socket.RemoteAddr().String()).
|
payload := &message.Admin{}
|
||||||
Msg("session ended")
|
return errors.Wrapf(
|
||||||
}()
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
|
return h.adminGive(id, session, payload)
|
||||||
ws.handle(socket, id)
|
}), "%s failed", header.Event)
|
||||||
return nil
|
case event.ADMIN_BAN:
|
||||||
}
|
payload := &message.Admin{}
|
||||||
|
return errors.Wrapf(
|
||||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
id, err := utils.NewUID(32)
|
return h.adminBan(id, session, payload)
|
||||||
if err != nil {
|
}), "%s failed", header.Event)
|
||||||
return "", false, err
|
case event.ADMIN_KICK:
|
||||||
}
|
payload := &message.Admin{}
|
||||||
|
return errors.Wrapf(
|
||||||
passwords, ok := r.URL.Query()["password"]
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
if !ok || len(passwords[0]) < 1 {
|
return h.adminKick(id, session, payload)
|
||||||
return "", false, fmt.Errorf("no password provided")
|
}), "%s failed", header.Event)
|
||||||
}
|
case event.ADMIN_MUTE:
|
||||||
|
payload := &message.Admin{}
|
||||||
if passwords[0] == ws.conf.AdminPassword {
|
return errors.Wrapf(
|
||||||
return id, true, nil
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
}
|
return h.adminMute(id, session, payload)
|
||||||
|
}), "%s failed", header.Event)
|
||||||
if passwords[0] == ws.conf.Password {
|
case event.ADMIN_UNMUTE:
|
||||||
return id, false, nil
|
payload := &message.Admin{}
|
||||||
}
|
return errors.Wrapf(
|
||||||
|
utils.Unmarshal(payload, raw, func() error {
|
||||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
return h.adminUnmute(id, session, payload)
|
||||||
}
|
}), "%s failed", header.Event)
|
||||||
|
default:
|
||||||
func (ws *WebSocketHandler) handle(socket *websocket.Conn, id string) {
|
return errors.Errorf("unknown message event %s", header.Event)
|
||||||
bytes := make(chan []byte)
|
|
||||||
cancel := make(chan struct{})
|
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
ticker.Stop()
|
|
||||||
ws.logger.Debug().Str("address", socket.RemoteAddr().String()).Msg("handle socket ending")
|
|
||||||
ws.handler.SocketDisconnected(id)
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
_, raw, err := socket.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
||||||
ws.logger.Warn().Err(err).Msg("read message error")
|
|
||||||
} else {
|
|
||||||
ws.logger.Debug().Err(err).Msg("read message error")
|
|
||||||
}
|
|
||||||
close(cancel)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
bytes <- raw
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case raw := <-bytes:
|
|
||||||
ws.logger.Debug().
|
|
||||||
Str("session", id).
|
|
||||||
Str("raw", string(raw)).
|
|
||||||
Msg("recieved message from client")
|
|
||||||
if err := ws.handler.Message(id, raw); err != nil {
|
|
||||||
ws.logger.Error().Err(err).Msg("message handler has failed")
|
|
||||||
}
|
|
||||||
case <-cancel:
|
|
||||||
return
|
|
||||||
case _ = <-ticker.C:
|
|
||||||
if err := socket.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,34 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/session"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *MessageHandler) identityDetails(id string, session *session.Session, payload *message.IdentityDetails) error {
|
func (h *MessageHandler) identityDetails(id string, session types.Session, payload *message.IdentityDetails) error {
|
||||||
if _, err := h.sessions.SetName(id, payload.Username); err != nil {
|
if err := session.SetName(payload.Username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MessageHandler) createPeer(id string, session types.Session, payload *message.Signal) error {
|
||||||
|
sdp, peer, err := h.webrtc.CreatePeer(id, payload.SDP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.SetPeer(peer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.Send(message.Signal{
|
||||||
|
Event: event.SIGNAL_ANSWER,
|
||||||
|
SDP: sdp,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,149 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
|
|
||||||
"n.eko.moe/neko/internal/event"
|
|
||||||
"n.eko.moe/neko/internal/message"
|
|
||||||
"n.eko.moe/neko/internal/session"
|
|
||||||
"n.eko.moe/neko/internal/utils"
|
|
||||||
"n.eko.moe/neko/internal/webrtc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageHandler struct {
|
|
||||||
logger zerolog.Logger
|
|
||||||
sessions *session.SessionManager
|
|
||||||
webrtc *webrtc.WebRTCManager
|
|
||||||
banned map[string]bool
|
|
||||||
locked bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MessageHandler) SocketConnected(id string, socket *websocket.Conn) (bool, string, error) {
|
|
||||||
remote := socket.RemoteAddr().String()
|
|
||||||
if remote != "" {
|
|
||||||
address := strings.SplitN(remote, ":", -1)
|
|
||||||
if len(address[0]) < 1 {
|
|
||||||
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
|
|
||||||
} else {
|
|
||||||
|
|
||||||
ok, banned := h.banned[address[0]]
|
|
||||||
if ok && banned {
|
|
||||||
h.logger.Debug().Str("address", remote).Msg("banned")
|
|
||||||
return false, "This IP has been banned", nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.locked {
|
|
||||||
h.logger.Debug().Str("address", remote).Msg("locked")
|
|
||||||
return false, "Server is currently locked", nil
|
|
||||||
}
|
|
||||||
return true, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MessageHandler) SocketDisconnected(id string) error {
|
|
||||||
return h.sessions.Destroy(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MessageHandler) Message(id string, raw []byte) error {
|
|
||||||
header := message.Message{}
|
|
||||||
if err := json.Unmarshal(raw, &header); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, ok := h.sessions.Get(id)
|
|
||||||
if !ok {
|
|
||||||
errors.Errorf("unknown session id %s", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch header.Event {
|
|
||||||
// Signal Events
|
|
||||||
case event.SIGNAL_PROVIDE:
|
|
||||||
payload := message.Signal{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(&payload, raw, func() error {
|
|
||||||
return h.webrtc.CreatePeer(id, payload.SDP)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
|
|
||||||
// Identity Events
|
|
||||||
case event.IDENTITY_DETAILS:
|
|
||||||
payload := &message.IdentityDetails{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.identityDetails(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
|
|
||||||
// Control Events
|
|
||||||
case event.CONTROL_RELEASE:
|
|
||||||
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
|
||||||
case event.CONTROL_REQUEST:
|
|
||||||
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
|
||||||
case event.CONTROL_GIVE:
|
|
||||||
payload := &message.Control{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.controlGive(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
|
|
||||||
// Chat Events
|
|
||||||
case event.CHAT_MESSAGE:
|
|
||||||
payload := &message.ChatRecieve{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.chat(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
case event.CHAT_EMOTE:
|
|
||||||
payload := &message.EmoteRecieve{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.chatEmote(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
|
|
||||||
// Admin Events
|
|
||||||
case event.ADMIN_LOCK:
|
|
||||||
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_UNLOCK:
|
|
||||||
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_CONTROL:
|
|
||||||
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_RELEASE:
|
|
||||||
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_GIVE:
|
|
||||||
payload := &message.Admin{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.adminGive(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_BAN:
|
|
||||||
payload := &message.Admin{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.adminBan(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_KICK:
|
|
||||||
payload := &message.Admin{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.adminKick(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_MUTE:
|
|
||||||
payload := &message.Admin{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.adminMute(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
case event.ADMIN_UNMUTE:
|
|
||||||
payload := &message.Admin{}
|
|
||||||
return errors.Wrapf(
|
|
||||||
utils.Unmarshal(payload, raw, func() error {
|
|
||||||
return h.adminUnmute(id, session, payload)
|
|
||||||
}), "%s failed", header.Event)
|
|
||||||
default:
|
|
||||||
return errors.Errorf("unknown message event %s", header.Event)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +1,12 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"n.eko.moe/neko/internal/event"
|
"n.eko.moe/neko/internal/types"
|
||||||
"n.eko.moe/neko/internal/message"
|
"n.eko.moe/neko/internal/types/event"
|
||||||
"n.eko.moe/neko/internal/session"
|
"n.eko.moe/neko/internal/types/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *MessageHandler) SessionCreated(id string, session *session.Session) error {
|
func (h *MessageHandler) SessionCreated(id string, session types.Session) error {
|
||||||
if err := session.Send(message.Identity{
|
if err := session.Send(message.Identity{
|
||||||
Event: event.IDENTITY_PROVIDE,
|
Event: event.IDENTITY_PROVIDE,
|
||||||
ID: id,
|
ID: id,
|
||||||
@ -17,11 +17,11 @@ func (h *MessageHandler) SessionCreated(id string, session *session.Session) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MessageHandler) SessionConnected(id string, session *session.Session) error {
|
func (h *MessageHandler) SessionConnected(id string, session types.Session) error {
|
||||||
// send list of members to session
|
// send list of members to session
|
||||||
if err := session.Send(message.MembersList{
|
if err := session.Send(message.MembersList{
|
||||||
Event: event.MEMBER_LIST,
|
Event: event.MEMBER_LIST,
|
||||||
Memebers: h.sessions.GetConnected(),
|
Memebers: h.sessions.Members(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
|
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
|
||||||
return err
|
return err
|
||||||
@ -32,7 +32,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
|||||||
if ok {
|
if ok {
|
||||||
if err := session.Send(message.Control{
|
if err := session.Send(message.Control{
|
||||||
Event: event.CONTROL_LOCKED,
|
Event: event.CONTROL_LOCKED,
|
||||||
ID: host.ID,
|
ID: host.ID(),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
|
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
|
||||||
return err
|
return err
|
||||||
@ -43,7 +43,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
|||||||
if err := h.sessions.Brodcast(
|
if err := h.sessions.Brodcast(
|
||||||
message.Member{
|
message.Member{
|
||||||
Event: event.MEMBER_CONNECTED,
|
Event: event.MEMBER_CONNECTED,
|
||||||
Session: session,
|
Member: session.Member(),
|
||||||
}, nil); err != nil {
|
}, nil); err != nil {
|
||||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE)
|
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE)
|
||||||
return err
|
return err
|
||||||
|
37
server/internal/websocket/socket.go
Normal file
37
server/internal/websocket/socket.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WebSocket struct {
|
||||||
|
id string
|
||||||
|
connection *websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *WebSocket) Address() *string {
|
||||||
|
remote := socket.connection.RemoteAddr()
|
||||||
|
address := strings.SplitN(remote.String(), ":", -1)
|
||||||
|
if len(address[0]) < 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &address[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *WebSocket) Send(v interface{}) error {
|
||||||
|
if socket.connection == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return socket.connection.WriteJSON(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *WebSocket) Destroy() error {
|
||||||
|
if socket.connection == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return socket.connection.Close()
|
||||||
|
}
|
239
server/internal/websocket/websocket.go
Normal file
239
server/internal/websocket/websocket.go
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"n.eko.moe/neko/internal/config"
|
||||||
|
"n.eko.moe/neko/internal/types"
|
||||||
|
"n.eko.moe/neko/internal/types/event"
|
||||||
|
"n.eko.moe/neko/internal/types/message"
|
||||||
|
"n.eko.moe/neko/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(sessions types.SessionManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||||
|
logger := log.With().Str("module", "websocket").Logger()
|
||||||
|
|
||||||
|
return &WebSocketHandler{
|
||||||
|
logger: logger,
|
||||||
|
conf: conf,
|
||||||
|
sessions: sessions,
|
||||||
|
upgrader: websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
handler: &MessageHandler{
|
||||||
|
logger: logger.With().Str("subsystem", "handler").Logger(),
|
||||||
|
sessions: sessions,
|
||||||
|
webrtc: webrtc,
|
||||||
|
banned: make(map[string]bool),
|
||||||
|
locked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send pings to peer with this period. Must be less than pongWait.
|
||||||
|
const pingPeriod = 60 * time.Second
|
||||||
|
|
||||||
|
type WebSocketHandler struct {
|
||||||
|
logger zerolog.Logger
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
sessions types.SessionManager
|
||||||
|
conf *config.WebSocket
|
||||||
|
handler *MessageHandler
|
||||||
|
shutdown chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketHandler) Start() error {
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
ws.logger.Info().Msg("shutdown")
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ws.shutdown:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ws.sessions.OnCreated(func(id string, session types.Session) {
|
||||||
|
if err := ws.handler.SessionCreated(id, session); err != nil {
|
||||||
|
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
||||||
|
} else {
|
||||||
|
ws.logger.Debug().Str("id", id).Msg("session created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.sessions.OnConnected(func(id string, session types.Session) {
|
||||||
|
if err := ws.handler.SessionConnected(id, session); err != nil {
|
||||||
|
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
||||||
|
} else {
|
||||||
|
ws.logger.Debug().Str("id", id).Msg("session connected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.sessions.OnDestroy(func(id string) {
|
||||||
|
if err := ws.handler.SessionDestroyed(id); err != nil {
|
||||||
|
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
||||||
|
} else {
|
||||||
|
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketHandler) Shutdown() error {
|
||||||
|
ws.shutdown <- true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||||
|
|
||||||
|
connection, err := ws.upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, admin, err := ws.authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
||||||
|
|
||||||
|
if err = connection.WriteJSON(message.Disconnect{
|
||||||
|
Event: event.SYSTEM_DISCONNECT,
|
||||||
|
Message: "invalid password",
|
||||||
|
}); err != nil {
|
||||||
|
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = connection.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
socket := &WebSocket{
|
||||||
|
id: id,
|
||||||
|
connection: connection,
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, reason, err := ws.handler.Connected(id, socket)
|
||||||
|
if err != nil {
|
||||||
|
ws.logger.Error().Err(err).Msg("connection failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
if err = connection.WriteJSON(message.Disconnect{
|
||||||
|
Event: event.SYSTEM_DISCONNECT,
|
||||||
|
Message: reason,
|
||||||
|
}); err != nil {
|
||||||
|
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = connection.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.sessions.New(id, admin, socket)
|
||||||
|
|
||||||
|
ws.logger.
|
||||||
|
Debug().
|
||||||
|
Str("session", id).
|
||||||
|
Str("address", connection.RemoteAddr().String()).
|
||||||
|
Msg("new connection created")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
ws.logger.
|
||||||
|
Debug().
|
||||||
|
Str("session", id).
|
||||||
|
Str("address", connection.RemoteAddr().String()).
|
||||||
|
Msg("session ended")
|
||||||
|
}()
|
||||||
|
|
||||||
|
ws.handle(connection, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||||
|
id, err := utils.NewUID(32)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
passwords, ok := r.URL.Query()["password"]
|
||||||
|
if !ok || len(passwords[0]) < 1 {
|
||||||
|
return "", false, fmt.Errorf("no password provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if passwords[0] == ws.conf.AdminPassword {
|
||||||
|
return id, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if passwords[0] == ws.conf.Password {
|
||||||
|
return id, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
|
||||||
|
bytes := make(chan []byte)
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
ticker := time.NewTicker(pingPeriod)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending")
|
||||||
|
ws.handler.Disconnected(id)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, raw, err := connection.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
ws.logger.Warn().Err(err).Msg("read message error")
|
||||||
|
} else {
|
||||||
|
ws.logger.Debug().Err(err).Msg("read message error")
|
||||||
|
}
|
||||||
|
close(cancel)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
bytes <- raw
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case raw := <-bytes:
|
||||||
|
ws.logger.Debug().
|
||||||
|
Str("session", id).
|
||||||
|
Str("raw", string(raw)).
|
||||||
|
Msg("recieved message from client")
|
||||||
|
if err := ws.handler.Message(id, raw); err != nil {
|
||||||
|
ws.logger.Error().Err(err).Msg("message handler has failed")
|
||||||
|
}
|
||||||
|
case <-cancel:
|
||||||
|
return
|
||||||
|
case _ = <-ticker.C:
|
||||||
|
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user