diff --git a/client/src/neko/base.ts b/client/src/neko/base.ts index d9cda1c..26fc4f2 100644 --- a/client/src/neko/base.ts +++ b/client/src/neko/base.ts @@ -62,8 +62,8 @@ export abstract class BaseClient extends EventEmitter { this._ws = new WebSocket(`${url}ws?password=${password}`) this.emit('debug', `connecting to ${this._ws.url}`) this._ws.onmessage = this.onMessage.bind(this) - this._ws.onerror = (event) => this.onError.bind(this) - this._ws.onclose = (event) => this.onDisconnected.bind(this, new Error('websocket closed')) + this._ws.onerror = event => this.onError.bind(this) + this._ws.onclose = event => this.onDisconnected.bind(this, new Error('websocket closed')) this._timeout = setTimeout(this.onTimeout.bind(this), 15000) } catch (err) { this.onDisconnected(err) @@ -156,7 +156,7 @@ export abstract class BaseClient extends EventEmitter { this._ws!.send(JSON.stringify({ event, ...payload })) } - public createPeer(sdp: string) { + public createPeer(sdp: string, lite: boolean, servers: string[]) { this.emit('debug', `creating peer`) if (!this.socketOpen) { this.emit( @@ -173,16 +173,21 @@ export abstract class BaseClient extends EventEmitter { } this._peer = new RTCPeerConnection() + if (lite !== true) { + this._peer = new RTCPeerConnection({ + iceServers: [{ urls: servers }], + }) + } - this._peer.onconnectionstatechange = (event) => { + this._peer.onconnectionstatechange = event => { this.emit('debug', `peer connection state changed`, this._peer ? this._peer.connectionState : undefined) } - this._peer.onsignalingstatechange = (event) => { + this._peer.onsignalingstatechange = event => { this.emit('debug', `peer signaling state changed`, this._peer ? this._peer.signalingState : undefined) } - this._peer.oniceconnectionstatechange = (event) => { + this._peer.oniceconnectionstatechange = event => { this._state = this._peer!.iceConnectionState this.emit('debug', `peer ice connection state changed: ${this._peer!.iceConnectionState}`) @@ -217,7 +222,7 @@ export abstract class BaseClient extends EventEmitter { this._peer.setRemoteDescription({ type: 'offer', sdp }) this._peer .createAnswer() - .then((d) => { + .then(d => { this._peer!.setLocalDescription(d) this._ws!.send( JSON.stringify({ @@ -227,7 +232,7 @@ export abstract class BaseClient extends EventEmitter { }), ) }) - .catch((err) => this.emit('error', err)) + .catch(err => this.emit('error', err)) } private onMessage(e: MessageEvent) { @@ -236,9 +241,9 @@ export abstract class BaseClient extends EventEmitter { this.emit('debug', `received websocket event ${event} ${payload ? `with payload: ` : ''}`, payload) if (event === EVENT.SIGNAL.PROVIDE) { - const { sdp, id } = payload as SignalProvidePayload + const { sdp, lite, ice, id } = payload as SignalProvidePayload this._id = id - this.createPeer(sdp) + this.createPeer(sdp, lite, ice) return } diff --git a/client/src/neko/messages.ts b/client/src/neko/messages.ts index 34435c3..b1eec75 100644 --- a/client/src/neko/messages.ts +++ b/client/src/neko/messages.ts @@ -61,6 +61,8 @@ export interface SignalProvideMessage extends WebSocketMessage, SignalProvidePay } export interface SignalProvidePayload { id: string + lite: boolean + ice: string[] sdp: string } diff --git a/server/internal/types/config/webrtc.go b/server/internal/types/config/webrtc.go index d779358..742a1aa 100644 --- a/server/internal/types/config/webrtc.go +++ b/server/internal/types/config/webrtc.go @@ -16,6 +16,8 @@ type WebRTC struct { AudioCodec string AudioParams string Display string + ICELite bool + ICEServers []string VideoCodec string VideoParams string EphemeralMin uint16 @@ -99,6 +101,16 @@ func (WebRTC) Init(cmd *cobra.Command) error { return err } + cmd.PersistentFlags().Bool("icelite", false, "") + if err := viper.BindPFlag("icelite", cmd.PersistentFlags().Lookup("icelite")); err != nil { + return err + } + + cmd.PersistentFlags().StringSlice("iceserver", []string{"stun:stun.l.google.com:19302"}, "") + if err := viper.BindPFlag("iceserver", cmd.PersistentFlags().Lookup("iceserver")); err != nil { + return err + } + return nil } @@ -123,6 +135,9 @@ func (s *WebRTC) Set() { audioCodec = webrtc.PCMA } + s.ICELite = viper.GetBool("icelite") + s.ICEServers = viper.GetStringSlice("iceserver") + s.Device = viper.GetString("device") s.AudioCodec = audioCodec s.AudioParams = viper.GetString("aparams") diff --git a/server/internal/types/message/messages.go b/server/internal/types/message/messages.go index 4f573e2..a2f099f 100644 --- a/server/internal/types/message/messages.go +++ b/server/internal/types/message/messages.go @@ -14,9 +14,11 @@ type Disconnect struct { } type SignalProvide struct { - Event string `json:"event"` - ID string `json:"id"` - SDP string `json:"sdp"` + Event string `json:"event"` + ID string `json:"id"` + SDP string `json:"sdp"` + Lite bool `json:"lite"` + ICE []string `json:"ice"` } type SignalAnswer struct { diff --git a/server/internal/types/webrtc.go b/server/internal/types/webrtc.go index 7b342b0..c997784 100644 --- a/server/internal/types/webrtc.go +++ b/server/internal/types/webrtc.go @@ -8,7 +8,7 @@ type Sample struct { type WebRTCManager interface { Start() Shutdown() error - CreatePeer(id string, session Session) (string, error) + CreatePeer(id string, session Session) (string, bool, []string, error) ChangeScreenSize(width int, height int, rate int) error } diff --git a/server/internal/webrtc/peer.go b/server/internal/webrtc/peer.go index 3b75e29..f82c4fa 100644 --- a/server/internal/webrtc/peer.go +++ b/server/internal/webrtc/peer.go @@ -7,10 +7,14 @@ import ( ) type Peer struct { - id string - manager *WebRTCManager - connection *webrtc.PeerConnection - mu sync.Mutex + id string + api *webrtc.API + engine *webrtc.MediaEngine + manager *WebRTCManager + settings *webrtc.SettingEngine + connection *webrtc.PeerConnection + configuration *webrtc.Configuration + mu sync.Mutex } func (peer *Peer) SignalAnswer(sdp string) error { diff --git a/server/internal/webrtc/tracks.go b/server/internal/webrtc/tracks.go index 7da358e..04ae9cb 100644 --- a/server/internal/webrtc/tracks.go +++ b/server/internal/webrtc/tracks.go @@ -8,7 +8,7 @@ import ( "n.eko.moe/neko/internal/gst" ) -func (m *WebRTCManager) createTrack(codecName string, pipelineDevice string, pipelineSrc string) (*gst.Pipeline, *webrtc.Track, error) { +func (m *WebRTCManager) createTrack(codecName string, pipelineDevice string, pipelineSrc string) (*gst.Pipeline, *webrtc.Track, *webrtc.RTPCodec, error) { pipeline, err := gst.CreatePipeline( codecName, pipelineDevice, @@ -16,7 +16,7 @@ func (m *WebRTCManager) createTrack(codecName string, pipelineDevice string, pip ) if err != nil { - return nil, nil, err + return nil, nil, nil, err } var codec *webrtc.RTPCodec @@ -36,14 +36,13 @@ func (m *WebRTCManager) createTrack(codecName string, pipelineDevice string, pip case webrtc.PCMA: codec = webrtc.NewRTPPCMACodec(webrtc.DefaultPayloadTypePCMA, 8000) default: - return nil, nil, fmt.Errorf("unknown codec %s", codecName) + return nil, nil, nil, fmt.Errorf("unknown codec %s", codecName) } - m.engine.RegisterCodec(codec) track, err := webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return pipeline, track, nil + return pipeline, track, codec, nil } diff --git a/server/internal/webrtc/webrtc.go b/server/internal/webrtc/webrtc.go index 671f1cf..03274e7 100644 --- a/server/internal/webrtc/webrtc.go +++ b/server/internal/webrtc/webrtc.go @@ -18,53 +18,27 @@ import ( ) func New(sessions types.SessionManager, config *config.WebRTC) *WebRTCManager { - logger := log.With().Str("module", "webrtc").Logger() - settings := webrtc.SettingEngine{ - LoggerFactory: loggerFactory{ - logger: logger, - }, - } - - settings.SetLite(true) - settings.SetEphemeralUDPPortRange(config.EphemeralMin, config.EphemeralMax) - settings.SetNAT1To1IPs(config.NAT1To1IPs, webrtc.ICECandidateTypeHost) - - // Create MediaEngine based off sdp - engine := webrtc.MediaEngine{} - engine.RegisterDefaultCodecs() - - // Create API with MediaEngine and SettingEngine - api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(settings)) - return &WebRTCManager{ - logger: logger, - settings: settings, + logger: log.With().Str("module", "webrtc").Logger(), cleanup: time.NewTicker(1 * time.Second), shutdown: make(chan bool), sessions: sessions, - engine: engine, config: config, - api: api, - configuration: &webrtc.Configuration{ - SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, - }, } } type WebRTCManager struct { logger zerolog.Logger - settings webrtc.SettingEngine - engine webrtc.MediaEngine - api *webrtc.API videoTrack *webrtc.Track audioTrack *webrtc.Track videoPipeline *gst.Pipeline audioPipeline *gst.Pipeline + videoCodec *webrtc.RTPCodec + audioCodec *webrtc.RTPCodec sessions types.SessionManager cleanup *time.Ticker config *config.WebRTC shutdown chan bool - configuration *webrtc.Configuration } func (m *WebRTCManager) Start() { @@ -79,12 +53,12 @@ func (m *WebRTCManager) Start() { } var err error - m.videoPipeline, m.videoTrack, err = m.createTrack(m.config.VideoCodec, m.config.Display, m.config.VideoParams) + m.videoPipeline, m.videoTrack, m.videoCodec, err = m.createTrack(m.config.VideoCodec, m.config.Display, m.config.VideoParams) if err != nil { m.logger.Panic().Err(err).Msg("unable to start webrtc manager") } - m.audioPipeline, m.audioTrack, err = m.createTrack(m.config.AudioCodec, m.config.Device, m.config.AudioParams) + m.audioPipeline, m.audioTrack, m.audioCodec, err = m.createTrack(m.config.AudioCodec, m.config.Device, m.config.AudioParams) if err != nil { m.logger.Panic().Err(err).Msg("unable to start webrtc manager") } @@ -133,10 +107,12 @@ func (m *WebRTCManager) Start() { Str("video_codec", m.config.VideoCodec). Str("audio_device", m.config.Device). Str("audio_codec", m.config.AudioCodec). - Str("ephemeral_port_range", fmt.Sprintf("%d-%d", m.config.EphemeralMin, m.config.EphemeralMax)). - Str("nat_ips", strings.Join(m.config.NAT1To1IPs, ",")). Str("audio_pipeline_src", m.audioPipeline.Src). Str("video_pipeline_src", m.videoPipeline.Src). + Str("ice_lite", fmt.Sprintf("%t", m.config.ICELite)). + Str("ice_servers", strings.Join(m.config.ICEServers, ",")). + Str("ephemeral_port_range", fmt.Sprintf("%d-%d", m.config.EphemeralMin, m.config.EphemeralMax)). + Str("nat_ips", strings.Join(m.config.NAT1To1IPs, ",")). Msgf("webrtc streaming") } @@ -149,28 +125,62 @@ func (m *WebRTCManager) Shutdown() error { return nil } -func (m *WebRTCManager) CreatePeer(id string, session types.Session) (string, error) { +func (m *WebRTCManager) CreatePeer(id string, session types.Session) (string, bool, []string, error) { + configuration := &webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: m.config.ICEServers, + }, + }, + SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, + } + + settings := webrtc.SettingEngine{ + LoggerFactory: loggerFactory{ + logger: m.logger, + }, + } + + if m.config.ICELite { + configuration = &webrtc.Configuration{ + SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, + } + settings.SetLite(true) + } + + settings.SetEphemeralUDPPortRange(m.config.EphemeralMin, m.config.EphemeralMax) + settings.SetNAT1To1IPs(m.config.NAT1To1IPs, webrtc.ICECandidateTypeHost) + + // Create MediaEngine based off sdp + engine := webrtc.MediaEngine{} + // engine.RegisterDefaultCodecs() + engine.RegisterCodec(m.audioCodec) + engine.RegisterCodec(m.videoCodec) + + // Create API with MediaEngine and SettingEngine + api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(settings)) + // Create new peer connection - connection, err := m.api.NewPeerConnection(*m.configuration) + connection, err := api.NewPeerConnection(*configuration) if err != nil { - return "", err + return "", m.config.ICELite, m.config.ICEServers, err } if _, err = connection.AddTransceiverFromTrack(m.videoTrack, webrtc.RtpTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionSendonly, }); err != nil { - return "", err + return "", m.config.ICELite, m.config.ICEServers, err } if _, err = connection.AddTransceiverFromTrack(m.audioTrack, webrtc.RtpTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionSendonly, }); err != nil { - return "", err + return "", m.config.ICELite, m.config.ICEServers, err } description, err := connection.CreateOffer(nil) if err != nil { - return "", err + return "", m.config.ICELite, m.config.ICEServers, err } connection.OnDataChannel(func(d *webrtc.DataChannel) { @@ -200,14 +210,18 @@ func (m *WebRTCManager) CreatePeer(id string, session types.Session) (string, er }) if err := session.SetPeer(&Peer{ - id: id, - manager: m, - connection: connection, + id: id, + api: api, + engine: &engine, + manager: m, + settings: &settings, + connection: connection, + configuration: configuration, }); err != nil { - return "", err + return "", m.config.ICELite, m.config.ICEServers, err } - return description.SDP, nil + return description.SDP, m.config.ICELite, m.config.ICEServers, nil } func (m *WebRTCManager) ChangeScreenSize(width int, height int, rate int) error { diff --git a/server/internal/websocket/signal.go b/server/internal/websocket/signal.go index e3d49e2..e24db55 100644 --- a/server/internal/websocket/signal.go +++ b/server/internal/websocket/signal.go @@ -7,7 +7,7 @@ import ( ) func (h *MessageHandler) signalProvide(id string, session types.Session) error { - sdp, err := h.webrtc.CreatePeer(id, session) + sdp, lite, ice, err := h.webrtc.CreatePeer(id, session) if err != nil { return err } @@ -16,6 +16,8 @@ func (h *MessageHandler) signalProvide(id string, session types.Session) error { Event: event.SIGNAL_PROVIDE, ID: id, SDP: sdp, + Lite: lite, + ICE: ice, }); err != nil { return err }