From 124c5ae1172f2c64b724502e2147f58ecc43af42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Tue, 14 Feb 2023 21:19:02 +0100 Subject: [PATCH] Wait for keyframe on switching streams (#28) * stream sink add keyframe lobby. * change streamsink keyframe identifier. * add h264. * use gstreamers is delta unit for sample. * use delta unit. --- internal/capture/streamsink.go | 56 ++++++++++++++++++++++++++-------- internal/config/capture.go | 5 ++- internal/webrtc/track.go | 5 ++- pkg/gst/gst.c | 7 +++-- pkg/gst/gst.go | 11 ++++--- pkg/gst/gst.h | 4 +-- pkg/types/capture.go | 8 +++-- pkg/types/codec/codecs.go | 12 ++++++++ 8 files changed, 80 insertions(+), 28 deletions(-) diff --git a/internal/capture/streamsink.go b/internal/capture/streamsink.go index 3400840e..c084e869 100644 --- a/internal/capture/streamsink.go +++ b/internal/capture/streamsink.go @@ -21,6 +21,7 @@ var moveSinkListenerMu = sync.Mutex{} type StreamSinkManagerCtx struct { id string getBitrate func() (int, error) + waitForKf bool // wait for a keyframe before sending samples logger zerolog.Logger mu sync.Mutex @@ -32,6 +33,7 @@ type StreamSinkManagerCtx struct { pipelineFn func() (string, error) listeners map[uintptr]*func(sample types.Sample) + listenersKf map[uintptr]*func(sample types.Sample) // keyframe lobby listenersMu sync.Mutex // metrics @@ -40,7 +42,7 @@ type StreamSinkManagerCtx struct { pipelinesActive prometheus.Gauge } -func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx { +func streamSinkNew(c codec.RTPCodec, pipelineFn func() (string, error), id string, getBitrate func() (int, error)) *StreamSinkManagerCtx { logger := log.With(). Str("module", "capture"). Str("submodule", "stream-sink"). @@ -49,11 +51,15 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id s manager := &StreamSinkManagerCtx{ id: id, getBitrate: getBitrate, + // only wait for keyframes if the codec is video + waitForKf: c.IsVideo(), logger: logger, - codec: codec, + codec: c, pipelineFn: pipelineFn, - listeners: map[uintptr]*func(sample types.Sample){}, + + listeners: map[uintptr]*func(sample types.Sample){}, + listenersKf: map[uintptr]*func(sample types.Sample){}, // metrics currentListeners: promauto.NewGauge(prometheus.GaugeOpts{ @@ -63,8 +69,8 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id s Help: "Current number of listeners for a pipeline.", ConstLabels: map[string]string{ "video_id": id, - "codec_name": codec.Name, - "codec_type": codec.Type.String(), + "codec_name": c.Name, + "codec_type": c.Type.String(), }, }), pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{ @@ -75,8 +81,8 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id s ConstLabels: map[string]string{ "submodule": "streamsink", "video_id": id, - "codec_name": codec.Name, - "codec_type": codec.Type.String(), + "codec_name": c.Name, + "codec_type": c.Type.String(), }, }), pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{ @@ -87,8 +93,8 @@ func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id s ConstLabels: map[string]string{ "submodule": "streamsink", "video_id": id, - "codec_name": codec.Name, - "codec_type": codec.Type.String(), + "codec_name": c.Name, + "codec_type": c.Type.String(), }, }), } @@ -103,6 +109,9 @@ func (manager *StreamSinkManagerCtx) shutdown() { for key := range manager.listeners { delete(manager.listeners, key) } + for key := range manager.listenersKf { + delete(manager.listenersKf, key) + } manager.listenersMu.Unlock() manager.DestroyPipeline() @@ -133,7 +142,7 @@ func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec { } func (manager *StreamSinkManagerCtx) start() error { - if len(manager.listeners) == 0 { + if len(manager.listeners)+len(manager.listenersKf) == 0 { err := manager.CreatePipeline() if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { return err @@ -146,7 +155,7 @@ func (manager *StreamSinkManagerCtx) start() error { } func (manager *StreamSinkManagerCtx) stop() { - if len(manager.listeners) == 0 { + if len(manager.listeners)+len(manager.listenersKf) == 0 { manager.DestroyPipeline() manager.logger.Info().Msgf("last listener, stopping") } @@ -156,11 +165,22 @@ func (manager *StreamSinkManagerCtx) addListener(listener *func(sample types.Sam ptr := reflect.ValueOf(listener).Pointer() manager.listenersMu.Lock() - manager.listeners[ptr] = listener + if manager.waitForKf { + // if we're waiting for a keyframe, add it to the keyframe lobby + manager.listenersKf[ptr] = listener + } else { + // otherwise, add it as a regular listener + manager.listeners[ptr] = listener + } manager.listenersMu.Unlock() manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener") manager.currentListeners.Set(float64(manager.ListenersCount())) + + // if we will be waiting for a keyframe, emit one now + if manager.pipeline != nil && manager.waitForKf { + manager.pipeline.EmitVideoKeyframe() + } } func (manager *StreamSinkManagerCtx) removeListener(listener *func(sample types.Sample)) { @@ -168,6 +188,7 @@ func (manager *StreamSinkManagerCtx) removeListener(listener *func(sample types. manager.listenersMu.Lock() delete(manager.listeners, ptr) + delete(manager.listenersKf, ptr) // if it's a keyframe listener, remove it too manager.listenersMu.Unlock() manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener") @@ -259,7 +280,7 @@ func (manager *StreamSinkManagerCtx) ListenersCount() int { manager.listenersMu.Lock() defer manager.listenersMu.Unlock() - return len(manager.listeners) + return len(manager.listeners) + len(manager.listenersKf) } func (manager *StreamSinkManagerCtx) Started() bool { @@ -307,6 +328,15 @@ func (manager *StreamSinkManagerCtx) CreatePipeline() error { } manager.listenersMu.Lock() + // if is not delta unit -> it can be decoded independently -> it is a keyframe + if manager.waitForKf && !sample.DeltaUnit && len(manager.listenersKf) > 0 { + // if current sample is a keyframe, move listeners from + // keyframe lobby to actual listeners map and clear lobby + for k, v := range manager.listenersKf { + manager.listeners[k] = v + } + manager.listenersKf = make(map[uintptr]*func(sample types.Sample)) + } for _, emit := range manager.listeners { (*emit)(sample) } diff --git a/internal/config/capture.go b/internal/config/capture.go index 6b4b4fee..6e23b4a8 100644 --- a/internal/config/capture.go +++ b/internal/config/capture.go @@ -3,7 +3,6 @@ package config import ( "os" - "github.com/pion/webrtc/v3" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -171,7 +170,7 @@ func (s *Capture) Set() { // video videoCodec := viper.GetString("capture.video.codec") s.VideoCodec, ok = codec.ParseStr(videoCodec) - if !ok || s.VideoCodec.Type != webrtc.RTPCodecTypeVideo { + if !ok || !s.VideoCodec.IsVideo() { log.Warn().Str("codec", videoCodec).Msgf("unknown video codec, using Vp8") s.VideoCodec = codec.VP8() } @@ -217,7 +216,7 @@ func (s *Capture) Set() { audioCodec := viper.GetString("capture.audio.codec") s.AudioCodec, ok = codec.ParseStr(audioCodec) - if !ok || s.AudioCodec.Type != webrtc.RTPCodecTypeAudio { + if !ok || !s.AudioCodec.IsAudio() { log.Warn().Str("codec", audioCodec).Msgf("unknown audio codec, using Opus") s.AudioCodec = codec.Opus() } diff --git a/internal/webrtc/track.go b/internal/webrtc/track.go index 6528d459..d3fb4082 100644 --- a/internal/webrtc/track.go +++ b/internal/webrtc/track.go @@ -64,7 +64,10 @@ func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.Pe return } - err := track.WriteSample(media.Sample(sample)) + err := track.WriteSample(media.Sample{ + Data: sample.Data, + Duration: sample.Duration, + }) if err != nil && !errors.Is(err, io.ErrClosedPipe) { logger.Warn().Err(err).Msg("failed to write sample to track") } diff --git a/pkg/gst/gst.c b/pkg/gst/gst.c index 86e97dce..020eb92f 100644 --- a/pkg/gst/gst.c +++ b/pkg/gst/gst.c @@ -6,7 +6,7 @@ static void gstreamer_pipeline_log(GstPipelineCtx *ctx, char* level, const char* char buffer[100]; vsprintf(buffer, format, argptr); va_end(argptr); - goPipelineLog(level, buffer, ctx->pipelineId); + goPipelineLog(ctx->pipelineId, level, buffer); } static gboolean gstreamer_bus_call(GstBus *bus, GstMessage *msg, gpointer user_data) { @@ -95,7 +95,10 @@ static GstFlowReturn gstreamer_send_new_sample_handler(GstElement *object, gpoin buffer = gst_sample_get_buffer(sample); if (buffer) { gst_buffer_extract_dup(buffer, 0, gst_buffer_get_size(buffer), ©, ©_size); - goHandlePipelineBuffer(copy, copy_size, GST_BUFFER_DURATION(buffer), ctx->pipelineId); + goHandlePipelineBuffer(ctx->pipelineId, copy, copy_size, + GST_BUFFER_DURATION(buffer), + GST_BUFFER_FLAG_IS_SET(buffer, GST_BUFFER_FLAG_DELTA_UNIT) + ); } gst_sample_unref(sample); } diff --git a/pkg/gst/gst.go b/pkg/gst/gst.go index 14bdc6dd..9d4fb874 100644 --- a/pkg/gst/gst.go +++ b/pkg/gst/gst.go @@ -200,8 +200,8 @@ func CheckPlugins(plugins []string) error { } //export goHandlePipelineBuffer -func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) { - defer C.free(buffer) +func goHandlePipelineBuffer(pipelineID C.int, buf unsafe.Pointer, bufLen C.int, duration C.guint64, deltaUnit C.gboolean) { + defer C.free(buf) pipelinesLock.Lock() pipeline, ok := pipelines[int(pipelineID)] @@ -209,8 +209,9 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i if ok { pipeline.sample <- types.Sample{ - Data: C.GoBytes(buffer, bufferLen), - Duration: time.Duration(duration), + Data: C.GoBytes(buf, bufLen), + Duration: time.Duration(duration), + DeltaUnit: deltaUnit == C.TRUE, } } else { log.Warn(). @@ -222,7 +223,7 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i } //export goPipelineLog -func goPipelineLog(levelUnsafe *C.char, msgUnsafe *C.char, pipelineID C.int) { +func goPipelineLog(pipelineID C.int, levelUnsafe *C.char, msgUnsafe *C.char) { levelStr := C.GoString(levelUnsafe) msg := C.GoString(msgUnsafe) diff --git a/pkg/gst/gst.h b/pkg/gst/gst.h index 619170e7..1d6a1527 100644 --- a/pkg/gst/gst.h +++ b/pkg/gst/gst.h @@ -12,8 +12,8 @@ typedef struct GstPipelineCtx { GstElement *appsrc; } GstPipelineCtx; -extern void goHandlePipelineBuffer(void *buffer, int bufferLen, int samples, int pipelineId); -extern void goPipelineLog(char *level, char *msg, int pipelineId); +extern void goHandlePipelineBuffer(int pipelineId, void *buffer, int bufferLen, guint64 duration, gboolean deltaUnit); +extern void goPipelineLog(int pipelineId, char *level, char *msg); GstPipelineCtx *gstreamer_pipeline_create(char *pipelineStr, int pipelineId, GError **error); void gstreamer_pipeline_attach_appsink(GstPipelineCtx *ctx, char *sinkName); diff --git a/pkg/types/capture.go b/pkg/types/capture.go index 6ab9a8e7..9ef238ed 100644 --- a/pkg/types/capture.go +++ b/pkg/types/capture.go @@ -6,17 +6,21 @@ import ( "fmt" "math" "strings" + "time" "github.com/PaesslerAG/gval" "github.com/demodesk/neko/pkg/types/codec" - "github.com/pion/webrtc/v3/pkg/media" ) var ( ErrCapturePipelineAlreadyExists = errors.New("capture pipeline already exists") ) -type Sample media.Sample +type Sample struct { + Data []byte + Duration time.Duration + DeltaUnit bool // this unit cannot be decoded independently. +} type Receiver interface { SetStream(stream StreamSinkManager) (changed bool, err error) diff --git a/pkg/types/codec/codecs.go b/pkg/types/codec/codecs.go index 0bb6adf0..c71057ba 100644 --- a/pkg/types/codec/codecs.go +++ b/pkg/types/codec/codecs.go @@ -63,6 +63,18 @@ func (codec *RTPCodec) Register(engine *webrtc.MediaEngine) error { }, codec.Type) } +func (codec *RTPCodec) IsVideo() bool { + return codec.Type == webrtc.RTPCodecTypeVideo +} + +func (codec *RTPCodec) IsAudio() bool { + return codec.Type == webrtc.RTPCodecTypeAudio +} + +func (codec *RTPCodec) String() string { + return codec.Type.String() + "/" + codec.Name +} + func VP8() RTPCodec { return RTPCodec{ Name: "vp8",