initial commit - from neko open source repository..

This commit is contained in:
Miroslav Šedivý
2020-10-22 16:54:50 +02:00
commit 56de805f54
66 changed files with 5498 additions and 0 deletions

View File

@ -0,0 +1,83 @@
package broadcast
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/gst"
"n.eko.moe/neko/internal/types/config"
)
type BroadcastManager struct {
logger zerolog.Logger
pipeline *gst.Pipeline
remote *config.Remote
config *config.Broadcast
enabled bool
url string
}
func New(remote *config.Remote, config *config.Broadcast) *BroadcastManager {
return &BroadcastManager{
logger: log.With().Str("module", "remote").Logger(),
remote: remote,
config: config,
enabled: false,
url: "",
}
}
func (manager *BroadcastManager) Start() {
if !manager.enabled || manager.IsActive() {
return
}
var err error
manager.pipeline, err = gst.CreateRTMPPipeline(
manager.remote.Device,
manager.remote.Display,
manager.config.Pipeline,
manager.url,
)
manager.logger.Info().
Str("audio_device", manager.remote.Device).
Str("video_display", manager.remote.Display).
Str("rtmp_pipeline_src", manager.pipeline.Src).
Msgf("RTMP pipeline is starting...")
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create rtmp pipeline")
return
}
manager.pipeline.Play()
}
func (manager *BroadcastManager) Stop() {
if !manager.IsActive() {
return
}
manager.pipeline.Stop()
manager.pipeline = nil
}
func (manager *BroadcastManager) IsActive() bool {
return manager.pipeline != nil
}
func (manager *BroadcastManager) Create(url string) {
manager.url = url
manager.enabled = true
manager.Start()
}
func (manager *BroadcastManager) Destroy() {
manager.Stop()
manager.enabled = false
}
func (manager *BroadcastManager) GetUrl() string {
return manager.url
}

95
internal/gst/gst.c Normal file
View File

@ -0,0 +1,95 @@
#include "gst.h"
#include <gst/app/gstappsrc.h>
typedef struct SampleHandlerUserData {
int pipelineId;
} SampleHandlerUserData;
void gstreamer_init(void) {
gst_init(NULL, NULL);
}
GMainLoop *gstreamer_send_main_loop = NULL;
void gstreamer_send_start_mainloop(void) {
gstreamer_send_main_loop = g_main_loop_new(NULL, FALSE);
g_main_loop_run(gstreamer_send_main_loop);
}
static gboolean gstreamer_send_bus_call(GstBus *bus, GstMessage *msg, gpointer data) {
switch (GST_MESSAGE_TYPE(msg)) {
case GST_MESSAGE_EOS:
g_print("End of stream\n");
exit(1);
break;
case GST_MESSAGE_ERROR: {
gchar *debug;
GError *error;
gst_message_parse_error(msg, &error, &debug);
g_free(debug);
g_printerr("Error: %s\n", error->message);
g_error_free(error);
exit(1);
}
default:
break;
}
return TRUE;
}
GstFlowReturn gstreamer_send_new_sample_handler(GstElement *object, gpointer user_data) {
GstSample *sample = NULL;
GstBuffer *buffer = NULL;
gpointer copy = NULL;
gsize copy_size = 0;
SampleHandlerUserData *s = (SampleHandlerUserData *)user_data;
g_signal_emit_by_name (object, "pull-sample", &sample);
if (sample) {
buffer = gst_sample_get_buffer(sample);
if (buffer) {
gst_buffer_extract_dup(buffer, 0, gst_buffer_get_size(buffer), &copy, &copy_size);
goHandlePipelineBuffer(copy, copy_size, GST_BUFFER_DURATION(buffer), s->pipelineId);
}
gst_sample_unref (sample);
}
return GST_FLOW_OK;
}
GstElement *gstreamer_send_create_pipeline(char *pipeline) {
GError *error = NULL;
return gst_parse_launch(pipeline, &error);
}
void gstreamer_send_start_pipeline(GstElement *pipeline, int pipelineId) {
SampleHandlerUserData *s = calloc(1, sizeof(SampleHandlerUserData));
s->pipelineId = pipelineId;
GstBus *bus = gst_pipeline_get_bus(GST_PIPELINE(pipeline));
gst_bus_add_watch(bus, gstreamer_send_bus_call, NULL);
gst_object_unref(bus);
GstElement *appsink = gst_bin_get_by_name(GST_BIN(pipeline), "appsink");
g_object_set(appsink, "emit-signals", TRUE, NULL);
g_signal_connect(appsink, "new-sample", G_CALLBACK(gstreamer_send_new_sample_handler), s);
gst_object_unref(appsink);
gst_element_set_state(pipeline, GST_STATE_PLAYING);
}
void gstreamer_send_play_pipeline(GstElement *pipeline) {
gst_element_set_state(pipeline, GST_STATE_PLAYING);
}
void gstreamer_send_stop_pipeline(GstElement *pipeline) {
gst_element_set_state(pipeline, GST_STATE_NULL);
}

277
internal/gst/gst.go Normal file
View File

@ -0,0 +1,277 @@
package gst
/*
#cgo pkg-config: gstreamer-1.0 gstreamer-app-1.0
#include "gst.h"
*/
import "C"
import (
"fmt"
"sync"
"unsafe"
"github.com/pion/webrtc/v2"
"n.eko.moe/neko/internal/types"
)
/*
apt-get install \
libgstreamer1.0-0 \
gstreamer1.0-plugins-base \
gstreamer1.0-plugins-good \
gstreamer1.0-plugins-bad \
gstreamer1.0-plugins-ugly\
gstreamer1.0-libav \
gstreamer1.0-doc \
gstreamer1.0-tools \
gstreamer1.0-x \
gstreamer1.0-alsa \
gstreamer1.0-pulseaudio
gst-inspect-1.0 --version
gst-inspect-1.0 plugin
gst-launch-1.0 ximagesrc show-pointer=true use-damage=false ! video/x-raw,framerate=30/1 ! videoconvert ! queue ! vp8enc error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true cpu-used=5 deadline=1 ! autovideosink
gst-launch-1.0 pulsesrc ! audioconvert ! opusenc ! autoaudiosink
*/
// Pipeline is a wrapper for a GStreamer Pipeline
type Pipeline struct {
Pipeline *C.GstElement
Sample chan types.Sample
CodecName string
ClockRate float32
Src string
id int
}
var pipelines = make(map[int]*Pipeline)
var pipelinesLock sync.Mutex
var registry *C.GstRegistry
const (
videoClockRate = 90000
audioClockRate = 48000
pcmClockRate = 8000
videoSrc = "ximagesrc xid=%s show-pointer=true use-damage=false ! video/x-raw ! videoconvert ! queue ! "
audioSrc = "pulsesrc device=%s ! audio/x-raw,channels=2 ! audioconvert ! "
)
func init() {
C.gstreamer_init()
registry = C.gst_registry_get()
}
// CreateRTMPPipeline creates a GStreamer Pipeline
func CreateRTMPPipeline(pipelineDevice string, pipelineDisplay string, pipelineSrc string, pipelineRTMP string) (*Pipeline, error) {
video := fmt.Sprintf(videoSrc, pipelineDisplay)
audio := fmt.Sprintf(audioSrc, pipelineDevice)
var pipelineStr string
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc, pipelineRTMP, pipelineDevice, pipelineDisplay)
} else {
pipelineStr = fmt.Sprintf("flvmux name=mux ! rtmpsink location='%s live=1' %s audio/x-raw,channels=2 ! audioconvert ! voaacenc ! mux. %s x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! mux.", pipelineRTMP, audio, video)
}
return CreatePipeline(pipelineStr, "", 0)
}
// CreateAppPipeline creates a GStreamer Pipeline
func CreateAppPipeline(codecName string, pipelineDevice string, pipelineSrc string) (*Pipeline, error) {
pipelineStr := " ! appsink name=appsink"
var clockRate float32
switch codecName {
case webrtc.VP8:
// https://gstreamer.freedesktop.org/documentation/vpx/vp8enc.html?gi-language=c
// gstreamer1.0-plugins-good
// vp8enc error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true cpu-used=5 deadline=1
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
return nil, err
}
clockRate = videoClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(videoSrc+"vp8enc cpu-used=8 threads=2 deadline=1 error-resilient=partitions keyframe-max-dist=10 auto-alt-ref=true"+pipelineStr, pipelineDevice)
}
case webrtc.VP9:
// https://gstreamer.freedesktop.org/documentation/vpx/vp9enc.html?gi-language=c
// gstreamer1.0-plugins-good
// vp9enc
if err := CheckPlugins([]string{"ximagesrc", "vpx"}); err != nil {
return nil, err
}
clockRate = videoClockRate
// Causes panic! not sure why...
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(videoSrc+"vp9enc"+pipelineStr, pipelineDevice)
}
case webrtc.H264:
// https://gstreamer.freedesktop.org/documentation/openh264/openh264enc.html?gi-language=c#openh264enc
// gstreamer1.0-plugins-bad
// openh264enc multi-thread=4 complexity=high bitrate=3072000 max-bitrate=4096000
if err := CheckPlugins([]string{"ximagesrc"}); err != nil {
return nil, err
}
clockRate = videoClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(videoSrc+"openh264enc multi-thread=4 complexity=high bitrate=3072000 max-bitrate=4096000 ! video/x-h264,stream-format=byte-stream"+pipelineStr, pipelineDevice)
// https://gstreamer.freedesktop.org/documentation/x264/index.html?gi-language=c
// gstreamer1.0-plugins-ugly
// video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream
if err := CheckPlugins([]string{"openh264"}); err != nil {
pipelineStr = fmt.Sprintf(videoSrc+"video/x-raw,format=I420 ! x264enc bframes=0 key-int-max=60 byte-stream=true tune=zerolatency speed-preset=veryfast ! video/x-h264,stream-format=byte-stream"+pipelineStr, pipelineDevice)
if err := CheckPlugins([]string{"x264"}); err != nil {
return nil, err
}
}
}
case webrtc.Opus:
// https://gstreamer.freedesktop.org/documentation/opus/opusenc.html
// gstreamer1.0-plugins-base
// opusenc
if err := CheckPlugins([]string{"pulseaudio", "opus"}); err != nil {
return nil, err
}
clockRate = audioClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(audioSrc+"opusenc"+pipelineStr, pipelineDevice)
}
case webrtc.G722:
// https://gstreamer.freedesktop.org/documentation/libav/avenc_g722.html?gi-language=c
// gstreamer1.0-libav
// avenc_g722
if err := CheckPlugins([]string{"pulseaudio", "libav"}); err != nil {
return nil, err
}
clockRate = audioClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(audioSrc+"avenc_g722"+pipelineStr, pipelineDevice)
}
case webrtc.PCMU:
// https://gstreamer.freedesktop.org/documentation/mulaw/mulawenc.html?gi-language=c
// gstreamer1.0-plugins-good
// audio/x-raw, rate=8000 ! mulawenc
if err := CheckPlugins([]string{"pulseaudio", "mulaw"}); err != nil {
return nil, err
}
clockRate = pcmClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(audioSrc+"audio/x-raw, rate=8000 ! mulawenc"+pipelineStr, pipelineDevice)
}
case webrtc.PCMA:
// https://gstreamer.freedesktop.org/documentation/alaw/alawenc.html?gi-language=c
// gstreamer1.0-plugins-good
// audio/x-raw, rate=8000 ! alawenc
if err := CheckPlugins([]string{"pulseaudio", "alaw"}); err != nil {
return nil, err
}
clockRate = pcmClockRate
if pipelineSrc != "" {
pipelineStr = fmt.Sprintf(pipelineSrc+pipelineStr, pipelineDevice)
} else {
pipelineStr = fmt.Sprintf(audioSrc+"audio/x-raw, rate=8000 ! alawenc"+pipelineStr, pipelineDevice)
}
default:
return nil, fmt.Errorf("unknown codec %s", codecName)
}
return CreatePipeline(pipelineStr, codecName, clockRate)
}
// CreatePipeline creates a GStreamer Pipeline
func CreatePipeline(pipelineStr string, codecName string, clockRate float32) (*Pipeline, error) {
pipelineStrUnsafe := C.CString(pipelineStr)
defer C.free(unsafe.Pointer(pipelineStrUnsafe))
pipelinesLock.Lock()
defer pipelinesLock.Unlock()
p := &Pipeline{
Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe),
Sample: make(chan types.Sample),
CodecName: codecName,
ClockRate: clockRate,
Src: pipelineStr,
id: len(pipelines),
}
pipelines[p.id] = p
return p, nil
}
// Start starts the GStreamer Pipeline
func (p *Pipeline) Start() {
C.gstreamer_send_start_pipeline(p.Pipeline, C.int(p.id))
}
// Play starts the GStreamer Pipeline
func (p *Pipeline) Play() {
C.gstreamer_send_play_pipeline(p.Pipeline)
}
// Stop stops the GStreamer Pipeline
func (p *Pipeline) Stop() {
C.gstreamer_send_stop_pipeline(p.Pipeline)
}
// gst-inspect-1.0
func CheckPlugins(plugins []string) error {
var plugin *C.GstPlugin
for _, pluginstr := range plugins {
plugincstr := C.CString(pluginstr)
plugin = C.gst_registry_find_plugin(registry, plugincstr)
C.free(unsafe.Pointer(plugincstr))
if plugin == nil {
return fmt.Errorf("required gstreamer plugin %s not found", pluginstr)
}
}
return nil
}
//export goHandlePipelineBuffer
func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.int, pipelineID C.int) {
pipelinesLock.Lock()
pipeline, ok := pipelines[int(pipelineID)]
pipelinesLock.Unlock()
if ok {
samples := uint32(pipeline.ClockRate * (float32(duration) / 1000000000))
pipeline.Sample <- types.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}
} else {
fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID))
}
C.free(buffer)
}

