diff --git a/internal/types/event/events.go b/internal/types/event/events.go index d6a74812..fd41b804 100644 --- a/internal/types/event/events.go +++ b/internal/types/event/events.go @@ -8,6 +8,7 @@ const ( const ( SIGNAL_REQUEST = "signal/request" + SIGNAL_RESTART = "signal/restart" SIGNAL_ANSWER = "signal/answer" SIGNAL_PROVIDE = "signal/provide" SIGNAL_CANDIDATE = "signal/candidate" diff --git a/internal/types/webrtc.go b/internal/types/webrtc.go index c31b11a8..f53aea6a 100644 --- a/internal/types/webrtc.go +++ b/internal/types/webrtc.go @@ -9,6 +9,7 @@ type ICEServer struct { } type WebRTCPeer interface { + CreateOffer(ICETrickle bool, ICERestart bool) (*webrtc.SessionDescription, error) SignalAnswer(sdp string) error SignalCandidate(candidate webrtc.ICECandidateInit) error diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 7f589303..d6a65b62 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -20,9 +20,6 @@ import ( "demodesk/neko/internal/webrtc/cursor" ) -// how long is can take between sending offer and connecting -const offerTimeout = 4 * time.Second - // the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds const disconnectedTimeout = 4 * time.Second @@ -237,26 +234,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin return nil, err } - offer, err := connection.CreateOffer(nil) - if err != nil { - return nil, err - } - - if !manager.config.ICETrickle { - // Create channel that is blocked until ICE Gathering is complete - gatherComplete := webrtc.GatheringCompletePromise(connection) - - if err := connection.SetLocalDescription(offer); err != nil { - return nil, err - } - - <-gatherComplete - } else { - if err := connection.SetLocalDescription(offer); err != nil { - return nil, err - } - } - peer := &WebRTCPeerCtx{ api: api, connection: connection, @@ -347,22 +324,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } }) - // offer timeout - go func() { - time.Sleep(offerTimeout) - - // already disconnected - if connection.ConnectionState() == webrtc.PeerConnectionStateClosed { - return - } - - // not connected - if connection.ConnectionState() != webrtc.PeerConnectionStateConnected { - logger.Warn().Msg("connection timeouted, closing") - connection.Close() - } - }() - go func() { rtcpBuf := make([]byte, 1500) for { @@ -382,7 +343,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin }() session.SetWebRTCPeer(peer) - return connection.LocalDescription(), nil + return peer.CreateOffer(manager.config.ICETrickle, false) } func (manager *WebRTCManagerCtx) mediaEngine(videoID string) (*webrtc.MediaEngine, error) { diff --git a/internal/webrtc/peer.go b/internal/webrtc/peer.go index a9bdba34..e5bb933f 100644 --- a/internal/webrtc/peer.go +++ b/internal/webrtc/peer.go @@ -1,6 +1,14 @@ package webrtc -import "github.com/pion/webrtc/v3" +import ( + "time" + + "github.com/pion/webrtc/v3" + "github.com/rs/zerolog/log" +) + +// how long is can take between sending offer and connecting +const offerTimeout = 10 * time.Second type WebRTCPeerCtx struct { api *webrtc.API @@ -9,6 +17,48 @@ type WebRTCPeerCtx struct { changeVideo func(videoID string) error } +func (peer *WebRTCPeerCtx) CreateOffer(ICETrickle bool, ICERestart bool) (*webrtc.SessionDescription, error) { + // offer timeout + go func() { + time.Sleep(offerTimeout) + + // already disconnected + if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed { + return + } + + // not connected + if peer.connection.ConnectionState() != webrtc.PeerConnectionStateConnected { + log.Warn().Msg("connection timeouted, closing") + peer.connection.Close() + } + }() + + offer, err := peer.connection.CreateOffer(&webrtc.OfferOptions{ + ICERestart: ICERestart, + }) + if err != nil { + return nil, err + } + + if !ICETrickle { + // Create channel that is blocked until ICE Gathering is complete + gatherComplete := webrtc.GatheringCompletePromise(peer.connection) + + if err := peer.connection.SetLocalDescription(offer); err != nil { + return nil, err + } + + <-gatherComplete + } else { + if err := peer.connection.SetLocalDescription(offer); err != nil { + return nil, err + } + } + + return peer.connection.LocalDescription(), nil +} + func (peer *WebRTCPeerCtx) SignalAnswer(sdp string) error { return peer.connection.SetRemoteDescription(webrtc.SessionDescription{ SDP: sdp, diff --git a/internal/websocket/handler/handler.go b/internal/websocket/handler/handler.go index 0c9a0982..5b1531d9 100644 --- a/internal/websocket/handler/handler.go +++ b/internal/websocket/handler/handler.go @@ -52,6 +52,8 @@ func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) bool { err = utils.Unmarshal(payload, raw, func() error { return h.signalRequest(session, payload) }) + case event.SIGNAL_RESTART: + err = h.signalRestart(session) case event.SIGNAL_ANSWER: payload := &message.SignalAnswer{} err = utils.Unmarshal(payload, raw, func() error { diff --git a/internal/websocket/handler/signal.go b/internal/websocket/handler/signal.go index 57c67add..635db77d 100644 --- a/internal/websocket/handler/signal.go +++ b/internal/websocket/handler/signal.go @@ -32,6 +32,25 @@ func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *messag }) } +func (h *MessageHandlerCtx) signalRestart(session types.Session) error { + peer := session.GetWebRTCPeer() + if peer == nil { + h.logger.Debug().Str("session_id", session.ID()).Msg("webRTC peer does not exist") + return nil + } + + offer, err := peer.CreateOffer(true, true) + if err != nil { + return err + } + + return session.Send( + message.SignalAnswer{ + Event: event.SIGNAL_RESTART, + SDP: offer.SDP, + }) +} + func (h *MessageHandlerCtx) signalAnswer(session types.Session, payload *message.SignalAnswer) error { peer := session.GetWebRTCPeer() if peer == nil {