diff --git a/internal/config/webrtc.go b/internal/config/webrtc.go index d51383ca..8c8cf46e 100644 --- a/internal/config/webrtc.go +++ b/internal/config/webrtc.go @@ -10,12 +10,13 @@ import ( "github.com/spf13/viper" "demodesk/neko/internal/utils" + "demodesk/neko/internal/types" ) type WebRTC struct { ICELite bool ICETrickle bool - ICEServers []string + ICEServers []types.ICEServer EphemeralMin uint16 EphemeralMax uint16 @@ -26,6 +27,7 @@ type WebRTC struct { const ( defEprMin = 59000 defEprMax = 59100 + defStun = "stun:stun.l.google.com:19302" ) func (WebRTC) Init(cmd *cobra.Command) error { @@ -39,8 +41,8 @@ func (WebRTC) Init(cmd *cobra.Command) error { return err } - cmd.PersistentFlags().StringSlice("webrtc.iceserver", []string{"stun:stun.l.google.com:19302"}, "describes a single STUN and TURN server that can be used by the ICEAgent to establish a connection with a peer") - if err := viper.BindPFlag("webrtc.iceserver", cmd.PersistentFlags().Lookup("webrtc.iceserver")); err != nil { + cmd.PersistentFlags().String("webrtc.iceservers", "[]", "STUN and TURN servers in JSON format with `urls`, `username`, `password` keys") + if err := viper.BindPFlag("webrtc.iceservers", cmd.PersistentFlags().Lookup("webrtc.iceservers")); err != nil { return err } @@ -65,7 +67,18 @@ func (WebRTC) Init(cmd *cobra.Command) error { func (s *WebRTC) Set() { s.ICELite = viper.GetBool("webrtc.icelite") s.ICETrickle = viper.GetBool("webrtc.icetrickle") - s.ICEServers = viper.GetStringSlice("webrtc.iceserver") + + if err := viper.UnmarshalKey("webrtc.iceservers", &s.ICEServers, viper.DecodeHook( + utils.JsonStringAutoDecode(s.ICEServers), + )); err != nil { + log.Warn().Err(err).Msgf("unable to parse ICE servers") + } + + if len(s.ICEServers) == 0 { + s.ICEServers = append(s.ICEServers, types.ICEServer{ + URLs: []string{defStun}, + }) + } s.NAT1To1IPs = viper.GetStringSlice("webrtc.nat1to1") s.IpRetrievalUrl = viper.GetString("webrtc.ip_retrieval_url") diff --git a/internal/types/message/messages.go b/internal/types/message/messages.go index f6b6f112..21bb9aaa 100644 --- a/internal/types/message/messages.go +++ b/internal/types/message/messages.go @@ -40,12 +40,11 @@ type SystemDisconnect struct { ///////////////////////////// type SignalProvide struct { - Event string `json:"event,omitempty"` - SDP string `json:"sdp"` - Lite bool `json:"lite"` - ICE []string `json:"ice"` - Videos []string `json:"videos"` - Video string `json:"video"` + Event string `json:"event,omitempty"` + SDP string `json:"sdp"` + ICEServers []types.ICEServer `json:"iceservers"` + Videos []string `json:"videos"` + Video string `json:"video"` } type SignalCandidate struct { diff --git a/internal/types/webrtc.go b/internal/types/webrtc.go index 4560d681..39072103 100644 --- a/internal/types/webrtc.go +++ b/internal/types/webrtc.go @@ -2,6 +2,12 @@ package types import "github.com/pion/webrtc/v3" +type ICEServer struct { + URLs []string `mapstructure:"urls" json:"urls"` + Username string `mapstructure:"username" json:"username"` + Credential string `mapstructure:"credential" json:"credential"` +} + type WebRTCPeer interface { SignalAnswer(sdp string) error SignalCandidate(candidate webrtc.ICECandidateInit) error @@ -17,8 +23,7 @@ type WebRTCManager interface { Start() Shutdown() error - ICELite() bool - ICEServers() []string + ICEServers() []ICEServer CreatePeer(session Session, videoID string) (*webrtc.SessionDescription, error) } diff --git a/internal/utils/json.go b/internal/utils/json.go index 7ea0494d..752a08e1 100644 --- a/internal/utils/json.go +++ b/internal/utils/json.go @@ -1,6 +1,9 @@ package utils -import "encoding/json" +import ( + "encoding/json" + "reflect" +) func Unmarshal(in interface{}, raw []byte, callback func() error) error { if err := json.Unmarshal(raw, &in); err != nil { @@ -8,3 +11,19 @@ func Unmarshal(in interface{}, raw []byte, callback func() error) error { } return callback() } + +func JsonStringAutoDecode(m interface{}) func(rf reflect.Kind, rt reflect.Kind, data interface{}) (interface{}, error) { + return func(rf reflect.Kind, rt reflect.Kind, data interface{}) (interface{}, error) { + if rf != reflect.String || rt == reflect.String { + return data, nil + } + + raw := data.(string) + if raw != "" && (raw[0:1] == "{" || raw[0:1] == "[") { + err := json.Unmarshal([]byte(raw), &m) + return m, err + } + + return data, nil + } +} diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 10a66706..cd50c671 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -77,12 +77,17 @@ func (manager *WebRTCManagerCtx) Start() { audio.RemoveListener(&audioListener) } + var servers []string + for _, server := range manager.config.ICEServers { + servers = append(servers, server.URLs...) + } + manager.logger.Info(). - Str("ice_lite", fmt.Sprintf("%t", manager.config.ICELite)). - Str("ice_trickle", fmt.Sprintf("%t", manager.config.ICETrickle)). - Str("ice_servers", strings.Join(manager.config.ICEServers, ",")). - Str("ephemeral_port_range", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)). - Str("nat_ips", strings.Join(manager.config.NAT1To1IPs, ",")). + Str("icelite", fmt.Sprintf("%t", manager.config.ICELite)). + Str("icetrickle", fmt.Sprintf("%t", manager.config.ICETrickle)). + Str("iceservers", strings.Join(servers, ",")). + Str("nat1to1", strings.Join(manager.config.NAT1To1IPs, ",")). + Str("epr", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)). Msgf("webrtc starting") manager.curImage.Start() @@ -98,11 +103,7 @@ func (manager *WebRTCManagerCtx) Shutdown() error { return nil } -func (manager *WebRTCManagerCtx) ICELite() bool { - return manager.config.ICELite -} - -func (manager *WebRTCManagerCtx) ICEServers() []string { +func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer { return manager.config.ICEServers } @@ -423,12 +424,24 @@ func (manager *WebRTCManagerCtx) apiConfiguration() *webrtc.Configuration { } } + ICEServers := []webrtc.ICEServer{} + for _, server := range manager.config.ICEServers { + var credential interface{} + if server.Credential != "" { + credential = server.Credential + } else { + credential = false + } + + ICEServers = append(ICEServers, webrtc.ICEServer{ + URLs: server.URLs, + Username: server.Username, + Credential: credential, + }) + } + return &webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{ - { - URLs: manager.config.ICEServers, - }, - }, + ICEServers: ICEServers, SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback, } } diff --git a/internal/websocket/handler/signal.go b/internal/websocket/handler/signal.go index 3c390ab1..ecfb1b59 100644 --- a/internal/websocket/handler/signal.go +++ b/internal/websocket/handler/signal.go @@ -22,12 +22,11 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session) error { return session.Send( message.SignalProvide{ - Event: event.SIGNAL_PROVIDE, - SDP: offer.SDP, - Lite: h.webrtc.ICELite(), - ICE: h.webrtc.ICEServers(), - Videos: videos, - Video: defaultVideo, + Event: event.SIGNAL_PROVIDE, + SDP: offer.SDP, + ICEServers: h.webrtc.ICEServers(), + Videos: videos, + Video: defaultVideo, }) }