19
internal/gst/gst.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef GST_H
#define GST_H
#include <glib.h>
#include <gst/gst.h>
#include <stdint.h>
#include <stdlib.h>
extern void goHandlePipelineBuffer(void *buffer, int bufferLen, int samples, int pipelineId);
GstElement *gstreamer_send_create_pipeline(char *pipeline);
void gstreamer_send_start_pipeline(GstElement *pipeline, int pipelineId);
void gstreamer_send_play_pipeline(GstElement *pipeline);
void gstreamer_send_stop_pipeline(GstElement *pipeline);
void gstreamer_send_start_mainloop(void);
void gstreamer_init(void);
#endif

View File

@ -0,0 +1,102 @@
package endpoint
import (
"encoding/json"
"fmt"
"net/http"
"runtime/debug"
"github.com/go-chi/chi/middleware"
"github.com/rs/zerolog/log"
)
type (
Endpoint func(http.ResponseWriter, *http.Request) error
ErrResponse struct {
Status int `json:"status,omitempty"`
Err string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Details string `json:"details,omitempty"`
Code string `json:"code,omitempty"`
RequestID string `json:"request,omitempty"`
}
)
func Handle(handler Endpoint) http.HandlerFunc {
fn := func(w http.ResponseWriter, r *http.Request) {
if err := handler(w, r); err != nil {
WriteError(w, r, err)
}
}
return http.HandlerFunc(fn)
}
var nonErrorsCodes = map[int]bool{
404: true,
}
func errResponse(input interface{}) *ErrResponse {
var res *ErrResponse
var err interface{}
switch input.(type) {
case *HandlerError:
e := input.(*HandlerError)
res = &ErrResponse{
Status: e.Status,
Err: http.StatusText(e.Status),
Message: e.Message,
}
err = e.Err
default:
res = &ErrResponse{
Status: http.StatusInternalServerError,
Err: http.StatusText(http.StatusInternalServerError),
}
err = input
}
if err != nil {
switch err.(type) {
case *error:
e := err.(error)
res.Details = e.Error()
break
default:
res.Details = fmt.Sprintf("%+v", err)
break
}
}
return res
}
func WriteError(w http.ResponseWriter, r *http.Request, err interface{}) {
hlog := log.With().
Str("module", "http").
Logger()
res := errResponse(err)
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
res.RequestID = reqID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(res.Status)
if err := json.NewEncoder(w).Encode(res); err != nil {
hlog.Warn().Err(err).Msg("Failed writing json error response")
}
if !nonErrorsCodes[res.Status] {
logEntry := middleware.GetLogEntry(r)
if logEntry != nil {
logEntry.Panic(err, debug.Stack())
} else {
hlog.Error().Str("stack", string(debug.Stack())).Msgf("%+v", err)
}
}
}

View File

@ -0,0 +1,17 @@
package endpoint
import "fmt"
type HandlerError struct {
Status int
Message string
Err error
}
func (e *HandlerError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Err.Error())
}
return e.Message
}

87
internal/http/http.go Normal file
View File

@ -0,0 +1,87 @@
package http
import (
"context"
"fmt"
"net/http"
"os"
"github.com/go-chi/chi"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/http/endpoint"
"n.eko.moe/neko/internal/http/middleware"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/config"
)
type Server struct {
logger zerolog.Logger
router *chi.Mux
http *http.Server
conf *config.Server
}
func New(conf *config.Server, webSocketHandler types.WebSocketHandler) *Server {
logger := log.With().Str("module", "webrtc").Logger()
router := chi.NewRouter()
// router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(middleware.Logger) // Log API request calls
router.Get("/ws", func(w http.ResponseWriter, r *http.Request) {
webSocketHandler.Upgrade(w, r)
})
fs := http.FileServer(http.Dir(conf.Static))
router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
if _, err := os.Stat(conf.Static + r.RequestURI); os.IsNotExist(err) {
http.StripPrefix(r.RequestURI, fs).ServeHTTP(w, r)
} else {
fs.ServeHTTP(w, r)
}
})
router.NotFound(endpoint.Handle(func(w http.ResponseWriter, r *http.Request) error {
return &endpoint.HandlerError{
Status: http.StatusNotFound,
Message: fmt.Sprintf("file '%s' is not found", r.RequestURI),
}
}))
server := &http.Server{
Addr: conf.Bind,
Handler: router,
}
return &Server{
logger: logger,
router: router,
http: server,
conf: conf,
}
}
func (s *Server) Start() {
if s.conf.Cert != "" && s.conf.Key != "" {
go func() {
if err := s.http.ListenAndServeTLS(s.conf.Cert, s.conf.Key); err != http.ErrServerClosed {
s.logger.Panic().Err(err).Msg("unable to start https server")
}
}()
s.logger.Info().Msgf("https listening on %s", s.http.Addr)
} else {
go func() {
if err := s.http.ListenAndServe(); err != http.ErrServerClosed {
s.logger.Panic().Err(err).Msg("unable to start http server")
}
}()
s.logger.Warn().Msgf("http listening on %s", s.http.Addr)
}
}
func (s *Server) Shutdown() error {
return s.http.Shutdown(context.Background())
}

View File

@ -0,0 +1,80 @@
package middleware
import (
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/middleware"
"github.com/rs/zerolog/log"
)
func Logger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
req := map[string]interface{}{}
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
req["id"] = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
req["scheme"] = scheme
req["proto"] = r.Proto
req["method"] = r.Method
req["remote"] = r.RemoteAddr
req["agent"] = r.UserAgent()
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
fields := map[string]interface{}{}
fields["req"] = req
entry := &entry{
fields: fields,
}
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), time.Since(t1))
}()
next.ServeHTTP(ww, r)
}
return http.HandlerFunc(fn)
}
type entry struct {
fields map[string]interface{}
errors []map[string]interface{}
}
func (e *entry) Write(status, bytes int, elapsed time.Duration) {
res := map[string]interface{}{}
res["time"] = time.Now().UTC().Format(time.RFC1123)
res["status"] = status
res["bytes"] = bytes
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
e.fields["res"] = res
e.fields["module"] = "http"
if len(e.errors) > 0 {
e.fields["errors"] = e.errors
log.Error().Fields(e.fields).Msgf("request failed (%d)", status)
} else {
log.Debug().Fields(e.fields).Msgf("request complete (%d)", status)
}
}
func (e *entry) Panic(v interface{}, stack []byte) {
err := map[string]interface{}{}
err["message"] = fmt.Sprintf("%+v", v)
err["stack"] = string(stack)
e.errors = append(e.errors, err)
}

View File

@ -0,0 +1,12 @@
package middleware
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type ctxKey struct {
name string
}
func (k *ctxKey) String() string {
return "neko/ctx/" + k.name
}

View File

@ -0,0 +1,24 @@
package middleware
// The original work was derived from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"net/http"
"n.eko.moe/neko/internal/http/endpoint"
)
func Recoverer(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil {
endpoint.WriteError(w, r, rvr)
}
}()
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}

View File

@ -0,0 +1,89 @@
package middleware
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
)
// Key to use when setting the request ID.
type ctxKeyRequestID int
// RequestIDKey is the key that holds the unique request ID in a request context.
const RequestIDKey ctxKeyRequestID = 0
var prefix string
var reqid uint64
// A quick note on the statistics here: we're trying to calculate the chance that
// two randomly generated base62 prefixes will collide. We use the formula from
// http://en.wikipedia.org/wiki/Birthday_problem
//
// P[m, n] \approx 1 - e^{-m^2/2n}
//
// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server
// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$
//
// For a $k$ character base-62 identifier, we have $n(k) = 62^k$
//
// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for
// our purposes, and is surely more than anyone would ever need in practice -- a
// process that is rebooted a handful of times a day for a hundred years has less
// than a millionth of a percent chance of generating two colliding IDs.
func init() {
hostname, err := os.Hostname()
if hostname == "" || err != nil {
hostname = "localhost"
}
var buf [12]byte
var b64 string
for len(b64) < 10 {
rand.Read(buf[:])
b64 = base64.StdEncoding.EncodeToString(buf[:])
b64 = strings.NewReplacer("+", "", "/", "").Replace(b64)
}
prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10])
}
// RequestID is a middleware that injects a request ID into the context of each
// request. A request ID is a string of the form "host.example.com/random-0001",
// where "random" is a base62 random string that uniquely identifies this go
// process, and where the last number is an atomically incremented request
// counter.
func RequestID(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get("X-Request-Id")
if requestID == "" {
myid := atomic.AddUint64(&reqid, 1)
requestID = fmt.Sprintf("%s-%06d", prefix, myid)
}
ctx = context.WithValue(ctx, RequestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// GetReqID returns a request ID from the given context if one is present.
// Returns the empty string if a request ID cannot be found.
func GetReqID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqID, ok := ctx.Value(RequestIDKey).(string); ok {
return reqID
}
return ""
}
// NextRequestID generates the next request ID in the sequence.
func NextRequestID() uint64 {
return atomic.AddUint64(&reqid, 1)
}

View File

@ -0,0 +1,32 @@
package response
import (
"encoding/json"
"net/http"
"n.eko.moe/neko/internal/http/endpoint"
)
// JSON encodes data to rw in JSON format. Returns a pointer to a
// HandlerError if encoding fails.
func JSON(w http.ResponseWriter, data interface{}, status int) error {
w.WriteHeader(status)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(data)
if err != nil {
return &endpoint.HandlerError{
Status: http.StatusInternalServerError,
Message: "unable to write JSON response",
Err: err,
}
}
return nil
}
// Empty merely sets the response code to NoContent (204).
func Empty(w http.ResponseWriter) error {
w.WriteHeader(http.StatusNoContent)
return nil
}

235
internal/remote/manager.go Normal file
View File

@ -0,0 +1,235 @@
package remote
import (
"fmt"
"time"
"github.com/kataras/go-events"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/gst"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/config"
"n.eko.moe/neko/internal/xorg"
)
type RemoteManager struct {
logger zerolog.Logger
video *gst.Pipeline
audio *gst.Pipeline
config *config.Remote
broadcast types.BroadcastManager
cleanup *time.Ticker
shutdown chan bool
emmiter events.EventEmmiter
streaming bool
}
func New(config *config.Remote, broadcast types.BroadcastManager) *RemoteManager {
return &RemoteManager{
logger: log.With().Str("module", "remote").Logger(),
cleanup: time.NewTicker(1 * time.Second),
shutdown: make(chan bool),
emmiter: events.New(),
config: config,
broadcast: broadcast,
streaming: false,
}
}
func (manager *RemoteManager) VideoCodec() string {
return manager.config.VideoCodec
}
func (manager *RemoteManager) AudioCodec() string {
return manager.config.AudioCodec
}
func (manager *RemoteManager) Start() {
xorg.Display(manager.config.Display)
if !xorg.ValidScreenSize(manager.config.ScreenWidth, manager.config.ScreenHeight, manager.config.ScreenRate) {
manager.logger.Warn().Msgf("invalid screen option %dx%d@%d", manager.config.ScreenWidth, manager.config.ScreenHeight, manager.config.ScreenRate)
} else if err := xorg.ChangeScreenSize(manager.config.ScreenWidth, manager.config.ScreenHeight, manager.config.ScreenRate); err != nil {
manager.logger.Warn().Err(err).Msg("unable to change screen size")
}
manager.createPipelines()
manager.broadcast.Start()
go func() {
defer func() {
manager.logger.Info().Msg("shutdown")
}()
for {
select {
case <-manager.shutdown:
return
case sample := <-manager.video.Sample:
manager.emmiter.Emit("video", sample)
case sample := <-manager.audio.Sample:
manager.emmiter.Emit("audio", sample)
case <-manager.cleanup.C:
xorg.CheckKeys(time.Second * 10)
}
}
}()
}
func (manager *RemoteManager) Shutdown() error {
manager.logger.Info().Msgf("remote shutting down")
manager.video.Stop()
manager.audio.Stop()
manager.broadcast.Stop()
manager.cleanup.Stop()
manager.shutdown <- true
return nil
}
func (manager *RemoteManager) OnVideoFrame(listener func(sample types.Sample)) {
manager.emmiter.On("video", func(payload ...interface{}) {
listener(payload[0].(types.Sample))
})
}
func (manager *RemoteManager) OnAudioFrame(listener func(sample types.Sample)) {
manager.emmiter.On("audio", func(payload ...interface{}) {
listener(payload[0].(types.Sample))
})
}
func (manager *RemoteManager) StartStream() {
manager.createPipelines()
manager.logger.Info().
Str("video_display", manager.config.Display).
Str("video_codec", manager.config.VideoCodec).
Str("audio_device", manager.config.Device).
Str("audio_codec", manager.config.AudioCodec).
Str("audio_pipeline_src", manager.audio.Src).
Str("video_pipeline_src", manager.video.Src).
Str("screen_resolution", fmt.Sprintf("%dx%d@%d", manager.config.ScreenWidth, manager.config.ScreenHeight, manager.config.ScreenRate)).
Msgf("Pipelines starting...")
manager.video.Start()
manager.audio.Start()
manager.streaming = true
}
func (manager *RemoteManager) StopStream() {
manager.logger.Info().Msgf("Pipelines shutting down...")
manager.video.Stop()
manager.audio.Stop()
manager.streaming = false
}
func (manager *RemoteManager) Streaming() bool {
return manager.streaming
}
func (manager *RemoteManager) createPipelines() {
var err error
manager.video, err = gst.CreateAppPipeline(
manager.config.VideoCodec,
manager.config.Display,
manager.config.VideoParams,
)
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create video pipeline")
}
manager.audio, err = gst.CreateAppPipeline(
manager.config.AudioCodec,
manager.config.Device,
manager.config.AudioParams,
)
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create audio pipeline")
}
}
func (manager *RemoteManager) ChangeResolution(width int, height int, rate int) error {
if !xorg.ValidScreenSize(width, height, rate) {
return fmt.Errorf("unknown configuration")
}
manager.video.Stop()
manager.broadcast.Stop()
defer func() {
manager.video.Start()
manager.broadcast.Start()
manager.logger.Info().Msg("starting video pipeline...")
}()
if err := xorg.ChangeScreenSize(width, height, rate); err != nil {
return err
}
var err error
manager.video, err = gst.CreateAppPipeline(
manager.config.VideoCodec,
manager.config.Display,
manager.config.VideoParams,
)
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create new video pipeline")
}
return nil
}
func (manager *RemoteManager) Move(x, y int) {
xorg.Move(x, y)
}
func (manager *RemoteManager) Scroll(x, y int) {
xorg.Scroll(x, y)
}
func (manager *RemoteManager) ButtonDown(code int) error {
return xorg.ButtonDown(code)
}
func (manager *RemoteManager) KeyDown(code uint64) error {
return xorg.KeyDown(code)
}
func (manager *RemoteManager) ButtonUp(code int) error {
return xorg.ButtonUp(code)
}
func (manager *RemoteManager) KeyUp(code uint64) error {
return xorg.KeyUp(code)
}
func (manager *RemoteManager) ReadClipboard() string {
return xorg.ReadClipboard()
}
func (manager *RemoteManager) WriteClipboard(data string) {
xorg.WriteClipboard(data)
}
func (manager *RemoteManager) ResetKeys() {
xorg.ResetKeys()
}
func (manager *RemoteManager) ScreenConfigurations() map[int]types.ScreenConfiguration {
return xorg.ScreenConfigurations
}
func (manager *RemoteManager) GetScreenSize() *types.ScreenSize {
return xorg.GetScreenSize()
}
func (manager *RemoteManager) SetKeyboardLayout(layout string) {
xorg.SetKeyboardLayout(layout)
}
func (manager *RemoteManager) SetKeyboardModifiers(NumLock int, CapsLock int, ScrollLock int) {
xorg.SetKeyboardModifiers(NumLock, CapsLock, ScrollLock)
}

189
internal/session/manager.go Normal file
View File

@ -0,0 +1,189 @@
package session
import (
"fmt"
"github.com/kataras/go-events"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/utils"
)
func New(remote types.RemoteManager) *SessionManager {
return &SessionManager{
logger: log.With().Str("module", "session").Logger(),
host: "",
remote: remote,
members: make(map[string]*Session),
emmiter: events.New(),
}
}
type SessionManager struct {
logger zerolog.Logger
host string
remote types.RemoteManager
members map[string]*Session
emmiter events.EventEmmiter
}
func (manager *SessionManager) New(id string, admin bool, socket types.WebSocket) types.Session {
session := &Session{
id: id,
admin: admin,
manager: manager,
socket: socket,
logger: manager.logger.With().Str("id", id).Logger(),
connected: false,
}
manager.members[id] = session
manager.emmiter.Emit("created", id, session)
if manager.remote.Streaming() != true && len(manager.members) > 0 {
manager.remote.StartStream()
}
return session
}
func (manager *SessionManager) HasHost() bool {
return manager.host != ""
}
func (manager *SessionManager) IsHost(id string) bool {
return manager.host == id
}
func (manager *SessionManager) SetHost(id string) error {
_, ok := manager.members[id]
if ok {
manager.host = id
manager.emmiter.Emit("host", id)
return nil
}
return fmt.Errorf("invalid session id %s", id)
}
func (manager *SessionManager) GetHost() (types.Session, bool) {
host, ok := manager.members[manager.host]
return host, ok
}
func (manager *SessionManager) ClearHost() {
id := manager.host
manager.host = ""
manager.emmiter.Emit("host_cleared", id)
}
func (manager *SessionManager) Has(id string) bool {
_, ok := manager.members[id]
return ok
}
func (manager *SessionManager) Get(id string) (types.Session, bool) {
session, ok := manager.members[id]
return session, ok
}
func (manager *SessionManager) Admins() []*types.Member {
members := []*types.Member{}
for _, session := range manager.members {
if !session.connected || !session.admin {
continue
}
member := session.Member()
if member != nil {
members = append(members, member)
}
}
return members
}
func (manager *SessionManager) Members() []*types.Member {
members := []*types.Member{}
for _, session := range manager.members {
if !session.connected {
continue
}
member := session.Member()
if member != nil {
members = append(members, member)
}
}
return members
}
func (manager *SessionManager) Destroy(id string) error {
session, ok := manager.members[id]
if ok {
err := session.destroy()
delete(manager.members, id)
if manager.remote.Streaming() != false && len(manager.members) <= 0 {
manager.remote.StopStream()
}
manager.emmiter.Emit("destroyed", id, session)
return err
}
return nil
}
func (manager *SessionManager) Clear() error {
return nil
}
func (manager *SessionManager) Broadcast(v interface{}, exclude interface{}) error {
for id, session := range manager.members {
if !session.connected {
continue
}
if exclude != nil {
if in, _ := utils.ArrayIn(id, exclude); in {
continue
}
}
if err := session.Send(v); err != nil {
return err
}
}
return nil
}
func (manager *SessionManager) OnHost(listener func(id string)) {
manager.emmiter.On("host", func(payload ...interface{}) {
listener(payload[0].(string))
})
}
func (manager *SessionManager) OnHostCleared(listener func(id string)) {
manager.emmiter.On("host_cleared", func(payload ...interface{}) {
listener(payload[0].(string))
})
}
func (manager *SessionManager) OnDestroy(listener func(id string, session types.Session)) {
manager.emmiter.On("destroyed", func(payload ...interface{}) {
listener(payload[0].(string), payload[1].(*Session))
})
}
func (manager *SessionManager) OnCreated(listener func(id string, session types.Session)) {
manager.emmiter.On("created", func(payload ...interface{}) {
listener(payload[0].(string), payload[1].(*Session))
})
}
func (manager *SessionManager) OnConnected(listener func(id string, session types.Session)) {
manager.emmiter.On("connected", func(payload ...interface{}) {
listener(payload[0].(string), payload[1].(*Session))
})
}

137
internal/session/session.go Normal file
View File

@ -0,0 +1,137 @@
package session
import (
"sync"
"github.com/rs/zerolog"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
type Session struct {
logger zerolog.Logger
id string
name string
admin bool
muted bool
connected bool
manager *SessionManager
socket types.WebSocket
peer types.Peer
mu sync.Mutex
}
func (session *Session) ID() string {
return session.id
}
func (session *Session) Name() string {
return session.name
}
func (session *Session) Admin() bool {
return session.admin
}
func (session *Session) Muted() bool {
return session.muted
}
func (session *Session) Connected() bool {
return session.connected
}
func (session *Session) Address() string {
if session.socket == nil {
return ""
}
return session.socket.Address()
}
func (session *Session) Member() *types.Member {
return &types.Member{
ID: session.id,
Name: session.name,
Admin: session.admin,
Muted: session.muted,
}
}
func (session *Session) SetMuted(muted bool) {
session.muted = muted
}
func (session *Session) SetName(name string) error {
session.name = name
return nil
}
func (session *Session) SetSocket(socket types.WebSocket) error {
session.socket = socket
return nil
}
func (session *Session) SetPeer(peer types.Peer) error {
session.peer = peer
return nil
}
func (session *Session) SetConnected(connected bool) error {
session.connected = connected
if connected {
session.manager.emmiter.Emit("connected", session.id, session)
}
return nil
}
func (session *Session) Kick(reason string) error {
if session.socket == nil {
return nil
}
if err := session.socket.Send(&message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: reason,
}); err != nil {
return err
}
return session.destroy()
}
func (session *Session) Send(v interface{}) error {
if session.socket == nil {
return nil
}
return session.socket.Send(v)
}
func (session *Session) Write(v interface{}) error {
if session.socket == nil {
return nil
}
return session.socket.Send(v)
}
func (session *Session) SignalAnswer(sdp string) error {
if session.peer == nil {
return nil
}
return session.peer.SignalAnswer(sdp)
}
func (session *Session) destroy() error {
if session.socket != nil {
if err := session.socket.Destroy(); err != nil {
return err
}
}
if session.peer != nil {
if err := session.peer.Destroy(); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,10 @@
package types
type BroadcastManager interface {
Start()
Stop()
IsActive() bool
Create(url string)
Destroy()
GetUrl() string
}

View File

@ -0,0 +1,23 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Broadcast struct {
Pipeline string
}
func (Broadcast) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("broadcast_pipeline", "", "audio codec parameters to use for broadcasting")
if err := viper.BindPFlag("broadcast_pipeline", cmd.PersistentFlags().Lookup("broadcast_pipeline")); err != nil {
return err
}
return nil
}
func (s *Broadcast) Set() {
s.Pipeline = viper.GetString("broadcast_pipeline")
}

View File

@ -0,0 +1,8 @@
package config
import "github.com/spf13/cobra"
type Config interface {
Init(cmd *cobra.Command) error
Set()
}

View File

@ -0,0 +1,136 @@
package config
import (
"regexp"
"strconv"
"github.com/pion/webrtc/v2"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Remote struct {
Display string
Device string
AudioCodec string
AudioParams string
VideoCodec string
VideoParams string
ScreenWidth int
ScreenHeight int
ScreenRate int
}
func (Remote) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("display", ":99.0", "XDisplay to capture")
if err := viper.BindPFlag("display", cmd.PersistentFlags().Lookup("display")); err != nil {
return err
}
cmd.PersistentFlags().String("device", "auto_null.monitor", "audio device to capture")
if err := viper.BindPFlag("device", cmd.PersistentFlags().Lookup("device")); err != nil {
return err
}
cmd.PersistentFlags().String("audio", "", "audio codec parameters to use for streaming")
if err := viper.BindPFlag("audio", cmd.PersistentFlags().Lookup("audio")); err != nil {
return err
}
cmd.PersistentFlags().String("video", "", "video codec parameters to use for streaming")
if err := viper.BindPFlag("video", cmd.PersistentFlags().Lookup("video")); err != nil {
return err
}
cmd.PersistentFlags().String("screen", "1280x720@30", "default screen resolution and framerate")
if err := viper.BindPFlag("screen", cmd.PersistentFlags().Lookup("screen")); err != nil {
return err
}
// video codecs
cmd.PersistentFlags().Bool("vp8", false, "use VP8 video codec")
if err := viper.BindPFlag("vp8", cmd.PersistentFlags().Lookup("vp8")); err != nil {
return err
}
cmd.PersistentFlags().Bool("vp9", false, "use VP9 video codec")
if err := viper.BindPFlag("vp9", cmd.PersistentFlags().Lookup("vp9")); err != nil {
return err
}
cmd.PersistentFlags().Bool("h264", false, "use H264 video codec")
if err := viper.BindPFlag("h264", cmd.PersistentFlags().Lookup("h264")); err != nil {
return err
}
// audio codecs
cmd.PersistentFlags().Bool("opus", false, "use Opus audio codec")
if err := viper.BindPFlag("opus", cmd.PersistentFlags().Lookup("opus")); err != nil {
return err
}
cmd.PersistentFlags().Bool("g722", false, "use G722 audio codec")
if err := viper.BindPFlag("g722", cmd.PersistentFlags().Lookup("g722")); err != nil {
return err
}
cmd.PersistentFlags().Bool("pcmu", false, "use PCMU audio codec")
if err := viper.BindPFlag("pcmu", cmd.PersistentFlags().Lookup("pcmu")); err != nil {
return err
}
cmd.PersistentFlags().Bool("pcma", false, "use PCMA audio codec")
if err := viper.BindPFlag("pcma", cmd.PersistentFlags().Lookup("pcma")); err != nil {
return err
}
return nil
}
func (s *Remote) Set() {
videoCodec := webrtc.VP8
if viper.GetBool("vp8") {
videoCodec = webrtc.VP8
} else if viper.GetBool("vp9") {
videoCodec = webrtc.VP9
} else if viper.GetBool("h264") {
videoCodec = webrtc.H264
}
audioCodec := webrtc.Opus
if viper.GetBool("opus") {
audioCodec = webrtc.Opus
} else if viper.GetBool("g722") {
audioCodec = webrtc.G722
} else if viper.GetBool("pcmu") {
audioCodec = webrtc.PCMU
} else if viper.GetBool("pcma") {
audioCodec = webrtc.PCMA
}
s.Device = viper.GetString("device")
s.AudioCodec = audioCodec
s.AudioParams = viper.GetString("audio")
s.Display = viper.GetString("display")
s.VideoCodec = videoCodec
s.VideoParams = viper.GetString("video")
s.ScreenWidth = 1280
s.ScreenHeight = 720
s.ScreenRate = 30
r := regexp.MustCompile(`([0-9]{1,4})x([0-9]{1,4})@([0-9]{1,3})`)
res := r.FindStringSubmatch(viper.GetString("screen"))
if len(res) > 0 {
width, err1 := strconv.ParseInt(res[1], 10, 64)
height, err2 := strconv.ParseInt(res[2], 10, 64)
rate, err3 := strconv.ParseInt(res[3], 10, 64)
if err1 == nil && err2 == nil && err3 == nil {
s.ScreenWidth = int(width)
s.ScreenHeight = int(height)
s.ScreenRate = int(rate)
}
}
}

View File

@ -0,0 +1,37 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Root struct {
Debug bool
Logs bool
CfgFile string
}
func (Root) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().BoolP("debug", "d", false, "enable debug mode")
if err := viper.BindPFlag("debug", cmd.PersistentFlags().Lookup("debug")); err != nil {
return err
}
cmd.PersistentFlags().BoolP("logs", "l", false, "save logs to file")
if err := viper.BindPFlag("logs", cmd.PersistentFlags().Lookup("logs")); err != nil {
return err
}
cmd.PersistentFlags().String("config", "", "configuration file path")
if err := viper.BindPFlag("config", cmd.PersistentFlags().Lookup("config")); err != nil {
return err
}
return nil
}
func (s *Root) Set() {
s.Logs = viper.GetBool("logs")
s.Debug = viper.GetBool("debug")
s.CfgFile = viper.GetString("config")
}

View File

@ -0,0 +1,44 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Server struct {
Cert string
Key string
Bind string
Static string
}
func (Server) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("bind", "127.0.0.1:8080", "address/port/socket to serve neko")
if err := viper.BindPFlag("bind", cmd.PersistentFlags().Lookup("bind")); err != nil {
return err
}
cmd.PersistentFlags().String("cert", "", "path to the SSL cert used to secure the neko server")
if err := viper.BindPFlag("cert", cmd.PersistentFlags().Lookup("cert")); err != nil {
return err
}
cmd.PersistentFlags().String("key", "", "path to the SSL key used to secure the neko server")
if err := viper.BindPFlag("key", cmd.PersistentFlags().Lookup("key")); err != nil {
return err
}
cmd.PersistentFlags().String("static", "./www", "path to neko client files to serve")
if err := viper.BindPFlag("static", cmd.PersistentFlags().Lookup("static")); err != nil {
return err
}
return nil
}
func (s *Server) Set() {
s.Cert = viper.GetString("cert")
s.Key = viper.GetString("key")
s.Bind = viper.GetString("bind")
s.Static = viper.GetString("static")
}

View File

@ -0,0 +1,79 @@
package config
import (
"strconv"
"strings"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"n.eko.moe/neko/internal/utils"
)
type WebRTC struct {
ICELite bool
ICEServers []string
EphemeralMin uint16
EphemeralMax uint16
NAT1To1IPs []string
}
func (WebRTC) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("epr", "59000-59100", "limits the pool of ephemeral ports that ICE UDP connections can allocate from")
if err := viper.BindPFlag("epr", cmd.PersistentFlags().Lookup("epr")); err != nil {
return err
}
cmd.PersistentFlags().StringSlice("nat1to1", []string{}, "sets a list of external IP addresses of 1:1 (D)NAT and a candidate type for which the external IP address is used")
if err := viper.BindPFlag("nat1to1", cmd.PersistentFlags().Lookup("nat1to1")); err != nil {
return err
}
cmd.PersistentFlags().Bool("icelite", false, "configures whether or not the ice agent should be a lite agent")
if err := viper.BindPFlag("icelite", cmd.PersistentFlags().Lookup("icelite")); err != nil {
return err
}
cmd.PersistentFlags().StringSlice("iceserver", []string{"stun:stun.l.google.com:19302"}, "describes a single STUN and TURN server that can be used by the ICEAgent to establish a connection with a peer")
if err := viper.BindPFlag("iceserver", cmd.PersistentFlags().Lookup("iceserver")); err != nil {
return err
}
return nil
}
func (s *WebRTC) Set() {
s.ICELite = viper.GetBool("icelite")
s.ICEServers = viper.GetStringSlice("iceserver")
s.NAT1To1IPs = viper.GetStringSlice("nat1to1")
if len(s.NAT1To1IPs) == 0 {
ip, err := utils.GetIP()
if err == nil {
s.NAT1To1IPs = append(s.NAT1To1IPs, ip)
}
}
min := uint16(59000)
max := uint16(59100)
epr := viper.GetString("epr")
ports := strings.SplitN(epr, "-", -1)
if len(ports) > 1 {
start, err := strconv.ParseUint(ports[0], 10, 16)
if err == nil {
min = uint16(start)
}
end, err := strconv.ParseUint(ports[1], 10, 16)
if err == nil {
max = uint16(end)
}
}
if min > max {
s.EphemeralMin = max
s.EphemeralMax = min
} else {
s.EphemeralMin = min
s.EphemeralMax = max
}
}

View File

@ -0,0 +1,37 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type WebSocket struct {
Password string
AdminPassword string
Proxy bool
}
func (WebSocket) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("password", "neko", "password for connecting to stream")
if err := viper.BindPFlag("password", cmd.PersistentFlags().Lookup("password")); err != nil {
return err
}
cmd.PersistentFlags().String("password_admin", "admin", "admin password for connecting to stream")
if err := viper.BindPFlag("password_admin", cmd.PersistentFlags().Lookup("password_admin")); err != nil {
return err
}
cmd.PersistentFlags().Bool("proxy", false, "enable reverse proxy mode")
if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
return err
}
return nil
}
func (s *WebSocket) Set() {
s.Password = viper.GetString("password")
s.AdminPassword = viper.GetString("password_admin")
s.Proxy = viper.GetBool("proxy")
}

View File

@ -0,0 +1,55 @@
package event
const (
SYSTEM_DISCONNECT = "system/disconnect"
)
const (
SIGNAL_ANSWER = "signal/answer"
SIGNAL_PROVIDE = "signal/provide"
)
const (
MEMBER_LIST = "member/list"
MEMBER_CONNECTED = "member/connected"
MEMBER_DISCONNECTED = "member/disconnected"
)
const (
CONTROL_LOCKED = "control/locked"
CONTROL_RELEASE = "control/release"
CONTROL_REQUEST = "control/request"
CONTROL_REQUESTING = "control/requesting"
CONTROL_GIVE = "control/give"
CONTROL_CLIPBOARD = "control/clipboard"
CONTROL_KEYBOARD = "control/keyboard"
)
const (
CHAT_MESSAGE = "chat/message"
CHAT_EMOTE = "chat/emote"
)
const (
SCREEN_CONFIGURATIONS = "screen/configurations"
SCREEN_RESOLUTION = "screen/resolution"
SCREEN_SET = "screen/set"
)
const (
BORADCAST_STATUS = "broadcast/status"
BORADCAST_CREATE = "broadcast/create"
BORADCAST_DESTROY = "broadcast/destroy"
)
const (
ADMIN_BAN = "admin/ban"
ADMIN_KICK = "admin/kick"
ADMIN_LOCK = "admin/lock"
ADMIN_MUTE = "admin/mute"
ADMIN_UNLOCK = "admin/unlock"
ADMIN_UNMUTE = "admin/unmute"
ADMIN_CONTROL = "admin/control"
ADMIN_RELEASE = "admin/release"
ADMIN_GIVE = "admin/give"
)

14
internal/types/keys.go Normal file
View File

@ -0,0 +1,14 @@
package types
type Button struct {
Name string
Code int
Keysym int
}
type Key struct {
Name string
Value string
Code int
Keysym int
}

View File

@ -0,0 +1,123 @@
package message
import (
"n.eko.moe/neko/internal/types"
)
type Message struct {
Event string `json:"event"`
}
type Disconnect struct {
Event string `json:"event"`
Message string `json:"message"`
}
type SignalProvide struct {
Event string `json:"event"`
ID string `json:"id"`
SDP string `json:"sdp"`
Lite bool `json:"lite"`
ICE []string `json:"ice"`
}
type SignalAnswer struct {
Event string `json:"event"`
DisplayName string `json:"displayname"`
SDP string `json:"sdp"`
}
type MembersList struct {
Event string `json:"event"`
Memebers []*types.Member `json:"members"`
}
type Member struct {
Event string `json:"event"`
*types.Member
}
type MemberDisconnected struct {
Event string `json:"event"`
ID string `json:"id"`
}
type Clipboard struct {
Event string `json:"event"`
Text string `json:"text"`
}
type Keyboard struct {
Event string `json:"event"`
Layout *string `json:"layout,omitempty"`
CapsLock *bool `json:"capsLock,omitempty"`
NumLock *bool `json:"numLock,omitempty"`
ScrollLock *bool `json:"scrollLock,omitempty"`
}
type Control struct {
Event string `json:"event"`
ID string `json:"id"`
}
type ControlTarget struct {
Event string `json:"event"`
ID string `json:"id"`
Target string `json:"target"`
}
type ChatReceive struct {
Event string `json:"event"`
Content string `json:"content"`
}
type ChatSend struct {
Event string `json:"event"`
ID string `json:"id"`
Content string `json:"content"`
}
type EmoteReceive struct {
Event string `json:"event"`
Emote string `json:"emote"`
}
type EmoteSend struct {
Event string `json:"event"`
ID string `json:"id"`
Emote string `json:"emote"`
}
type Admin struct {
Event string `json:"event"`
ID string `json:"id"`
}
type AdminTarget struct {
Event string `json:"event"`
Target string `json:"target"`
ID string `json:"id"`
}
type ScreenResolution struct {
Event string `json:"event"`
ID string `json:"id,omitempty"`
Width int `json:"width"`
Height int `json:"height"`
Rate int `json:"rate"`
}
type ScreenConfigurations struct {
Event string `json:"event"`
Configurations map[int]types.ScreenConfiguration `json:"configurations"`
}
type BroadcastStatus struct {
Event string `json:"event"`
URL string `json:"url"`
IsActive bool `json:"isActive"`
}
type BroadcastCreate struct {
Event string `json:"event"`
URL string `json:"url"`
}

27
internal/types/remote.go Normal file
View File

@ -0,0 +1,27 @@
package types
type RemoteManager interface {
VideoCodec() string
AudioCodec() string
Start()
Shutdown() error
OnVideoFrame(listener func(sample Sample))
OnAudioFrame(listener func(sample Sample))
StartStream()
StopStream()
Streaming() bool
ChangeResolution(width int, height int, rate int) error
GetScreenSize() *ScreenSize
ScreenConfigurations() map[int]ScreenConfiguration
Move(x, y int)
Scroll(x, y int)
ButtonDown(code int) error
KeyDown(code uint64) error
ButtonUp(code int) error
KeyUp(code uint64) error
ReadClipboard() string
WriteClipboard(data string)
ResetKeys()
SetKeyboardLayout(layout string)
SetKeyboardModifiers(NumLock int, CapsLock int, ScrollLock int)
}

48
internal/types/session.go Normal file
View File

@ -0,0 +1,48 @@
package types
type Member struct {
ID string `json:"id"`
Name string `json:"displayname"`
Admin bool `json:"admin"`
Muted bool `json:"muted"`
}
type Session interface {
ID() string
Name() string
Admin() bool
Muted() bool
Connected() bool
Member() *Member
SetMuted(muted bool)
SetName(name string) error
SetConnected(connected bool) error
SetSocket(socket WebSocket) error
SetPeer(peer Peer) error
Address() string
Kick(message string) error
Write(v interface{}) error
Send(v interface{}) error
SignalAnswer(sdp string) error
}
type SessionManager interface {
New(id string, admin bool, socket WebSocket) Session
HasHost() bool
IsHost(id string) bool
SetHost(id string) error
GetHost() (Session, bool)
ClearHost()
Has(id string) bool
Get(id string) (Session, bool)
Members() []*Member
Admins() []*Member
Destroy(id string) error
Clear() error
Broadcast(v interface{}, exclude interface{}) error
OnHost(listener func(id string))
OnHostCleared(listener func(id string))
OnDestroy(listener func(id string, session Session))
OnCreated(listener func(id string, session Session))
OnConnected(listener func(id string, session Session))
}

18
internal/types/webrtc.go Normal file
View File

@ -0,0 +1,18 @@
package types
type Sample struct {
Data []byte
Samples uint32
}
type WebRTCManager interface {
Start()
Shutdown() error
CreatePeer(id string, session Session) (string, bool, []string, error)
}
type Peer interface {
SignalAnswer(sdp string) error
WriteData(v interface{}) error
Destroy() error
}

View File

@ -0,0 +1,15 @@
package types
import "net/http"
type WebSocket interface {
Address() string
Send(v interface{}) error
Destroy() error
}
type WebSocketHandler interface {
Start() error
Shutdown() error
Upgrade(w http.ResponseWriter, r *http.Request) error
}

13
internal/types/xorg.go Normal file
View File

@ -0,0 +1,13 @@
package types
type ScreenSize struct {
Width int `json:"width"`
Height int `json:"height"`
Rate int16 `json:"rate"`
}
type ScreenConfiguration struct {
Width int `json:"width"`
Height int `json:"height"`
Rates map[int]int16 `json:"rates"`
}

24
internal/utils/array.go Normal file
View File

@ -0,0 +1,24 @@
package utils
import (
"reflect"
)
func ArrayIn(val interface{}, array interface{}) (exists bool, index int) {
exists = false
index = -1
switch reflect.TypeOf(array).Kind() {
case reflect.Slice:
s := reflect.ValueOf(array)
for i := 0; i < s.Len(); i++ {
if reflect.DeepEqual(val, s.Index(i).Interface()) == true {
index = i
exists = true
return
}
}
}
return
}

34
internal/utils/color.go Normal file
View File

@ -0,0 +1,34 @@
package utils
import (
"fmt"
"regexp"
)
const (
char = "&"
)
// Colors: http://www.lihaoyi.com/post/BuildyourownCommandLinewithANSIescapecodes.html
var re = regexp.MustCompile(char + `(?m)([0-9]{1,2};[0-9]{1,2}|[0-9]{1,2})`)
func Color(str string) string {
result := ""
lastIndex := 0
for _, v := range re.FindAllSubmatchIndex([]byte(str), -1) {
groups := []string{}
for i := 0; i < len(v); i += 2 {
groups = append(groups, str[v[i]:v[i+1]])
}
result += str[lastIndex:v[0]] + "\033[" + groups[1] + "m"
lastIndex = v[1]
}
return result + str[lastIndex:]
}
func Colorf(format string, a ...interface{}) string {
return fmt.Sprintf(Color(format), a...)
}

35
internal/utils/ip.go Normal file
View File

@ -0,0 +1,35 @@
package utils
import (
"bytes"
"io/ioutil"
"net/http"
)
// dig @resolver1.opendns.com ANY myip.opendns.com +short -4
func GetIP() (string, error) {
rsp, err := http.Get("http://checkip.amazonaws.com")
if err != nil {
return "", err
}
defer rsp.Body.Close()
buf, err := ioutil.ReadAll(rsp.Body)
if err != nil {
return "", err
}
return string(bytes.TrimSpace(buf)), nil
}
func ReadUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
}
return IPAddress
}

10
internal/utils/json.go Normal file
View File

@ -0,0 +1,10 @@
package utils
import "encoding/json"
func Unmarshal(in interface{}, raw []byte, callback func() error) error {
if err := json.Unmarshal(raw, &in); err != nil {
return err
}
return callback()
}

98
internal/utils/uid.go Normal file
View File

@ -0,0 +1,98 @@
package utils
import (
"crypto/rand"
"fmt"
"math"
)
const (
defaultAlphabet = "_-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" // len=64
defaultSize = 21
defaultMaskSize = 5
)
// Generator function
type Generator func([]byte) (int, error)
// BytesGenerator is the default bytes generator
var BytesGenerator Generator = rand.Read
func initMasks(params ...int) []uint {
var size int
if len(params) == 0 {
size = defaultMaskSize
} else {
size = params[0]
}
masks := make([]uint, size)
for i := 0; i < size; i++ {
shift := 3 + i
masks[i] = (2 << uint(shift)) - 1
}
return masks
}
func getMask(alphabet string, masks []uint) int {
for i := 0; i < len(masks); i++ {
curr := int(masks[i])
if curr >= len(alphabet)-1 {
return curr
}
}
return 0
}
// GenerateUID is a low-level function to change alphabet and ID size.
func GenerateUID(alphabet string, size int) (string, error) {
if len(alphabet) == 0 || len(alphabet) > 255 {
return "", fmt.Errorf("alphabet must not empty and contain no more than 255 chars. Current len is %d", len(alphabet))
}
if size <= 0 {
return "", fmt.Errorf("size must be positive integer")
}
masks := initMasks(size)
mask := getMask(alphabet, masks)
ceilArg := 1.6 * float64(mask*size) / float64(len(alphabet))
step := int(math.Ceil(ceilArg))
id := make([]byte, size)
bytes := make([]byte, step)
for j := 0; ; {
_, err := BytesGenerator(bytes)
if err != nil {
return "", err
}
for i := 0; i < step; i++ {
currByte := bytes[i] & byte(mask)
if currByte < byte(len(alphabet)) {
id[j] = alphabet[currByte]
j++
if j == size {
return string(id[:size]), nil
}
}
}
}
}
// NewUID generates secure URL-friendly unique ID.
func NewUID(param ...int) (string, error) {
var size int
if len(param) == 0 {
size = defaultSize
} else {
size = param[0]
}
bytes := make([]byte, size)
_, err := BytesGenerator(bytes)
if err != nil {
return "", err
}
id := make([]byte, size)
for i := 0; i < size; i++ {
id[i] = defaultAlphabet[bytes[i]&63]
}
return string(id[:size]), nil
}

137
internal/webrtc/handle.go Normal file
View File

@ -0,0 +1,137 @@
package webrtc
import (
"bytes"
"encoding/binary"
"strconv"
"github.com/pion/webrtc/v2"
)
const OP_MOVE = 0x01
const OP_SCROLL = 0x02
const OP_KEY_DOWN = 0x03
const OP_KEY_UP = 0x04
const OP_KEY_CLK = 0x05
type PayloadHeader struct {
Event uint8
Length uint16
}
type PayloadMove struct {
PayloadHeader
X uint16
Y uint16
}
type PayloadScroll struct {
PayloadHeader
X int16
Y int16
}
type PayloadKey struct {
PayloadHeader
Key uint64
}
func (manager *WebRTCManager) handle(id string, msg webrtc.DataChannelMessage) error {
if !manager.sessions.IsHost(id) {
return nil
}
buffer := bytes.NewBuffer(msg.Data)
header := &PayloadHeader{}
hbytes := make([]byte, 3)
if _, err := buffer.Read(hbytes); err != nil {
return err
}
if err := binary.Read(bytes.NewBuffer(hbytes), binary.LittleEndian, header); err != nil {
return err
}
buffer = bytes.NewBuffer(msg.Data)
switch header.Event {
case OP_MOVE:
payload := &PayloadMove{}
if err := binary.Read(buffer, binary.LittleEndian, payload); err != nil {
return err
}
manager.remote.Move(int(payload.X), int(payload.Y))
break
case OP_SCROLL:
payload := &PayloadScroll{}
if err := binary.Read(buffer, binary.LittleEndian, payload); err != nil {
return err
}
manager.logger.
Debug().
Str("x", strconv.Itoa(int(payload.X))).
Str("y", strconv.Itoa(int(payload.Y))).
Msg("scroll")
manager.remote.Scroll(int(payload.X), int(payload.Y))
break
case OP_KEY_DOWN:
payload := &PayloadKey{}
if err := binary.Read(buffer, binary.LittleEndian, payload); err != nil {
return err
}
if payload.Key < 8 {
err := manager.remote.ButtonDown(int(payload.Key))
if err != nil {
manager.logger.Warn().Err(err).Msg("button down failed")
return nil
}
manager.logger.Debug().Msgf("button down %d", payload.Key)
} else {
err := manager.remote.KeyDown(uint64(payload.Key))
if err != nil {
manager.logger.Warn().Err(err).Msg("key down failed")
return nil
}
manager.logger.Debug().Msgf("key down %d", payload.Key)
}
break
case OP_KEY_UP:
payload := &PayloadKey{}
err := binary.Read(buffer, binary.LittleEndian, payload)
if err != nil {
return err
}
if payload.Key < 8 {
err := manager.remote.ButtonUp(int(payload.Key))
if err != nil {
manager.logger.Warn().Err(err).Msg("button up failed")
return nil
}
manager.logger.Debug().Msgf("button up %d", payload.Key)
} else {
err := manager.remote.KeyUp(uint64(payload.Key))
if err != nil {
manager.logger.Warn().Err(err).Msg("key up failed")
return nil
}
manager.logger.Debug().Msgf("key up %d", payload.Key)
}
break
case OP_KEY_CLK:
// unused
break
}
return nil
}

66
internal/webrtc/logger.go Normal file
View File

@ -0,0 +1,66 @@
package webrtc
import (
"fmt"
"strings"
"github.com/pion/logging"
"github.com/rs/zerolog"
)
type nulllog struct{}
func (l nulllog) Trace(msg string) {}
func (l nulllog) Tracef(format string, args ...interface{}) {}
func (l nulllog) Debug(msg string) {}
func (l nulllog) Debugf(format string, args ...interface{}) {}
func (l nulllog) Info(msg string) {}
func (l nulllog) Infof(format string, args ...interface{}) {}
func (l nulllog) Warn(msg string) {}
func (l nulllog) Warnf(format string, args ...interface{}) {}
func (l nulllog) Error(msg string) {}
func (l nulllog) Errorf(format string, args ...interface{}) {}
type logger struct {
logger zerolog.Logger
subsystem string
}
func (l logger) Trace(msg string) { l.logger.Trace().Msg(msg) }
func (l logger) Tracef(format string, args ...interface{}) { l.logger.Trace().Msgf(format, args...) }
func (l logger) Debug(msg string) { l.logger.Debug().Msg(msg) }
func (l logger) Debugf(format string, args ...interface{}) { l.logger.Debug().Msgf(format, args...) }
func (l logger) Info(msg string) {
if strings.Contains(msg, "packetio.Buffer is full") {
//l.logger.Panic().Msg(msg)
return
}
l.logger.Info().Msg(msg)
}
func (l logger) Infof(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
if strings.Contains(msg, "packetio.Buffer is full") {
// l.logger.Panic().Msg(msg)
return
}
l.logger.Info().Msg(msg)
}
func (l logger) Warn(msg string) { l.logger.Warn().Msg(msg) }
func (l logger) Warnf(format string, args ...interface{}) { l.logger.Warn().Msgf(format, args...) }
func (l logger) Error(msg string) { l.logger.Error().Msg(msg) }
func (l logger) Errorf(format string, args ...interface{}) { l.logger.Error().Msgf(format, args...) }
type loggerFactory struct {
logger zerolog.Logger
}
func (l loggerFactory) NewLogger(subsystem string) logging.LeveledLogger {
if subsystem == "sctp" {
return nulllog{}
}
return logger{
subsystem: subsystem,
logger: l.logger.With().Str("subsystem", subsystem).Logger(),
}
}

38
internal/webrtc/peer.go Normal file
View File

@ -0,0 +1,38 @@
package webrtc
import (
"sync"
"github.com/pion/webrtc/v2"
)
type Peer struct {
id string
api *webrtc.API
engine *webrtc.MediaEngine
manager *WebRTCManager
settings *webrtc.SettingEngine
connection *webrtc.PeerConnection
configuration *webrtc.Configuration
mu sync.Mutex
}
func (peer *Peer) SignalAnswer(sdp string) error {
return peer.connection.SetRemoteDescription(webrtc.SessionDescription{SDP: sdp, Type: webrtc.SDPTypeAnswer})
}
func (peer *Peer) WriteData(v interface{}) error {
peer.mu.Lock()
defer peer.mu.Unlock()
return nil
}
func (peer *Peer) Destroy() error {
if peer.connection != nil && peer.connection.ConnectionState() == webrtc.PeerConnectionStateConnected {
if err := peer.connection.Close(); err != nil {
return err
}
}
return nil
}

201
internal/webrtc/webrtc.go Normal file
View File

@ -0,0 +1,201 @@
package webrtc
import (
"fmt"
"io"
"math/rand"
"strings"
"github.com/pion/webrtc/v2"
"github.com/pion/webrtc/v2/pkg/media"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/config"
)
func New(sessions types.SessionManager, remote types.RemoteManager, config *config.WebRTC) *WebRTCManager {
return &WebRTCManager{
logger: log.With().Str("module", "webrtc").Logger(),
remote: remote,
sessions: sessions,
config: config,
}
}
type WebRTCManager struct {
logger zerolog.Logger
videoTrack *webrtc.Track
audioTrack *webrtc.Track
videoCodec *webrtc.RTPCodec
audioCodec *webrtc.RTPCodec
sessions types.SessionManager
remote types.RemoteManager
config *config.WebRTC
}
func (manager *WebRTCManager) Start() {
var err error
manager.audioTrack, manager.audioCodec, err = manager.createTrack(manager.remote.AudioCodec())
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create audio track")
}
manager.remote.OnAudioFrame(func(sample types.Sample) {
if err := manager.audioTrack.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe {
manager.logger.Warn().Err(err).Msg("audio pipeline failed to write")
}
})
manager.videoTrack, manager.videoCodec, err = manager.createTrack(manager.remote.VideoCodec())
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to create video track")
}
manager.remote.OnVideoFrame(func(sample types.Sample) {
if err := manager.videoTrack.WriteSample(media.Sample(sample)); err != nil && err != io.ErrClosedPipe {
manager.logger.Warn().Err(err).Msg("video pipeline failed to write")
}
})
manager.logger.Info().
Str("ice_lite", fmt.Sprintf("%t", manager.config.ICELite)).
Str("ice_servers", strings.Join(manager.config.ICEServers, ",")).
Str("ephemeral_port_range", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)).
Str("nat_ips", strings.Join(manager.config.NAT1To1IPs, ",")).
Msgf("webrtc starting")
}
func (manager *WebRTCManager) Shutdown() error {
manager.logger.Info().Msgf("webrtc shutting down")
return nil
}
func (manager *WebRTCManager) CreatePeer(id string, session types.Session) (string, bool, []string, error) {
configuration := &webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
URLs: manager.config.ICEServers,
},
},
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
}
settings := webrtc.SettingEngine{
LoggerFactory: loggerFactory{
logger: manager.logger,
},
}
if manager.config.ICELite {
configuration = &webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlanWithFallback,
}
settings.SetLite(true)
}
settings.SetEphemeralUDPPortRange(manager.config.EphemeralMin, manager.config.EphemeralMax)
settings.SetNAT1To1IPs(manager.config.NAT1To1IPs, webrtc.ICECandidateTypeHost)
// Create MediaEngine based off sdp
engine := webrtc.MediaEngine{}
engine.RegisterCodec(manager.audioCodec)
engine.RegisterCodec(manager.videoCodec)
// Create API with MediaEngine and SettingEngine
api := webrtc.NewAPI(webrtc.WithMediaEngine(engine), webrtc.WithSettingEngine(settings))
// Create new peer connection
connection, err := api.NewPeerConnection(*configuration)
if err != nil {
return "", manager.config.ICELite, manager.config.ICEServers, err
}
if _, err = connection.AddTransceiverFromTrack(manager.videoTrack, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
}); err != nil {
return "", manager.config.ICELite, manager.config.ICEServers, err
}
if _, err = connection.AddTransceiverFromTrack(manager.audioTrack, webrtc.RtpTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionSendonly,
}); err != nil {
return "", manager.config.ICELite, manager.config.ICEServers, err
}
description, err := connection.CreateOffer(nil)
if err != nil {
return "", manager.config.ICELite, manager.config.ICEServers, err
}
connection.OnDataChannel(func(d *webrtc.DataChannel) {
d.OnMessage(func(msg webrtc.DataChannelMessage) {
if err = manager.handle(id, msg); err != nil {
manager.logger.Warn().Err(err).Msg("data handle failed")
}
})
})
connection.SetLocalDescription(description)
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateDisconnected:
case webrtc.PeerConnectionStateFailed:
manager.logger.Info().Str("id", id).Msg("peer disconnected")
manager.sessions.Destroy(id)
break
case webrtc.PeerConnectionStateConnected:
manager.logger.Info().Str("id", id).Msg("peer connected")
if err = session.SetConnected(true); err != nil {
manager.logger.Warn().Err(err).Msg("unable to set connected on peer")
manager.sessions.Destroy(id)
}
break
}
})
if err := session.SetPeer(&Peer{
id: id,
api: api,
engine: &engine,
manager: manager,
settings: &settings,
connection: connection,
configuration: configuration,
}); err != nil {
return "", manager.config.ICELite, manager.config.ICEServers, err
}
return description.SDP, manager.config.ICELite, manager.config.ICEServers, nil
}
func (m *WebRTCManager) createTrack(codecName string) (*webrtc.Track, *webrtc.RTPCodec, error) {
var codec *webrtc.RTPCodec
switch codecName {
case webrtc.VP8:
codec = webrtc.NewRTPVP8Codec(webrtc.DefaultPayloadTypeVP8, 90000)
case webrtc.VP9:
codec = webrtc.NewRTPVP9Codec(webrtc.DefaultPayloadTypeVP9, 90000)
case webrtc.H264:
codec = webrtc.NewRTPH264Codec(webrtc.DefaultPayloadTypeH264, 90000)
case webrtc.Opus:
codec = webrtc.NewRTPOpusCodec(webrtc.DefaultPayloadTypeOpus, 48000)
case webrtc.G722:
codec = webrtc.NewRTPG722Codec(webrtc.DefaultPayloadTypeG722, 8000)
case webrtc.PCMU:
codec = webrtc.NewRTPPCMUCodec(webrtc.DefaultPayloadTypePCMU, 8000)
case webrtc.PCMA:
codec = webrtc.NewRTPPCMACodec(webrtc.DefaultPayloadTypePCMA, 8000)
default:
return nil, nil, fmt.Errorf("unknown codec %s", codecName)
}
track, err := webrtc.NewTrack(codec.PayloadType, rand.Uint32(), "stream", "stream", codec)
if err != nil {
return nil, nil, err
}
return track, codec, nil
}

298
internal/websocket/admin.go Normal file
View File

@ -0,0 +1,298 @@
package websocket
import (
"strings"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) adminLock(id string, session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if h.locked {
h.logger.Debug().Msg("server already locked...")
return nil
}
h.locked = true
if err := h.sessions.Broadcast(
message.Admin{
Event: event.ADMIN_LOCK,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_LOCK)
return err
}
return nil
}
func (h *MessageHandler) adminUnlock(id string, session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if !h.locked {
h.logger.Debug().Msg("server not locked...")
return nil
}
h.locked = false
if err := h.sessions.Broadcast(
message.Admin{
Event: event.ADMIN_UNLOCK,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_UNLOCK)
return err
}
return nil
}
func (h *MessageHandler) adminControl(id string, session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
host, ok := h.sessions.GetHost()
h.sessions.SetHost(id)
if ok {
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_CONTROL,
ID: id,
Target: host.ID(),
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_CONTROL)
return err
}
} else {
if err := h.sessions.Broadcast(
message.Admin{
Event: event.ADMIN_CONTROL,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_CONTROL)
return err
}
}
return nil
}
func (h *MessageHandler) adminRelease(id string, session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
host, ok := h.sessions.GetHost()
h.sessions.ClearHost()
if ok {
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_RELEASE,
ID: id,
Target: host.ID(),
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_RELEASE)
return err
}
} else {
if err := h.sessions.Broadcast(
message.Admin{
Event: event.ADMIN_RELEASE,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_RELEASE)
return err
}
}
return nil
}
func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if !h.sessions.Has(payload.ID) {
h.logger.Debug().Str("id", payload.ID).Msg("user does not exist")
return nil
}
// set host
h.sessions.SetHost(payload.ID)
// let everyone know
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.CONTROL_GIVE,
ID: id,
Target: payload.ID,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_LOCKED)
return err
}
return nil
}
func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
target, ok := h.sessions.Get(payload.ID)
if !ok {
h.logger.Debug().Str("id", payload.ID).Msg("can't find session id")
return nil
}
if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling")
return nil
}
target.SetMuted(true)
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_MUTE,
Target: target.ID(),
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_UNMUTE)
return err
}
return nil
}
func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
target, ok := h.sessions.Get(payload.ID)
if !ok {
h.logger.Debug().Str("id", payload.ID).Msg("can't find target session")
return nil
}
target.SetMuted(false)
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_UNMUTE,
Target: target.ID(),
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_UNMUTE)
return err
}
return nil
}
func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
target, ok := h.sessions.Get(payload.ID)
if !ok {
h.logger.Debug().Str("id", payload.ID).Msg("can't find session id")
return nil
}
if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling")
return nil
}
if err := target.Kick("kicked"); err != nil {
return err
}
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_KICK,
Target: target.ID(),
ID: id,
}, []string{payload.ID}); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_KICK)
return err
}
return nil
}
func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
target, ok := h.sessions.Get(payload.ID)
if !ok {
h.logger.Debug().Str("id", payload.ID).Msg("can't find session id")
return nil
}
if target.Admin() {
h.logger.Debug().Msg("target is an admin, baling")
return nil
}
remote := target.Address()
if remote == "" {
h.logger.Debug().Msg("no remote address, baling")
return nil
}
address := strings.SplitN(remote, ":", -1)
if len(address[0]) < 1 {
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
return nil
}
h.logger.Debug().Str("address", remote).Msg("adding address to banned")
h.banned[address[0]] = true
if err := target.Kick("banned"); err != nil {
return err
}
if err := h.sessions.Broadcast(
message.AdminTarget{
Event: event.ADMIN_BAN,
Target: target.ID(),
ID: id,
}, []string{payload.ID}); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.ADMIN_BAN)
return err
}
return nil
}

View File

@ -0,0 +1,56 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) boradcastCreate(session types.Session, payload *message.BroadcastCreate) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
h.broadcast.Create(payload.URL)
if err := h.boradcastStatus(session); err != nil {
return err
}
return nil
}
func (h *MessageHandler) boradcastDestroy(session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
h.broadcast.Destroy()
if err := h.boradcastStatus(session); err != nil {
return err
}
return nil
}
func (h *MessageHandler) boradcastStatus(session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if err := session.Send(
message.BroadcastStatus{
Event: event.BORADCAST_STATUS,
IsActive: h.broadcast.IsActive(),
URL: h.broadcast.GetUrl(),
}); err != nil {
h.logger.Warn().Err(err).Msgf("sending event %s has failed", event.BORADCAST_STATUS)
return err
}
return nil
}

View File

@ -0,0 +1,41 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatReceive) error {
if session.Muted() {
return nil
}
if err := h.sessions.Broadcast(
message.ChatSend{
Event: event.CHAT_MESSAGE,
Content: payload.Content,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE)
return err
}
return nil
}
func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteReceive) error {
if session.Muted() {
return nil
}
if err := h.sessions.Broadcast(
message.EmoteSend{
Event: event.CHAT_EMOTE,
Emote: payload.Emote,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE)
return err
}
return nil
}

View File

@ -0,0 +1,163 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) controlRelease(id string, session types.Session) error {
// check if session is host
if !h.sessions.IsHost(id) {
h.logger.Debug().Str("id", id).Msg("is not the host")
return nil
}
// release host
h.logger.Debug().Str("id", id).Msgf("host called %s", event.CONTROL_RELEASE)
h.sessions.ClearHost()
// tell everyone
if err := h.sessions.Broadcast(
message.Control{
Event: event.CONTROL_RELEASE,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE)
return err
}
return nil
}
func (h *MessageHandler) controlRequest(id string, session types.Session) error {
// check for host
if !h.sessions.HasHost() {
// set host
h.sessions.SetHost(id)
// let everyone know
if err := h.sessions.Broadcast(
message.Control{
Event: event.CONTROL_LOCKED,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_LOCKED)
return err
}
return nil
}
// get host
host, ok := h.sessions.GetHost()
if ok {
// tell session there is a host
if err := session.Send(message.Control{
Event: event.CONTROL_REQUEST,
ID: host.ID(),
}); err != nil {
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
return err
}
// tell host session wants to be host
if err := host.Send(message.Control{
Event: event.CONTROL_REQUESTING,
ID: id,
}); err != nil {
h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
return err
}
}
return nil
}
func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error {
// check if session is host
if !h.sessions.IsHost(id) {
h.logger.Debug().Str("id", id).Msg("is not the host")
return nil
}
if !h.sessions.Has(payload.ID) {
h.logger.Debug().Str("id", payload.ID).Msg("user does not exist")
return nil
}
// set host
h.sessions.SetHost(payload.ID)
// let everyone know
if err := h.sessions.Broadcast(
message.ControlTarget{
Event: event.CONTROL_GIVE,
ID: id,
Target: payload.ID,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_LOCKED)
return err
}
return nil
}
func (h *MessageHandler) controlClipboard(id string, session types.Session, payload *message.Clipboard) error {
// check if session is host
if !h.sessions.IsHost(id) {
h.logger.Debug().Str("id", id).Msg("is not the host")
return nil
}
h.remote.WriteClipboard(payload.Text)
return nil
}
func (h *MessageHandler) controlKeyboard(id string, session types.Session, payload *message.Keyboard) error {
// check if session is host
if !h.sessions.IsHost(id) {
h.logger.Debug().Str("id", id).Msg("is not the host")
return nil
}
// change layout
if payload.Layout != nil {
h.remote.SetKeyboardLayout(*payload.Layout)
}
// set num lock
var NumLock = 0
if payload.NumLock == nil {
NumLock = -1
} else if *payload.NumLock {
NumLock = 1
}
// set caps lock
var CapsLock = 0
if payload.CapsLock == nil {
CapsLock = -1
} else if *payload.CapsLock {
CapsLock = 1
}
// set scroll lock
var ScrollLock = 0
if payload.ScrollLock == nil {
ScrollLock = -1
} else if *payload.ScrollLock {
ScrollLock = 1
}
h.logger.Debug().
Int("NumLock", NumLock).
Int("CapsLock", CapsLock).
Int("ScrollLock", ScrollLock).
Msg("setting keyboard modifiers")
h.remote.SetKeyboardModifiers(NumLock, CapsLock, ScrollLock)
return nil
}

View File

@ -0,0 +1,179 @@
package websocket
import (
"encoding/json"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
"n.eko.moe/neko/internal/utils"
)
type MessageHandler struct {
logger zerolog.Logger
sessions types.SessionManager
webrtc types.WebRTCManager
remote types.RemoteManager
broadcast types.BroadcastManager
banned map[string]bool
locked bool
}
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
address := socket.Address()
if address == "" {
h.logger.Debug().Msg("no remote address")
} else {
ok, banned := h.banned[address]
if ok && banned {
h.logger.Debug().Str("address", address).Msg("banned")
return false, "banned", nil
}
}
if h.locked {
session, ok := h.sessions.Get(id)
if !ok || !session.Admin() {
h.logger.Debug().Msg("server locked")
return false, "locked", nil
}
}
return true, "", nil
}
func (h *MessageHandler) Disconnected(id string) error {
if h.locked && len(h.sessions.Admins()) == 0 {
h.locked = false
}
return h.sessions.Destroy(id)
}
func (h *MessageHandler) Message(id string, raw []byte) error {
header := message.Message{}
if err := json.Unmarshal(raw, &header); err != nil {
return err
}
session, ok := h.sessions.Get(id)
if !ok {
errors.Errorf("unknown session id %s", id)
}
switch header.Event {
// Signal Events
case event.SIGNAL_ANSWER:
payload := &message.SignalAnswer{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.signalAnswer(id, session, payload)
}), "%s failed", header.Event)
// Control Events
case event.CONTROL_RELEASE:
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
case event.CONTROL_REQUEST:
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
case event.CONTROL_GIVE:
payload := &message.Control{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.controlGive(id, session, payload)
}), "%s failed", header.Event)
case event.CONTROL_CLIPBOARD:
payload := &message.Clipboard{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.controlClipboard(id, session, payload)
}), "%s failed", header.Event)
case event.CONTROL_KEYBOARD:
payload := &message.Keyboard{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.controlKeyboard(id, session, payload)
}), "%s failed", header.Event)
// Chat Events
case event.CHAT_MESSAGE:
payload := &message.ChatReceive{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chat(id, session, payload)
}), "%s failed", header.Event)
case event.CHAT_EMOTE:
payload := &message.EmoteReceive{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.chatEmote(id, session, payload)
}), "%s failed", header.Event)
// Screen Events
case event.SCREEN_RESOLUTION:
return errors.Wrapf(h.screenResolution(id, session), "%s failed", header.Event)
case event.SCREEN_CONFIGURATIONS:
return errors.Wrapf(h.screenConfigurations(id, session), "%s failed", header.Event)
case event.SCREEN_SET:
payload := &message.ScreenResolution{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.screenSet(id, session, payload)
}), "%s failed", header.Event)
// Boradcast Events
case event.BORADCAST_CREATE:
payload := &message.BroadcastCreate{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.boradcastCreate(session, payload)
}), "%s failed", header.Event)
case event.BORADCAST_DESTROY:
return errors.Wrapf(h.boradcastDestroy(session), "%s failed", header.Event)
// Admin Events
case event.ADMIN_LOCK:
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
case event.ADMIN_UNLOCK:
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
case event.ADMIN_CONTROL:
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
case event.ADMIN_RELEASE:
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
case event.ADMIN_GIVE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminGive(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_BAN:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminBan(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_KICK:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminKick(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_MUTE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminMute(id, session, payload)
}), "%s failed", header.Event)
case event.ADMIN_UNMUTE:
payload := &message.Admin{}
return errors.Wrapf(
utils.Unmarshal(payload, raw, func() error {
return h.adminUnmute(id, session, payload)
}), "%s failed", header.Event)
default:
return errors.Errorf("unknown message event %s", header.Event)
}
}

View File

@ -0,0 +1,66 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) screenSet(id string, session types.Session, payload *message.ScreenResolution) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if err := h.remote.ChangeResolution(payload.Width, payload.Height, payload.Rate); err != nil {
h.logger.Warn().Err(err).Msgf("unable to change screen size")
return err
}
if err := h.sessions.Broadcast(
message.ScreenResolution{
Event: event.SCREEN_RESOLUTION,
ID: id,
Width: payload.Width,
Height: payload.Height,
Rate: payload.Rate,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.SCREEN_RESOLUTION)
return err
}
return nil
}
func (h *MessageHandler) screenResolution(id string, session types.Session) error {
if size := h.remote.GetScreenSize(); size != nil {
if err := session.Send(message.ScreenResolution{
Event: event.SCREEN_RESOLUTION,
Width: size.Width,
Height: size.Height,
Rate: int(size.Rate),
}); err != nil {
h.logger.Warn().Err(err).Msgf("sending event %s has failed", event.SCREEN_RESOLUTION)
return err
}
}
return nil
}
func (h *MessageHandler) screenConfigurations(id string, session types.Session) error {
if !session.Admin() {
h.logger.Debug().Msg("user not admin")
return nil
}
if err := session.Send(message.ScreenConfigurations{
Event: event.SCREEN_CONFIGURATIONS,
Configurations: h.remote.ScreenConfigurations(),
}); err != nil {
h.logger.Warn().Err(err).Msgf("sending event %s has failed", event.SCREEN_CONFIGURATIONS)
return err
}
return nil
}

View File

@ -0,0 +1,93 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) SessionCreated(id string, session types.Session) error {
// send sdp and id over to client
if err := h.signalProvide(id, session); err != nil {
return err
}
if session.Admin() {
// send screen configurations if admin
if err := h.screenConfigurations(id, session); err != nil {
return err
}
// send broadcast status if admin
if err := h.boradcastStatus(session); err != nil {
return err
}
}
return nil
}
func (h *MessageHandler) SessionConnected(id string, session types.Session) error {
// send list of members to session
if err := session.Send(message.MembersList{
Event: event.MEMBER_LIST,
Memebers: h.sessions.Members(),
}); err != nil {
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
return err
}
// send screen current resolution
if err := h.screenResolution(id, session); err != nil {
return err
}
// tell session there is a host
host, ok := h.sessions.GetHost()
if ok {
if err := session.Send(message.Control{
Event: event.CONTROL_LOCKED,
ID: host.ID(),
}); err != nil {
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
return err
}
}
// let everyone know there is a new session
if err := h.sessions.Broadcast(
message.Member{
Event: event.MEMBER_CONNECTED,
Member: session.Member(),
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE)
return err
}
return nil
}
func (h *MessageHandler) SessionDestroyed(id string) error {
// clear host if exists
if h.sessions.IsHost(id) {
h.sessions.ClearHost()
if err := h.sessions.Broadcast(message.Control{
Event: event.CONTROL_RELEASE,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.CONTROL_RELEASE)
}
}
// let everyone know session disconnected
if err := h.sessions.Broadcast(
message.MemberDisconnected{
Event: event.MEMBER_DISCONNECTED,
ID: id,
}, nil); err != nil {
h.logger.Warn().Err(err).Msgf("broadcasting event %s has failed", event.MEMBER_DISCONNECTED)
return err
}
return nil
}

View File

@ -0,0 +1,38 @@
package websocket
import (
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
)
func (h *MessageHandler) signalProvide(id string, session types.Session) error {
sdp, lite, ice, err := h.webrtc.CreatePeer(id, session)
if err != nil {
return err
}
if err := session.Send(message.SignalProvide{
Event: event.SIGNAL_PROVIDE,
ID: id,
SDP: sdp,
Lite: lite,
ICE: ice,
}); err != nil {
return err
}
return nil
}
func (h *MessageHandler) signalAnswer(id string, session types.Session, payload *message.SignalAnswer) error {
if err := session.SetName(payload.DisplayName); err != nil {
return err
}
if err := session.SignalAnswer(payload.SDP); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,55 @@
package websocket
import (
"encoding/json"
"strings"
"sync"
"github.com/gorilla/websocket"
)
type WebSocket struct {
id string
address string
ws *WebSocketHandler
connection *websocket.Conn
mu sync.Mutex
}
func (socket *WebSocket) Address() string {
//remote := socket.connection.RemoteAddr()
address := strings.SplitN(socket.address, ":", -1)
if len(address[0]) < 1 {
return socket.address
}
return address[0]
}
func (socket *WebSocket) Send(v interface{}) error {
socket.mu.Lock()
defer socket.mu.Unlock()
if socket.connection == nil {
return nil
}
raw, err := json.Marshal(v)
if err != nil {
return err
}
socket.ws.logger.Debug().
Str("session", socket.id).
Str("address", socket.connection.RemoteAddr().String()).
Str("raw", string(raw)).
Msg("sending message to client")
return socket.connection.WriteMessage(websocket.TextMessage, raw)
}
func (socket *WebSocket) Destroy() error {
if socket.connection == nil {
return nil
}
return socket.connection.Close()
}

View File

@ -0,0 +1,268 @@
package websocket
import (
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"n.eko.moe/neko/internal/types"
"n.eko.moe/neko/internal/types/config"
"n.eko.moe/neko/internal/types/event"
"n.eko.moe/neko/internal/types/message"
"n.eko.moe/neko/internal/utils"
)
func New(sessions types.SessionManager, remote types.RemoteManager, broadcast types.BroadcastManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
logger := log.With().Str("module", "websocket").Logger()
return &WebSocketHandler{
logger: logger,
conf: conf,
sessions: sessions,
remote: remote,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
handler: &MessageHandler{
logger: logger.With().Str("subsystem", "handler").Logger(),
remote: remote,
broadcast: broadcast,
sessions: sessions,
webrtc: webrtc,
banned: make(map[string]bool),
locked: false,
},
}
}
// Send pings to peer with this period. Must be less than pongWait.
const pingPeriod = 60 * time.Second
type WebSocketHandler struct {
logger zerolog.Logger
upgrader websocket.Upgrader
sessions types.SessionManager
remote types.RemoteManager
conf *config.WebSocket
handler *MessageHandler
shutdown chan bool
}
func (ws *WebSocketHandler) Start() error {
ws.sessions.OnCreated(func(id string, session types.Session) {
if err := ws.handler.SessionCreated(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session created")
}
})
ws.sessions.OnConnected(func(id string, session types.Session) {
if err := ws.handler.SessionConnected(id, session); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session connected")
}
})
ws.sessions.OnDestroy(func(id string, session types.Session) {
if err := ws.handler.SessionDestroyed(id); err != nil {
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
} else {
ws.logger.Debug().Str("id", id).Msg("session destroyed")
}
})
go func() {
defer func() {
ws.logger.Info().Msg("shutdown")
}()
current := ws.remote.ReadClipboard()
for {
select {
case <-ws.shutdown:
return
default:
if ws.sessions.HasHost() {
text := ws.remote.ReadClipboard()
if text != current {
session, ok := ws.sessions.GetHost()
if ok {
session.Send(message.Clipboard{
Event: event.CONTROL_CLIPBOARD,
Text: text,
})
}
current = text
}
}
time.Sleep(100 * time.Millisecond)
}
}
}()
return nil
}
func (ws *WebSocketHandler) Shutdown() error {
ws.shutdown <- true
return nil
}
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
ws.logger.Debug().Msg("attempting to upgrade connection")
connection, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err
}
id, ip, admin, err := ws.authenticate(r)
if err != nil {
ws.logger.Warn().Err(err).Msg("authentication failed")
if err = connection.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: "invalid_password",
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
}
if err = connection.Close(); err != nil {
return err
}
return nil
}
socket := &WebSocket{
id: id,
ws: ws,
address: ip,
connection: connection,
}
ok, reason, err := ws.handler.Connected(id, socket)
if err != nil {
ws.logger.Error().Err(err).Msg("connection failed")
return err
}
if !ok {
if err = connection.WriteJSON(message.Disconnect{
Event: event.SYSTEM_DISCONNECT,
Message: reason,
}); err != nil {
ws.logger.Error().Err(err).Msg("failed to send disconnect")
}
if err = connection.Close(); err != nil {
return err
}
return nil
}
ws.sessions.New(id, admin, socket)
ws.logger.
Debug().
Str("session", id).
Str("address", connection.RemoteAddr().String()).
Msg("new connection created")
defer func() {
ws.logger.
Debug().
Str("session", id).
Str("address", connection.RemoteAddr().String()).
Msg("session ended")
}()
ws.handle(connection, id)
return nil
}
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, string, bool, error) {
ip := r.RemoteAddr
if ws.conf.Proxy {
ip = utils.ReadUserIP(r)
}
id, err := utils.NewUID(32)
if err != nil {
return "", ip, false, err
}
passwords, ok := r.URL.Query()["password"]
if !ok || len(passwords[0]) < 1 {
return "", ip, false, fmt.Errorf("no password provided")
}
if passwords[0] == ws.conf.AdminPassword {
return id, ip, true, nil
}
if passwords[0] == ws.conf.Password {
return id, ip, false, nil
}
return "", ip, false, fmt.Errorf("invalid password: %s", passwords[0])
}
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
bytes := make(chan []byte)
cancel := make(chan struct{})
ticker := time.NewTicker(pingPeriod)
go func() {
defer func() {
ticker.Stop()
ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending")
ws.handler.Disconnected(id)
}()
for {
_, raw, err := connection.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
ws.logger.Warn().Err(err).Msg("read message error")
} else {
ws.logger.Debug().Err(err).Msg("read message error")
}
close(cancel)
break
}
bytes <- raw
}
}()
for {
select {
case raw := <-bytes:
ws.logger.Debug().
Str("session", id).
Str("address", connection.RemoteAddr().String()).
Str("raw", string(raw)).
Msg("received message from client")
if err := ws.handler.Message(id, raw); err != nil {
ws.logger.Error().Err(err).Msg("message handler has failed")
}
case <-cancel:
return
case _ = <-ticker.C:
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}

195
internal/xorg/xorg.c Normal file
View File

@ -0,0 +1,195 @@
#include "xorg.h"
static clipboard_c *CLIPBOARD = NULL;
static Display *DISPLAY = NULL;
static char *NAME = ":0.0";
static int REGISTERED = 0;
static int DIRTY = 0;
Display *getXDisplay(void) {
/* Close the display if displayName has changed */
if (DIRTY) {
XDisplayClose();
DIRTY = 0;
}
if (DISPLAY == NULL) {
/* First try the user set displayName */
DISPLAY = XOpenDisplay(NAME);
/* Then try using environment variable DISPLAY */
if (DISPLAY == NULL) {
DISPLAY = XOpenDisplay(NULL);
}
if (DISPLAY == NULL) {
fputs("Could not open main display\n", stderr);
} else if (!REGISTERED) {
atexit(&XDisplayClose);
REGISTERED = 1;
}
}
return DISPLAY;
}
clipboard_c *getClipboard(void) {
if (CLIPBOARD == NULL) {
CLIPBOARD = clipboard_new(NULL);
}
return CLIPBOARD;
}
void XDisplayClose(void) {
if (DISPLAY != NULL) {
XCloseDisplay(DISPLAY);
DISPLAY = NULL;
}
}
void XDisplaySet(char *input) {
NAME = strdup(input);
DIRTY = 1;
}
void XMove(int x, int y) {
Display *display = getXDisplay();
XWarpPointer(display, None, DefaultRootWindow(display), 0, 0, 0, 0, x, y);
XSync(display, 0);
}
void XScroll(int x, int y) {
int ydir = 4; /* Button 4 is up, 5 is down. */
int xdir = 6;
Display *display = getXDisplay();
if (y < 0) {
ydir = 5;
}
if (x < 0) {
xdir = 7;
}
int xi;
int yi;
for (xi = 0; xi < abs(x); xi++) {
XTestFakeButtonEvent(display, xdir, 1, CurrentTime);
XTestFakeButtonEvent(display, xdir, 0, CurrentTime);
}
for (yi = 0; yi < abs(y); yi++) {
XTestFakeButtonEvent(display, ydir, 1, CurrentTime);
XTestFakeButtonEvent(display, ydir, 0, CurrentTime);
}
XSync(display, 0);
}
void XButton(unsigned int button, int down) {
if (button != 0) {
Display *display = getXDisplay();
XTestFakeButtonEvent(display, button, down, CurrentTime);
XSync(display, 0);
}
}
void XKey(unsigned long key, int down) {
if (key != 0) {
Display *display = getXDisplay();
KeyCode code = XKeysymToKeycode(display, key);
// Map non-existing keysyms to new keycodes
if(code == 0) {
int min, max, numcodes;
XDisplayKeycodes(display, &min, &max);
XGetKeyboardMapping(display, min, max-min, &numcodes);
code = (max-min+1)*numcodes;
KeySym keysym_list[numcodes];
for(int i=0;i<numcodes;i++) keysym_list[i] = key;
XChangeKeyboardMapping(display, code, numcodes, keysym_list, 1);
}
XTestFakeKeyEvent(display, code, down, CurrentTime);
XSync(display, 0);
}
}
void XClipboardSet(char *src) {
clipboard_c *cb = getClipboard();
clipboard_set_text_ex(cb, src, strlen(src), 0);
}
char *XClipboardGet() {
clipboard_c *cb = getClipboard();
return clipboard_text_ex(cb, NULL, 0);
}
void XGetScreenConfigurations() {
Display *display = getXDisplay();
Window root = RootWindow(display, 0);
XRRScreenSize *xrrs;
int num_sizes;
xrrs = XRRSizes(display, 0, &num_sizes);
for(int i = 0; i < num_sizes; i ++) {
short *rates;
int num_rates;
goCreateScreenSize(i, xrrs[i].width, xrrs[i].height, xrrs[i].mwidth, xrrs[i].mheight);
rates = XRRRates(display, 0, i, &num_rates);
for (int j = 0; j < num_rates; j ++) {
goSetScreenRates(i, j, rates[j]);
}
}
}
void XSetScreenConfiguration(int index, short rate) {
Display *display = getXDisplay();
Window root = RootWindow(display, 0);
XRRSetScreenConfigAndRate(display, XRRGetScreenInfo(display, root), root, index, RR_Rotate_0, rate, CurrentTime);
}
int XGetScreenSize() {
Display *display = getXDisplay();
XRRScreenConfiguration *conf = XRRGetScreenInfo(display, RootWindow(display, 0));
Rotation original_rotation;
return XRRConfigCurrentConfiguration(conf, &original_rotation);
}
short XGetScreenRate() {
Display *display = getXDisplay();
XRRScreenConfiguration *conf = XRRGetScreenInfo(display, RootWindow(display, 0));
return XRRConfigCurrentRate(conf);
}
void SetKeyboardLayout(char *layout) {
// TOOD: refactor, use native API.
char cmd[13] = "setxkbmap ";
strncat(cmd, layout, 2);
system(cmd);
}
void SetKeyboardModifiers(int num_lock, int caps_lock, int scroll_lock) {
Display *display = getXDisplay();
if (num_lock != -1) {
XkbLockModifiers(display, XkbUseCoreKbd, 16, num_lock * 16);
}
if (caps_lock != -1) {
XkbLockModifiers(display, XkbUseCoreKbd, 2, caps_lock * 2);
}
if (scroll_lock != -1) {
XKeyboardControl values;
values.led_mode = scroll_lock ? LedModeOn : LedModeOff;
values.led = 3;
XChangeKeyboardControl(display, KBLedMode, &values);
}
XFlush(display);
}

247
internal/xorg/xorg.go Normal file
View File

@ -0,0 +1,247 @@
package xorg
/*
#cgo linux CFLAGS: -I/usr/src -I/usr/local/include/
#cgo linux LDFLAGS: /usr/local/lib/libclipboard.a -L/usr/src -L/usr/local/lib -lX11 -lXtst -lXrandr -lxcb
#include "xorg.h"
*/
import "C"
import (
"fmt"
"sync"
"time"
"unsafe"
"regexp"
"n.eko.moe/neko/internal/types"
)
var ScreenConfigurations = make(map[int]types.ScreenConfiguration)
var debounce_button = make(map[int]time.Time)
var debounce_key = make(map[uint64]time.Time)
var mu = sync.Mutex{}
func init() {
C.XGetScreenConfigurations()
}
func Display(display string) {
mu.Lock()
defer mu.Unlock()
displayUnsafe := C.CString(display)
defer C.free(unsafe.Pointer(displayUnsafe))
C.XDisplaySet(displayUnsafe)
}
func Move(x, y int) {
mu.Lock()
defer mu.Unlock()
C.XMove(C.int(x), C.int(y))
}
func Scroll(x, y int) {
mu.Lock()
defer mu.Unlock()
C.XScroll(C.int(x), C.int(y))
}
func ButtonDown(code int) error {
mu.Lock()
defer mu.Unlock()
if _, ok := debounce_button[code]; ok {
return fmt.Errorf("debounced button %v", code)
}
debounce_button[code] = time.Now()
C.XButton(C.uint(code), C.int(1))
return nil
}
func KeyDown(code uint64) error {
mu.Lock()
defer mu.Unlock()
if _, ok := debounce_key[code]; ok {
return fmt.Errorf("debounced key %v", code)
}
debounce_key[code] = time.Now()
C.XKey(C.ulong(code), C.int(1))
return nil
}
func ButtonUp(code int) error {
mu.Lock()
defer mu.Unlock()
if _, ok := debounce_button[code]; !ok {
return fmt.Errorf("debounced button %v", code)
}
delete(debounce_button, code)
C.XButton(C.uint(code), C.int(0))
return nil
}
func KeyUp(code uint64) error {
mu.Lock()
defer mu.Unlock()
if _, ok := debounce_key[code]; !ok {
return fmt.Errorf("debounced key %v", code)
}
delete(debounce_key, code)
C.XKey(C.ulong(code), C.int(0))
return nil
}
func ReadClipboard() string {
mu.Lock()
defer mu.Unlock()
clipboardUnsafe := C.XClipboardGet()
defer C.free(unsafe.Pointer(clipboardUnsafe))
return C.GoString(clipboardUnsafe)
}
func WriteClipboard(data string) {
mu.Lock()
defer mu.Unlock()
clipboardUnsafe := C.CString(data)
defer C.free(unsafe.Pointer(clipboardUnsafe))
C.XClipboardSet(clipboardUnsafe)
}
func ResetKeys() {
for code := range debounce_button {
ButtonUp(code)
delete(debounce_button, code)
}
for code := range debounce_key {
KeyUp(code)
delete(debounce_key, code)
}
}
func CheckKeys(duration time.Duration) {
t := time.Now()
for code, start := range debounce_button {
if t.Sub(start) < duration {
continue
}
ButtonUp(code)
delete(debounce_button, code)
}
for code, start := range debounce_key {
if t.Sub(start) < duration {
continue
}
KeyUp(code)
delete(debounce_key, code)
}
}
func ValidScreenSize(width int, height int, rate int) bool {
for _, size := range ScreenConfigurations {
if size.Width == width && size.Height == height {
for _, fps := range size.Rates {
if int16(rate) == fps {
return true
}
}
}
}
return false
}
func ChangeScreenSize(width int, height int, rate int) error {
mu.Lock()
defer mu.Unlock()
for index, size := range ScreenConfigurations {
if size.Width == width && size.Height == height {
for _, fps := range size.Rates {
if int16(rate) == fps {
C.XSetScreenConfiguration(C.int(index), C.short(fps))
return nil
}
}
}
}
return fmt.Errorf("unknown configuration")
}
func GetScreenSize() *types.ScreenSize {
mu.Lock()
defer mu.Unlock()
index := int(C.XGetScreenSize())
rate := int16(C.XGetScreenRate())
if conf, ok := ScreenConfigurations[index]; ok {
return &types.ScreenSize{
Width: conf.Width,
Height: conf.Height,
Rate: rate,
}
}
return nil
}
func SetKeyboardLayout(layout string) {
mu.Lock()
defer mu.Unlock()
if !regexp.MustCompile(`^[a-zA-Z]+$`).MatchString(layout) {
return
}
layoutUnsafe := C.CString(layout)
defer C.free(unsafe.Pointer(layoutUnsafe))
C.SetKeyboardLayout(layoutUnsafe)
}
func SetKeyboardModifiers(num_lock int, caps_lock int, scroll_lock int) {
mu.Lock()
defer mu.Unlock()
C.SetKeyboardModifiers(C.int(num_lock), C.int(caps_lock), C.int(scroll_lock))
}
//export goCreateScreenSize
func goCreateScreenSize(index C.int, width C.int, height C.int, mwidth C.int, mheight C.int) {
ScreenConfigurations[int(index)] = types.ScreenConfiguration{
Width: int(width),
Height: int(height),
Rates: make(map[int]int16),
}
}
//export goSetScreenRates
func goSetScreenRates(index C.int, rate_index C.int, rate C.short) {
ScreenConfigurations[int(index)].Rates[int(rate_index)] = int16(rate)
}

46
internal/xorg/xorg.h Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#ifndef XDISPLAY_H
#define XDISPLAY_H
#include <X11/Xlib.h>
#include <X11/XKBlib.h>
#include <X11/extensions/Xrandr.h>
#include <X11/extensions/XTest.h>
#include <libclipboard.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h> /* For fputs() */
#include <string.h> /* For strdup() */
extern void goCreateScreenSize(int index, int width, int height, int mwidth, int mheight);
extern void goSetScreenRates(int index, int rate_index, short rate);
/* Returns the main display, closed either on exit or when closeMainDisplay()
* is invoked. This removes a bit of the overhead of calling XOpenDisplay() &
* XCloseDisplay() everytime the main display needs to be used.
*
* Note that this is almost certainly not thread safe. */
Display *getXDisplay(void);
clipboard_c *getClipboard(void);
void XMove(int x, int y);
void XScroll(int x, int y);
void XButton(unsigned int button, int down);
void XKey(unsigned long key, int down);
void XClipboardSet(char *src);
char *XClipboardGet();
void XGetScreenConfigurations();
void XSetScreenConfiguration(int index, short rate);
int XGetScreenSize();
short XGetScreenRate();
void XDisplayClose(void);
void XDisplaySet(char *input);
void SetKeyboardLayout(char *layout);
void SetKeyboardModifiers(int num_lock, int caps_lock, int scroll_lock);
#endif