move server to server directory.

This commit is contained in:
Miroslav Šedivý
2024-06-23 17:48:14 +02:00
parent da45f62ca8
commit 5b98344205
211 changed files with 18 additions and 10 deletions

View File

@ -0,0 +1,84 @@
package members
import (
"encoding/json"
"io"
"net/http"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type MemberBulkUpdatePayload struct {
IDs []string `json:"ids"`
Profile types.MemberProfile `json:"profile"`
}
func (h *MembersHandler) membersBulkUpdate(w http.ResponseWriter, r *http.Request) error {
bytes, err := io.ReadAll(r.Body)
if err != nil {
return utils.HttpBadRequest("unable to read post body").WithInternalErr(err)
}
header := &MemberBulkUpdatePayload{}
if err := json.Unmarshal(bytes, &header); err != nil {
return utils.HttpBadRequest("unable to unmarshal payload").WithInternalErr(err)
}
for _, memberId := range header.IDs {
// TODO: Bulk select?
profile, err := h.members.Select(memberId)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to select member profile").
Msgf("failed to update member %s", memberId)
}
body := &MemberBulkUpdatePayload{
Profile: profile,
}
if err := json.Unmarshal(bytes, &body); err != nil {
return utils.HttpBadRequest().
WithInternalErr(err).
Msgf("unable to unmarshal payload for member %s", memberId)
}
if err := h.members.UpdateProfile(memberId, body.Profile); err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to update member profile").
Msgf("failed to update member %s", memberId)
}
}
return utils.HttpSuccess(w)
}
type MemberBulkDeletePayload struct {
IDs []string `json:"ids"`
}
func (h *MembersHandler) membersBulkDelete(w http.ResponseWriter, r *http.Request) error {
bytes, err := io.ReadAll(r.Body)
if err != nil {
return utils.HttpBadRequest("unable to read post body").WithInternalErr(err)
}
data := &MemberBulkDeletePayload{}
if err := json.Unmarshal(bytes, &data); err != nil {
return utils.HttpBadRequest("unable to unmarshal payload").WithInternalErr(err)
}
for _, memberId := range data.IDs {
if err := h.members.Delete(memberId); err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to delete member").
Msgf("failed to delete member %s", memberId)
}
}
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,144 @@
package members
import (
"errors"
"net/http"
"strconv"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type MemberDataPayload struct {
ID string `json:"id"`
Profile types.MemberProfile `json:"profile"`
}
type MemberCreatePayload struct {
Username string `json:"username"`
Password string `json:"password"`
Profile types.MemberProfile `json:"profile"`
}
type MemberPasswordPayload struct {
Password string `json:"password"`
}
func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) error {
limit, err := strconv.Atoi(r.URL.Query().Get("limit"))
if err != nil {
// TODO: Default zero.
limit = 0
}
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil {
// TODO: Default zero.
offset = 0
}
entries, err := h.members.SelectAll(limit, offset)
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
members := []MemberDataPayload{}
for id, profile := range entries {
members = append(members, MemberDataPayload{
ID: id,
Profile: profile,
})
}
return utils.HttpSuccess(w, members)
}
func (h *MembersHandler) membersCreate(w http.ResponseWriter, r *http.Request) error {
data := &MemberCreatePayload{
// default values
Profile: types.MemberProfile{
IsAdmin: false,
CanLogin: true,
CanConnect: true,
CanWatch: true,
CanHost: true,
CanShareMedia: true,
CanAccessClipboard: true,
SendsInactiveCursor: true,
CanSeeInactiveCursors: true,
},
}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
if data.Username == "" {
return utils.HttpBadRequest("username cannot be empty")
}
if data.Password == "" {
return utils.HttpBadRequest("password cannot be empty")
}
id, err := h.members.Insert(data.Username, data.Password, data.Profile)
if err != nil {
if errors.Is(err, types.ErrMemberAlreadyExists) {
return utils.HttpUnprocessableEntity("member already exists")
}
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w, MemberDataPayload{
ID: id,
Profile: data.Profile,
})
}
func (h *MembersHandler) membersRead(w http.ResponseWriter, r *http.Request) error {
member := GetMember(r)
profile := member.Profile
return utils.HttpSuccess(w, profile)
}
func (h *MembersHandler) membersUpdateProfile(w http.ResponseWriter, r *http.Request) error {
member := GetMember(r)
data := &member.Profile
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
if err := h.members.UpdateProfile(member.ID, *data); err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w)
}
func (h *MembersHandler) membersUpdatePassword(w http.ResponseWriter, r *http.Request) error {
member := GetMember(r)
data := &MemberPasswordPayload{}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
if err := h.members.UpdatePassword(member.ID, data.Password); err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w)
}
func (h *MembersHandler) membersDelete(w http.ResponseWriter, r *http.Request) error {
member := GetMember(r)
if err := h.members.Delete(member.ID); err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,83 @@
package members
import (
"context"
"errors"
"net/http"
"github.com/go-chi/chi"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type key int
const keyMemberCtx key = iota
type MembersHandler struct {
members types.MemberManager
}
func New(
members types.MemberManager,
) *MembersHandler {
// Init
return &MembersHandler{
members: members,
}
}
func (h *MembersHandler) Route(r types.Router) {
r.Get("/", h.membersList)
r.With(auth.AdminsOnly).Group(func(r types.Router) {
r.Post("/", h.membersCreate)
r.With(h.ExtractMember).Route("/{memberId}", func(r types.Router) {
r.Get("/", h.membersRead)
r.Post("/", h.membersUpdateProfile)
r.Post("/password", h.membersUpdatePassword)
r.Delete("/", h.membersDelete)
})
})
}
func (h *MembersHandler) RouteBulk(r types.Router) {
r.With(auth.AdminsOnly).Group(func(r types.Router) {
r.Post("/update", h.membersBulkUpdate)
r.Post("/delete", h.membersBulkDelete)
})
}
type MemberData struct {
ID string
Profile types.MemberProfile
}
func SetMember(r *http.Request, session MemberData) context.Context {
return context.WithValue(r.Context(), keyMemberCtx, session)
}
func GetMember(r *http.Request) MemberData {
return r.Context().Value(keyMemberCtx).(MemberData)
}
func (h *MembersHandler) ExtractMember(w http.ResponseWriter, r *http.Request) (context.Context, error) {
memberId := chi.URLParam(r, "memberId")
profile, err := h.members.Select(memberId)
if err != nil {
if errors.Is(err, types.ErrMemberDoesNotExist) {
return nil, utils.HttpNotFound("member not found")
}
return nil, utils.HttpInternalServerError().WithInternalErr(err)
}
return SetMember(r, MemberData{
ID: memberId,
Profile: profile,
}), nil
}

View File

@ -0,0 +1,70 @@
package room
import (
"net/http"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
type BroadcastStatusPayload struct {
URL string `json:"url,omitempty"`
IsActive bool `json:"is_active"`
}
func (h *RoomHandler) broadcastStatus(w http.ResponseWriter, r *http.Request) error {
broadcast := h.capture.Broadcast()
return utils.HttpSuccess(w, BroadcastStatusPayload{
IsActive: broadcast.Started(),
URL: broadcast.Url(),
})
}
func (h *RoomHandler) boradcastStart(w http.ResponseWriter, r *http.Request) error {
data := &BroadcastStatusPayload{}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
if data.URL == "" {
return utils.HttpBadRequest("missing broadcast URL")
}
broadcast := h.capture.Broadcast()
if broadcast.Started() {
return utils.HttpUnprocessableEntity("server is already broadcasting")
}
if err := broadcast.Start(data.URL); err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
h.sessions.AdminBroadcast(
event.BORADCAST_STATUS,
message.BroadcastStatus{
IsActive: broadcast.Started(),
URL: broadcast.Url(),
})
return utils.HttpSuccess(w)
}
func (h *RoomHandler) boradcastStop(w http.ResponseWriter, r *http.Request) error {
broadcast := h.capture.Broadcast()
if !broadcast.Started() {
return utils.HttpUnprocessableEntity("server is not broadcasting")
}
broadcast.Stop()
h.sessions.AdminBroadcast(
event.BORADCAST_STATUS,
message.BroadcastStatus{
IsActive: broadcast.Started(),
URL: broadcast.Url(),
})
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,107 @@
package room
import (
// TODO: Unused now.
//"bytes"
//"strings"
"net/http"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type ClipboardPayload struct {
Text string `json:"text,omitempty"`
HTML string `json:"html,omitempty"`
}
func (h *RoomHandler) clipboardGetText(w http.ResponseWriter, r *http.Request) error {
data, err := h.desktop.ClipboardGetText()
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w, ClipboardPayload{
Text: data.Text,
HTML: data.HTML,
})
}
func (h *RoomHandler) clipboardSetText(w http.ResponseWriter, r *http.Request) error {
data := &ClipboardPayload{}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
err := h.desktop.ClipboardSetText(types.ClipboardText{
Text: data.Text,
HTML: data.HTML,
})
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w)
}
func (h *RoomHandler) clipboardGetImage(w http.ResponseWriter, r *http.Request) error {
bytes, err := h.desktop.ClipboardGetBinary("image/png")
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Content-Type", "image/png")
_, err = w.Write(bytes)
return err
}
/* TODO: Unused now.
func (h *RoomHandler) clipboardSetImage(w http.ResponseWriter, r *http.Request) error {
err := r.ParseMultipartForm(MAX_UPLOAD_SIZE)
if err != nil {
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
}
//nolint
defer r.MultipartForm.RemoveAll()
file, header, err := r.FormFile("file")
if err != nil {
return utils.HttpBadRequest("no file received").WithInternalErr(err)
}
defer file.Close()
mime := header.Header.Get("Content-Type")
if !strings.HasPrefix(mime, "image/") {
return utils.HttpBadRequest("file must be image")
}
buffer := new(bytes.Buffer)
_, err = buffer.ReadFrom(file)
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable to read from uploaded file")
}
err = h.desktop.ClipboardSetBinary("image/png", buffer.Bytes())
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable set image to clipboard")
}
return utils.HttpSuccess(w)
}
func (h *RoomHandler) clipboardGetTargets(w http.ResponseWriter, r *http.Request) error {
targets, err := h.desktop.ClipboardGetTargets()
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w, targets)
}
*/

View File

@ -0,0 +1,109 @@
package room
import (
"net/http"
"github.com/go-chi/chi"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
type ControlStatusPayload struct {
HasHost bool `json:"has_host"`
HostId string `json:"host_id,omitempty"`
}
type ControlTargetPayload struct {
ID string `json:"id"`
}
func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) error {
host, hasHost := h.sessions.GetHost()
var hostId string
if hasHost {
hostId = host.ID()
}
return utils.HttpSuccess(w, ControlStatusPayload{
HasHost: hasHost,
HostId: hostId,
})
}
func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
host, hasHost := h.sessions.GetHost()
if hasHost {
// TODO: Some throttling mechanism to prevent spamming.
// let host know that someone wants to take control
host.Send(
event.CONTROL_REQUEST,
message.SessionID{
ID: session.ID(),
})
return utils.HttpError(http.StatusAccepted, "control request sent")
}
if h.sessions.Settings().LockedControls && !session.Profile().IsAdmin {
return utils.HttpForbidden("controls are locked")
}
session.SetAsHost()
return utils.HttpSuccess(w)
}
func (h *RoomHandler) controlRelease(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
if !session.IsHost() {
return utils.HttpUnprocessableEntity("session is not the host")
}
h.desktop.ResetKeys()
session.ClearHost()
return utils.HttpSuccess(w)
}
func (h *RoomHandler) controlTake(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
session.SetAsHost()
return utils.HttpSuccess(w)
}
func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
sessionId := chi.URLParam(r, "sessionId")
target, ok := h.sessions.Get(sessionId)
if !ok {
return utils.HttpNotFound("target session was not found")
}
if !target.Profile().CanHost {
return utils.HttpBadRequest("target session is not allowed to host")
}
target.SetAsHostBy(session)
return utils.HttpSuccess(w)
}
func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
_, hasHost := h.sessions.GetHost()
if hasHost {
h.desktop.ResetKeys()
session.ClearHost()
}
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,126 @@
package room
import (
"context"
"net/http"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type RoomHandler struct {
sessions types.SessionManager
desktop types.DesktopManager
capture types.CaptureManager
privateModeImage []byte
}
func New(
sessions types.SessionManager,
desktop types.DesktopManager,
capture types.CaptureManager,
) *RoomHandler {
h := &RoomHandler{
sessions: sessions,
desktop: desktop,
capture: capture,
}
// generate fallback image for private mode when needed
sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) {
if old.PrivateMode && !new.PrivateMode {
log.Debug().Msg("clearing private mode fallback image")
h.privateModeImage = nil
return
}
if !old.PrivateMode && new.PrivateMode {
img := h.desktop.GetScreenshotImage()
bytes, err := utils.CreateJPGImage(img, 90)
if err != nil {
log.Err(err).Msg("could not generate private mode fallback image")
return
}
log.Debug().Msg("using private mode fallback image")
h.privateModeImage = bytes
}
})
return h
}
func (h *RoomHandler) Route(r types.Router) {
r.With(auth.AdminsOnly).Route("/settings", func(r types.Router) {
r.Post("/", h.settingsSet)
r.Get("/", h.settingsGet)
})
r.With(auth.AdminsOnly).Route("/broadcast", func(r types.Router) {
r.Get("/", h.broadcastStatus)
r.Post("/start", h.boradcastStart)
r.Post("/stop", h.boradcastStop)
})
r.With(auth.CanAccessClipboardOnly).With(auth.HostsOnly).Route("/clipboard", func(r types.Router) {
r.Get("/", h.clipboardGetText)
r.Post("/", h.clipboardSetText)
r.Get("/image.png", h.clipboardGetImage)
// TODO: Refactor. xclip is failing to set propper target type
// and this content is sent back to client as text in another
// clipboard update. Therefore endpoint is not usable!
//r.Post("/image", h.clipboardSetImage)
// TODO: Refactor. If there would be implemented custom target
// retrieval, this endpoint would be useful.
//r.Get("/targets", h.clipboardGetTargets)
})
r.With(auth.CanHostOnly).Route("/keyboard", func(r types.Router) {
r.Get("/map", h.keyboardMapGet)
r.With(auth.HostsOnly).Post("/map", h.keyboardMapSet)
r.Get("/modifiers", h.keyboardModifiersGet)
r.With(auth.HostsOnly).Post("/modifiers", h.keyboardModifiersSet)
})
r.With(auth.CanHostOnly).Route("/control", func(r types.Router) {
r.Get("/", h.controlStatus)
r.Post("/request", h.controlRequest)
r.Post("/release", h.controlRelease)
r.With(auth.AdminsOnly).Post("/take", h.controlTake)
r.With(auth.AdminsOnly).Post("/give/{sessionId}", h.controlGive)
r.With(auth.AdminsOnly).Post("/reset", h.controlReset)
})
r.With(auth.CanWatchOnly).Route("/screen", func(r types.Router) {
r.Get("/", h.screenConfiguration)
r.With(auth.AdminsOnly).Post("/", h.screenConfigurationChange)
r.With(auth.AdminsOnly).Get("/configurations", h.screenConfigurationsList)
r.Get("/cast.jpg", h.screenCastGet)
r.With(auth.AdminsOnly).Get("/shot.jpg", h.screenShotGet)
})
r.With(h.uploadMiddleware).Route("/upload", func(r types.Router) {
r.Post("/drop", h.uploadDrop)
r.Post("/dialog", h.uploadDialogPost)
r.Delete("/dialog", h.uploadDialogClose)
})
}
func (h *RoomHandler) uploadMiddleware(w http.ResponseWriter, r *http.Request) (context.Context, error) {
session, ok := auth.GetSession(r)
if !ok || (!session.IsHost() && (!session.Profile().CanHost || !h.sessions.Settings().ImplicitHosting)) {
return nil, utils.HttpForbidden("without implicit hosting, only host can upload files")
}
return nil, nil
}

View File

@ -0,0 +1,47 @@
package room
import (
"net/http"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) error {
keyboardMap := types.KeyboardMap{}
if err := utils.HttpJsonRequest(w, r, &keyboardMap); err != nil {
return err
}
err := h.desktop.SetKeyboardMap(keyboardMap)
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w)
}
func (h *RoomHandler) keyboardMapGet(w http.ResponseWriter, r *http.Request) error {
keyboardMap, err := h.desktop.GetKeyboardMap()
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
return utils.HttpSuccess(w, keyboardMap)
}
func (h *RoomHandler) keyboardModifiersSet(w http.ResponseWriter, r *http.Request) error {
keyboardModifiers := types.KeyboardModifiers{}
if err := utils.HttpJsonRequest(w, r, &keyboardModifiers); err != nil {
return err
}
h.desktop.SetKeyboardModifiers(keyboardModifiers)
return utils.HttpSuccess(w)
}
func (h *RoomHandler) keyboardModifiersGet(w http.ResponseWriter, r *http.Request) error {
keyboardModifiers := h.desktop.GetKeyboardModifiers()
return utils.HttpSuccess(w, keyboardModifiers)
}

View File

@ -0,0 +1,101 @@
package room
import (
"net/http"
"strconv"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
func (h *RoomHandler) screenConfiguration(w http.ResponseWriter, r *http.Request) error {
screenSize := h.desktop.GetScreenSize()
return utils.HttpSuccess(w, screenSize)
}
func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.Request) error {
auth, _ := auth.GetSession(r)
data := &types.ScreenSize{}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
size, err := h.desktop.SetScreenSize(types.ScreenSize{
Width: data.Width,
Height: data.Height,
Rate: data.Rate,
})
if err != nil {
return utils.HttpUnprocessableEntity("cannot set screen size").WithInternalErr(err)
}
h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{
ID: auth.ID(),
ScreenSize: size,
})
return utils.HttpSuccess(w, data)
}
// TODO: remove.
func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Request) error {
configurations := h.desktop.ScreenConfigurations()
return utils.HttpSuccess(w, configurations)
}
func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) error {
quality, err := strconv.Atoi(r.URL.Query().Get("quality"))
if err != nil {
quality = 90
}
img := h.desktop.GetScreenshotImage()
bytes, err := utils.CreateJPGImage(img, quality)
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Content-Type", "image/jpeg")
_, err = w.Write(bytes)
return err
}
func (h *RoomHandler) screenCastGet(w http.ResponseWriter, r *http.Request) error {
// display fallback image when private mode is enabled even if screencast is not
if session, ok := auth.GetSession(r); ok && session.PrivateModeEnabled() {
if h.privateModeImage != nil {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Content-Type", "image/jpeg")
_, err := w.Write(h.privateModeImage)
return err
}
return utils.HttpBadRequest("private mode is enabled but no fallback image available")
}
screencast := h.capture.Screencast()
if !screencast.Enabled() {
return utils.HttpBadRequest("screencast pipeline is not enabled")
}
bytes, err := screencast.Image()
if err != nil {
return utils.HttpInternalServerError().WithInternalErr(err)
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Content-Type", "image/jpeg")
_, err = w.Write(bytes)
return err
}

View File

@ -0,0 +1,38 @@
package room
import (
"encoding/json"
"io"
"net/http"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func (h *RoomHandler) settingsGet(w http.ResponseWriter, r *http.Request) error {
settings := h.sessions.Settings()
return utils.HttpSuccess(w, settings)
}
func (h *RoomHandler) settingsSet(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
// We read the request body first and unmashal it inside the UpdateSettingsFunc
// to ensure atomicity of the operation.
body, err := io.ReadAll(r.Body)
if err != nil {
return utils.HttpBadRequest("unable to read request body").WithInternalErr(err)
}
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
err = json.Unmarshal(body, settings)
return err == nil
})
if err != nil {
return utils.HttpBadRequest("unable to parse provided data").WithInternalErr(err)
}
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,172 @@
package room
import (
"io"
"net/http"
"os"
"path"
"strconv"
"github.com/demodesk/neko/pkg/utils"
)
// TODO: Extract file uploading to custom utility.
// maximum upload size of 32 MB
const maxUploadSize = 32 << 20
func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) error {
if !h.desktop.IsUploadDropEnabled() {
return utils.HttpBadRequest("upload drop is disabled")
}
err := r.ParseMultipartForm(maxUploadSize)
if err != nil {
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
}
//nolint
defer r.MultipartForm.RemoveAll()
X, err := strconv.Atoi(r.FormValue("x"))
if err != nil {
return utils.HttpBadRequest("no X coordinate received").WithInternalErr(err)
}
Y, err := strconv.Atoi(r.FormValue("y"))
if err != nil {
return utils.HttpBadRequest("no Y coordinate received").WithInternalErr(err)
}
req_files := r.MultipartForm.File["files"]
if len(req_files) == 0 {
return utils.HttpBadRequest("no files received")
}
dir, err := os.MkdirTemp("", "neko-drop-*")
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to create temporary directory")
}
files := []string{}
for _, req_file := range req_files {
path := path.Join(dir, req_file.Filename)
srcFile, err := req_file.Open()
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to open uploaded file")
}
defer srcFile.Close()
dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to open destination file")
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to copy uploaded file to destination file")
}
files = append(files, path)
}
if !h.desktop.DropFiles(X, Y, files) {
return utils.HttpInternalServerError().
WithInternalMsg("unable to drop files")
}
return utils.HttpSuccess(w)
}
func (h *RoomHandler) uploadDialogPost(w http.ResponseWriter, r *http.Request) error {
if !h.desktop.IsFileChooserDialogEnabled() {
return utils.HttpBadRequest("file chooser dialog is disabled")
}
err := r.ParseMultipartForm(maxUploadSize)
if err != nil {
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
}
//nolint
defer r.MultipartForm.RemoveAll()
req_files := r.MultipartForm.File["files"]
if len(req_files) == 0 {
return utils.HttpBadRequest("no files received")
}
if !h.desktop.IsFileChooserDialogOpened() {
return utils.HttpUnprocessableEntity("file chooser dialog is not open")
}
dir, err := os.MkdirTemp("", "neko-dialog-*")
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to create temporary directory")
}
for _, req_file := range req_files {
path := path.Join(dir, req_file.Filename)
srcFile, err := req_file.Open()
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to open uploaded file")
}
defer srcFile.Close()
dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to open destination file")
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to copy uploaded file to destination file")
}
}
if err := h.desktop.HandleFileChooserDialog(dir); err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
WithInternalMsg("unable to handle file chooser dialog")
}
return utils.HttpSuccess(w)
}
func (h *RoomHandler) uploadDialogClose(w http.ResponseWriter, r *http.Request) error {
if !h.desktop.IsFileChooserDialogEnabled() {
return utils.HttpBadRequest("file chooser dialog is disabled")
}
if !h.desktop.IsFileChooserDialogOpened() {
return utils.HttpUnprocessableEntity("file chooser dialog is not open")
}
h.desktop.CloseFileChooserDialog()
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,85 @@
package api
import (
"context"
"errors"
"net/http"
"github.com/demodesk/neko/internal/api/members"
"github.com/demodesk/neko/internal/api/room"
"github.com/demodesk/neko/internal/api/sessions"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type ApiManagerCtx struct {
sessions types.SessionManager
members types.MemberManager
desktop types.DesktopManager
capture types.CaptureManager
routers map[string]func(types.Router)
}
func New(
sessions types.SessionManager,
members types.MemberManager,
desktop types.DesktopManager,
capture types.CaptureManager,
) *ApiManagerCtx {
return &ApiManagerCtx{
sessions: sessions,
members: members,
desktop: desktop,
capture: capture,
routers: make(map[string]func(types.Router)),
}
}
func (api *ApiManagerCtx) Route(r types.Router) {
r.Post("/login", api.Login)
// Authenticated area
r.Group(func(r types.Router) {
r.Use(api.Authenticate)
r.Post("/logout", api.Logout)
r.Get("/whoami", api.Whoami)
sessionsHandler := sessions.New(api.sessions)
r.Route("/sessions", sessionsHandler.Route)
membersHandler := members.New(api.members)
r.Route("/members", membersHandler.Route)
r.Route("/members_bulk", membersHandler.RouteBulk)
roomHandler := room.New(api.sessions, api.desktop, api.capture)
r.Route("/room", roomHandler.Route)
for path, router := range api.routers {
r.Route(path, router)
}
})
}
func (api *ApiManagerCtx) Authenticate(w http.ResponseWriter, r *http.Request) (context.Context, error) {
session, err := api.sessions.Authenticate(r)
if err != nil {
if api.sessions.CookieEnabled() {
api.sessions.CookieClearToken(w, r)
}
if errors.Is(err, types.ErrSessionLoginDisabled) {
return nil, utils.HttpForbidden("login is disabled for this session")
}
return nil, utils.HttpUnauthorized().WithInternalErr(err)
}
return auth.SetSession(r, session), nil
}
func (api *ApiManagerCtx) AddRouter(path string, router func(types.Router)) {
api.routers[path] = router
}

View File

@ -0,0 +1,85 @@
package api
import (
"errors"
"net/http"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type SessionLoginPayload struct {
Username string `json:"username"`
Password string `json:"password"`
}
type SessionDataPayload struct {
ID string `json:"id"`
Token string `json:"token,omitempty"`
Profile types.MemberProfile `json:"profile"`
State types.SessionState `json:"state"`
}
func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) error {
data := &SessionLoginPayload{}
if err := utils.HttpJsonRequest(w, r, data); err != nil {
return err
}
session, token, err := api.members.Login(data.Username, data.Password)
if err != nil {
if errors.Is(err, types.ErrSessionAlreadyConnected) {
return utils.HttpUnprocessableEntity("session already connected")
} else if errors.Is(err, types.ErrMemberDoesNotExist) || errors.Is(err, types.ErrMemberInvalidPassword) {
return utils.HttpUnauthorized().WithInternalErr(err)
} else if errors.Is(err, types.ErrSessionLoginsLocked) {
return utils.HttpForbidden("logins are locked").WithInternalErr(err)
} else {
return utils.HttpInternalServerError().WithInternalErr(err)
}
}
sessionData := SessionDataPayload{
ID: session.ID(),
Profile: session.Profile(),
State: session.State(),
}
if api.sessions.CookieEnabled() {
api.sessions.CookieSetToken(w, token)
} else {
sessionData.Token = token
}
return utils.HttpSuccess(w, sessionData)
}
func (api *ApiManagerCtx) Logout(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
err := api.members.Logout(session.ID())
if err != nil {
if errors.Is(err, types.ErrSessionNotFound) {
return utils.HttpBadRequest("session is not logged in")
} else {
return utils.HttpInternalServerError().WithInternalErr(err)
}
}
if api.sessions.CookieEnabled() {
api.sessions.CookieClearToken(w, r)
}
return utils.HttpSuccess(w, true)
}
func (api *ApiManagerCtx) Whoami(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
return utils.HttpSuccess(w, SessionDataPayload{
ID: session.ID(),
Profile: session.Profile(),
State: session.State(),
})
}

View File

@ -0,0 +1,80 @@
package sessions
import (
"errors"
"net/http"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
"github.com/go-chi/chi"
)
type SessionDataPayload struct {
ID string `json:"id"`
Profile types.MemberProfile `json:"profile"`
State types.SessionState `json:"state"`
}
func (h *SessionsHandler) sessionsList(w http.ResponseWriter, r *http.Request) error {
sessions := []SessionDataPayload{}
for _, session := range h.sessions.List() {
sessions = append(sessions, SessionDataPayload{
ID: session.ID(),
Profile: session.Profile(),
State: session.State(),
})
}
return utils.HttpSuccess(w, sessions)
}
func (h *SessionsHandler) sessionsRead(w http.ResponseWriter, r *http.Request) error {
sessionId := chi.URLParam(r, "sessionId")
session, ok := h.sessions.Get(sessionId)
if !ok {
return utils.HttpNotFound("session not found")
}
return utils.HttpSuccess(w, SessionDataPayload{
ID: session.ID(),
Profile: session.Profile(),
State: session.State(),
})
}
func (h *SessionsHandler) sessionsDelete(w http.ResponseWriter, r *http.Request) error {
session, _ := auth.GetSession(r)
sessionId := chi.URLParam(r, "sessionId")
if sessionId == session.ID() {
return utils.HttpBadRequest("cannot delete own session")
}
err := h.sessions.Delete(sessionId)
if err != nil {
if errors.Is(err, types.ErrSessionNotFound) {
return utils.HttpBadRequest("session not found")
} else {
return utils.HttpInternalServerError().WithInternalErr(err)
}
}
return utils.HttpSuccess(w)
}
func (h *SessionsHandler) sessionsDisconnect(w http.ResponseWriter, r *http.Request) error {
sessionId := chi.URLParam(r, "sessionId")
err := h.sessions.Disconnect(sessionId)
if err != nil {
if errors.Is(err, types.ErrSessionNotFound) {
return utils.HttpBadRequest("session not found")
} else {
return utils.HttpInternalServerError().WithInternalErr(err)
}
}
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,30 @@
package sessions
import (
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
)
type SessionsHandler struct {
sessions types.SessionManager
}
func New(
sessions types.SessionManager,
) *SessionsHandler {
// Init
return &SessionsHandler{
sessions: sessions,
}
}
func (h *SessionsHandler) Route(r types.Router) {
r.Get("/", h.sessionsList)
r.With(auth.AdminsOnly).Route("/{sessionId}", func(r types.Router) {
r.Get("/", h.sessionsRead)
r.Delete("/", h.sessionsDelete)
r.Post("/disconnect", h.sessionsDisconnect)
})
}

View File

@ -0,0 +1,156 @@
package capture
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/gst"
"github.com/demodesk/neko/pkg/types"
)
type BroacastManagerCtx struct {
logger zerolog.Logger
mu sync.Mutex
pipeline gst.Pipeline
pipelineMu sync.Mutex
pipelineFn func(url string) (string, error)
url string
started bool
// metrics
pipelinesCounter prometheus.Counter
pipelinesActive prometheus.Gauge
}
func broadcastNew(pipelineFn func(url string) (string, error), defaultUrl string) *BroacastManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "broadcast").
Logger()
return &BroacastManagerCtx{
logger: logger,
pipelineFn: pipelineFn,
url: defaultUrl,
started: defaultUrl != "",
// metrics
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
Name: "pipelines_total",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of created pipelines.",
ConstLabels: map[string]string{
"submodule": "broadcast",
"video_id": "main",
"codec_name": "-",
"codec_type": "-",
},
}),
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
Name: "pipelines_active",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of active pipelines.",
ConstLabels: map[string]string{
"submodule": "broadcast",
"video_id": "main",
"codec_name": "-",
"codec_type": "-",
},
}),
}
}
func (manager *BroacastManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.destroyPipeline()
}
func (manager *BroacastManagerCtx) Start(url string) error {
manager.mu.Lock()
defer manager.mu.Unlock()
err := manager.createPipeline()
if err != nil {
return err
}
manager.url = url
manager.started = true
return nil
}
func (manager *BroacastManagerCtx) Stop() {
manager.mu.Lock()
defer manager.mu.Unlock()
manager.started = false
manager.destroyPipeline()
}
func (manager *BroacastManagerCtx) Started() bool {
manager.mu.Lock()
defer manager.mu.Unlock()
return manager.started
}
func (manager *BroacastManagerCtx) Url() string {
manager.mu.Lock()
defer manager.mu.Unlock()
return manager.url
}
func (manager *BroacastManagerCtx) createPipeline() error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline != nil {
return types.ErrCapturePipelineAlreadyExists
}
pipelineStr, err := manager.pipelineFn(manager.url)
if err != nil {
return err
}
manager.logger.Info().
Str("url", manager.url).
Str("src", pipelineStr).
Msgf("starting pipeline")
manager.pipeline, err = gst.CreatePipeline(pipelineStr)
if err != nil {
return err
}
manager.pipeline.Play()
manager.pipelinesCounter.Inc()
manager.pipelinesActive.Set(1)
return nil
}
func (manager *BroacastManagerCtx) destroyPipeline() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil {
return
}
manager.pipeline.Destroy()
manager.logger.Info().Msgf("destroying pipeline")
manager.pipeline = nil
manager.pipelinesActive.Set(0)
}

View File

@ -0,0 +1,269 @@
package capture
import (
"errors"
"fmt"
"strings"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type CaptureManagerCtx struct {
logger zerolog.Logger
desktop types.DesktopManager
config *config.Capture
// sinks
broadcast *BroacastManagerCtx
screencast *ScreencastManagerCtx
audio *StreamSinkManagerCtx
video *StreamSelectorManagerCtx
// sources
webcam *StreamSrcManagerCtx
microphone *StreamSrcManagerCtx
}
func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx {
logger := log.With().Str("module", "capture").Logger()
videos := map[string]types.StreamSinkManager{}
for video_id, cnf := range config.VideoPipelines {
pipelineConf := cnf
createPipeline := func() (string, error) {
if pipelineConf.GstPipeline != "" {
// replace {display} with valid display
return strings.Replace(pipelineConf.GstPipeline, "{display}", config.Display, 1), nil
}
screen := desktop.GetScreenSize()
pipeline, err := pipelineConf.GetPipeline(screen)
if err != nil {
return "", err
}
return fmt.Sprintf(
"ximagesrc display-name=%s show-pointer=false use-damage=false "+
"%s ! appsink name=appsink", config.Display, pipeline,
), nil
}
// trigger function to catch evaluation errors at startup
pipeline, err := createPipeline()
if err != nil {
logger.Panic().Err(err).
Str("video_id", video_id).
Msg("failed to create video pipeline")
}
logger.Info().
Str("video_id", video_id).
Str("pipeline", pipeline).
Msg("syntax check for video stream pipeline passed")
// append to videos
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
}
return &CaptureManagerCtx{
logger: logger,
desktop: desktop,
config: config,
// sinks
broadcast: broadcastNew(func(url string) (string, error) {
if config.BroadcastPipeline != "" {
var pipeline = config.BroadcastPipeline
// replace {display} with valid display
pipeline = strings.Replace(pipeline, "{display}", config.Display, 1)
// replace {device} with valid device
pipeline = strings.Replace(pipeline, "{device}", config.AudioDevice, 1)
// replace {url} with valid URL
return strings.Replace(pipeline, "{url}", url, 1), nil
}
return fmt.Sprintf(
"flvmux name=mux ! rtmpsink location='%s live=1' "+
"pulsesrc device=%s "+
"! audio/x-raw,channels=2 "+
"! audioconvert "+
"! queue "+
"! voaacenc bitrate=%d "+
"! mux. "+
"ximagesrc display-name=%s show-pointer=true use-damage=false "+
"! video/x-raw "+
"! videoconvert "+
"! queue "+
"! x264enc threads=4 bitrate=%d key-int-max=15 byte-stream=true tune=zerolatency speed-preset=%s "+
"! mux.", url, config.AudioDevice, config.BroadcastAudioBitrate*1000, config.Display, config.BroadcastVideoBitrate, config.BroadcastPreset,
), nil
}, config.BroadcastUrl),
screencast: screencastNew(config.ScreencastEnabled, func() string {
if config.ScreencastPipeline != "" {
// replace {display} with valid display
return strings.Replace(config.ScreencastPipeline, "{display}", config.Display, 1)
}
return fmt.Sprintf(
"ximagesrc display-name=%s show-pointer=true use-damage=false "+
"! video/x-raw,framerate=%s "+
"! videoconvert "+
"! queue "+
"! jpegenc quality=%s "+
"! appsink name=appsink", config.Display, config.ScreencastRate, config.ScreencastQuality,
)
}()),
audio: streamSinkNew(config.AudioCodec, func() (string, error) {
if config.AudioPipeline != "" {
// replace {device} with valid device
return strings.Replace(config.AudioPipeline, "{device}", config.AudioDevice, 1), nil
}
return fmt.Sprintf(
"pulsesrc device=%s "+
"! audio/x-raw,channels=2 "+
"! audioconvert "+
"! queue "+
"! %s "+
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
), nil
}, "audio"),
video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
// sources
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
codec.VP8().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
fmt.Sprintf("! application/x-rtp, payload=%d, encoding-name=VP8-DRAFT-IETF-01 ", codec.VP8().PayloadType) +
"! rtpvp8depay " +
"! decodebin " +
"! videoconvert " +
"! videorate " +
"! videoscale " +
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
"! identity drop-allocation=true " +
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
// TODO: Test this pipeline.
codec.VP9().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
"! application/x-rtp " +
"! rtpvp9depay " +
"! decodebin " +
"! videoconvert " +
"! videorate " +
"! videoscale " +
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
"! identity drop-allocation=true " +
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
// TODO: Test this pipeline.
codec.H264().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
"! application/x-rtp " +
"! rtph264depay " +
"! decodebin " +
"! videoconvert " +
"! videorate " +
"! videoscale " +
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
"! identity drop-allocation=true " +
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
}, "webcam"),
microphone: streamSrcNew(config.MicrophoneEnabled, map[string]string{
codec.Opus().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
fmt.Sprintf("! application/x-rtp, payload=%d, encoding-name=OPUS ", codec.Opus().PayloadType) +
"! rtpopusdepay " +
"! decodebin " +
fmt.Sprintf("! pulsesink device=%s", config.MicrophoneDevice),
// TODO: Test this pipeline.
codec.G722().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
"! application/x-rtp clock-rate=8000 " +
"! rtpg722depay " +
"! decodebin " +
fmt.Sprintf("! pulsesink device=%s", config.MicrophoneDevice),
}, "microphone"),
}
}
func (manager *CaptureManagerCtx) Start() {
if manager.broadcast.Started() {
if err := manager.broadcast.createPipeline(); err != nil {
manager.logger.Panic().Err(err).Msg("unable to create broadcast pipeline")
}
}
manager.desktop.OnBeforeScreenSizeChange(func() {
manager.video.destroyPipelines()
if manager.broadcast.Started() {
manager.broadcast.destroyPipeline()
}
if manager.screencast.Started() {
manager.screencast.destroyPipeline()
}
})
manager.desktop.OnAfterScreenSizeChange(func() {
err := manager.video.recreatePipelines()
if err != nil {
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
}
if manager.broadcast.Started() {
err := manager.broadcast.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
manager.logger.Panic().Err(err).Msg("unable to recreate broadcast pipeline")
}
}
if manager.screencast.Started() {
err := manager.screencast.createPipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
manager.logger.Panic().Err(err).Msg("unable to recreate screencast pipeline")
}
}
})
}
func (manager *CaptureManagerCtx) Shutdown() error {
manager.logger.Info().Msgf("shutdown")
manager.broadcast.shutdown()
manager.screencast.shutdown()
manager.audio.shutdown()
manager.video.shutdown()
manager.webcam.shutdown()
manager.microphone.shutdown()
return nil
}
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
return manager.broadcast
}
func (manager *CaptureManagerCtx) Screencast() types.ScreencastManager {
return manager.screencast
}
func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
return manager.audio
}
func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
return manager.video
}
func (manager *CaptureManagerCtx) Webcam() types.StreamSrcManager {
return manager.webcam
}
func (manager *CaptureManagerCtx) Microphone() types.StreamSrcManager {
return manager.microphone
}

View File

@ -0,0 +1,257 @@
package capture
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/gst"
"github.com/demodesk/neko/pkg/types"
)
// timeout between intervals, when screencast pipeline is checked
const screencastTimeout = 5 * time.Second
type ScreencastManagerCtx struct {
logger zerolog.Logger
mu sync.Mutex
wg sync.WaitGroup
pipeline gst.Pipeline
pipelineStr string
pipelineMu sync.Mutex
image types.Sample
imageMu sync.Mutex
tickerStop chan struct{}
enabled bool
started bool
expired int32
// metrics
imagesCounter prometheus.Counter
pipelinesCounter prometheus.Counter
pipelinesActive prometheus.Gauge
}
func screencastNew(enabled bool, pipelineStr string) *ScreencastManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "screencast").
Logger()
manager := &ScreencastManagerCtx{
logger: logger,
pipelineStr: pipelineStr,
tickerStop: make(chan struct{}),
enabled: enabled,
started: false,
// metrics
imagesCounter: promauto.NewCounter(prometheus.CounterOpts{
Name: "screencast_images_total",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of created images.",
}),
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
Name: "pipelines_total",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of created pipelines.",
ConstLabels: map[string]string{
"submodule": "screencast",
"video_id": "main",
"codec_name": "-",
"codec_type": "-",
},
}),
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
Name: "pipelines_active",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of active pipelines.",
ConstLabels: map[string]string{
"submodule": "screencast",
"video_id": "main",
"codec_name": "-",
"codec_type": "-",
},
}),
}
manager.wg.Add(1)
go func() {
defer manager.wg.Done()
ticker := time.NewTicker(screencastTimeout)
defer ticker.Stop()
for {
select {
case <-manager.tickerStop:
return
case <-ticker.C:
if manager.Started() && !atomic.CompareAndSwapInt32(&manager.expired, 0, 1) {
manager.stop()
}
}
}
}()
return manager
}
func (manager *ScreencastManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.destroyPipeline()
close(manager.tickerStop)
manager.wg.Wait()
}
func (manager *ScreencastManagerCtx) Enabled() bool {
manager.mu.Lock()
defer manager.mu.Unlock()
return manager.enabled
}
func (manager *ScreencastManagerCtx) Started() bool {
manager.mu.Lock()
defer manager.mu.Unlock()
return manager.started
}
func (manager *ScreencastManagerCtx) Image() ([]byte, error) {
atomic.StoreInt32(&manager.expired, 0)
err := manager.start()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return nil, err
}
manager.imageMu.Lock()
defer manager.imageMu.Unlock()
if manager.image.Data == nil {
return nil, errors.New("image data not found")
}
return manager.image.Data, nil
}
func (manager *ScreencastManagerCtx) start() error {
manager.mu.Lock()
defer manager.mu.Unlock()
if !manager.enabled {
return errors.New("screencast not enabled")
}
err := manager.createPipeline()
if err != nil {
return err
}
manager.started = true
return nil
}
func (manager *ScreencastManagerCtx) stop() {
manager.mu.Lock()
defer manager.mu.Unlock()
manager.started = false
manager.destroyPipeline()
}
func (manager *ScreencastManagerCtx) createPipeline() error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline != nil {
return types.ErrCapturePipelineAlreadyExists
}
var err error
manager.logger.Info().
Str("str", manager.pipelineStr).
Msgf("creating pipeline")
manager.pipeline, err = gst.CreatePipeline(manager.pipelineStr)
if err != nil {
return err
}
manager.pipeline.AttachAppsink("appsink")
manager.pipeline.Play()
manager.pipelinesCounter.Inc()
manager.pipelinesActive.Set(1)
// get first image
select {
case image, ok := <-manager.pipeline.Sample():
if !ok {
return errors.New("unable to get first image")
} else {
manager.setImage(image)
}
case <-time.After(1 * time.Second):
return errors.New("timeouted while waiting for first image")
}
manager.wg.Add(1)
pipeline := manager.pipeline
go func() {
manager.logger.Debug().Msg("started receiving images")
defer manager.wg.Done()
for {
image, ok := <-pipeline.Sample()
if !ok {
manager.logger.Debug().Msg("stopped receiving images")
return
}
manager.setImage(image)
}
}()
return nil
}
func (manager *ScreencastManagerCtx) setImage(image types.Sample) {
manager.imageMu.Lock()
manager.image = image
manager.imageMu.Unlock()
manager.imagesCounter.Inc()
}
func (manager *ScreencastManagerCtx) destroyPipeline() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil {
return
}
manager.pipeline.Destroy()
manager.logger.Info().Msgf("destroying pipeline")
manager.pipeline = nil
manager.pipelinesActive.Set(0)
}

View File

@ -0,0 +1,206 @@
package capture
import (
"errors"
"sort"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type StreamSelectorManagerCtx struct {
logger zerolog.Logger
codec codec.RTPCodec
streams map[string]types.StreamSinkManager
streamIDs []string
}
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-selector").
Logger()
return &StreamSelectorManagerCtx{
logger: logger,
codec: codec,
streams: streams,
streamIDs: streamIDs,
}
}
func (manager *StreamSelectorManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.destroyPipelines()
}
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
for _, stream := range manager.streams {
if stream.Started() {
stream.DestroyPipeline()
}
}
}
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
for _, stream := range manager.streams {
if stream.Started() {
err := stream.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
}
}
return nil
}
func (manager *StreamSelectorManagerCtx) IDs() []string {
return manager.streamIDs
}
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
// select stream by ID
if selector.ID != "" {
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
var lastStream types.StreamSinkManager
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
var lastStream types.StreamSinkManager
for _, streamID := range manager.streamIDs {
if streamID == selector.ID {
return lastStream, lastStream != nil
}
stream, ok := manager.streams[streamID]
if ok {
lastStream = stream
}
}
// we couldn't find a higher stream
return nil, false
}
// select exact stream
stream, ok := manager.streams[selector.ID]
return stream, ok
}
// select stream by bitrate
if selector.Bitrate != 0 {
// select stream by nearest bitrate
if selector.Type == types.StreamSelectorTypeNearest {
return manager.nearestBitrate(selector.Bitrate), true
}
// select lower stream
if selector.Type == types.StreamSelectorTypeLower {
// start from the highest stream, and go down, until we find a lower stream
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
streamID := manager.streamIDs[i]
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() < selector.Bitrate {
return stream, true
}
}
// we couldn't find a lower stream
return nil, false
}
// select higher stream
if selector.Type == types.StreamSelectorTypeHigher {
// start from the lowest stream, and go up, until we find a higher stream
for _, streamID := range manager.streamIDs {
stream := manager.streams[streamID]
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if considered && stream.Bitrate() > selector.Bitrate {
return stream, true
}
}
// we couldn't find a higher stream
return nil, false
}
// select stream by exact bitrate
for _, stream := range manager.streams {
if stream.Bitrate() == selector.Bitrate {
return stream, true
}
}
}
// we couldn't find a stream
return nil, false
}
// TODO: This is a very naive implementation, we should use a binary search instead.
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
type streamDiff struct {
id string
bitrateDiff int
}
sortDiff := func(a, b int) bool {
switch {
case a < 0 && b < 0:
return a > b
case a >= 0:
if b >= 0 {
return a <= b
}
return true
}
return false
}
var diffs []streamDiff
for _, stream := range manager.streams {
// if stream should be considered in calculation
considered := stream.Bitrate() != 0 && stream.Started()
if !considered {
continue
}
diffs = append(diffs, streamDiff{
id: stream.ID(),
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
})
}
// no streams available
if len(diffs) == 0 {
// return first (lowest) stream
return manager.streams[manager.streamIDs[0]]
}
sort.Slice(diffs, func(i, j int) bool {
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
})
bestDiff := diffs[0]
return manager.streams[bestDiff.id]
}

View File

@ -0,0 +1,414 @@
package capture
import (
"errors"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/gst"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
var moveSinkListenerMu = sync.Mutex{}
type StreamSinkManagerCtx struct {
id string
// wait for a keyframe before sending samples
waitForKf bool
bitrate uint64 // atomic
brBuckets map[int]float64
logger zerolog.Logger
mu sync.Mutex
wg sync.WaitGroup
codec codec.RTPCodec
pipeline gst.Pipeline
pipelineMu sync.Mutex
pipelineFn func() (string, error)
listeners map[uintptr]types.SampleListener
listenersKf map[uintptr]types.SampleListener // keyframe lobby
listenersMu sync.Mutex
// metrics
currentListeners prometheus.Gauge
totalBytes prometheus.Counter
pipelinesCounter prometheus.Counter
pipelinesActive prometheus.Gauge
}
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-sink").
Str("id", id).Logger()
manager := &StreamSinkManagerCtx{
id: id,
// only wait for keyframes if the codec is video
waitForKf: codec.IsVideo(),
bitrate: 0,
brBuckets: map[int]float64{},
logger: logger,
codec: codec,
pipelineFn: pipelineFn,
listeners: map[uintptr]types.SampleListener{},
listenersKf: map[uintptr]types.SampleListener{},
// metrics
currentListeners: promauto.NewGauge(prometheus.GaugeOpts{
Name: "streamsink_listeners",
Namespace: "neko",
Subsystem: "capture",
Help: "Current number of listeners for a pipeline.",
ConstLabels: map[string]string{
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
Name: "streamsink_bytes",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of bytes created by the pipeline.",
ConstLabels: map[string]string{
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
Name: "pipelines_total",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of created pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
Name: "pipelines_active",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of active pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsink",
"video_id": id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
}),
}
return manager
}
func (manager *StreamSinkManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.listenersMu.Lock()
for key := range manager.listeners {
delete(manager.listeners, key)
}
for key := range manager.listenersKf {
delete(manager.listenersKf, key)
}
manager.listenersMu.Unlock()
manager.DestroyPipeline()
manager.wg.Wait()
}
func (manager *StreamSinkManagerCtx) ID() string {
return manager.id
}
func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
return atomic.LoadUint64(&manager.bitrate)
}
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
return manager.codec
}
func (manager *StreamSinkManagerCtx) start() error {
if len(manager.listeners)+len(manager.listenersKf) == 0 {
err := manager.CreatePipeline()
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
return err
}
manager.logger.Info().Msgf("first listener, starting")
}
return nil
}
func (manager *StreamSinkManagerCtx) stop() {
if len(manager.listeners)+len(manager.listenersKf) == 0 {
manager.DestroyPipeline()
manager.logger.Info().Msgf("last listener, stopping")
}
}
func (manager *StreamSinkManagerCtx) addListener(listener types.SampleListener) {
ptr := reflect.ValueOf(listener).Pointer()
emitKeyframe := false
manager.listenersMu.Lock()
if manager.waitForKf {
// if this is the first listener, we need to emit a keyframe
emitKeyframe = len(manager.listenersKf) == 0
// if we're waiting for a keyframe, add it to the keyframe lobby
manager.listenersKf[ptr] = listener
} else {
// otherwise, add it as a regular listener
manager.listeners[ptr] = listener
}
manager.listenersMu.Unlock()
manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener")
manager.currentListeners.Set(float64(manager.ListenersCount()))
// if we will be waiting for a keyframe, emit one now
if manager.pipeline != nil && emitKeyframe {
manager.pipeline.EmitVideoKeyframe()
}
}
func (manager *StreamSinkManagerCtx) removeListener(listener types.SampleListener) {
ptr := reflect.ValueOf(listener).Pointer()
manager.listenersMu.Lock()
delete(manager.listeners, ptr)
delete(manager.listenersKf, ptr) // if it's a keyframe listener, remove it too
manager.listenersMu.Unlock()
manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener")
manager.currentListeners.Set(float64(manager.ListenersCount()))
}
func (manager *StreamSinkManagerCtx) AddListener(listener types.SampleListener) error {
manager.mu.Lock()
defer manager.mu.Unlock()
if listener == nil {
return errors.New("listener cannot be nil")
}
// start if stopped
if err := manager.start(); err != nil {
return err
}
// add listener
manager.addListener(listener)
return nil
}
func (manager *StreamSinkManagerCtx) RemoveListener(listener types.SampleListener) error {
manager.mu.Lock()
defer manager.mu.Unlock()
if listener == nil {
return errors.New("listener cannot be nil")
}
// remove listener
manager.removeListener(listener)
// stop if started
manager.stop()
return nil
}
// moving listeners between streams ensures, that target pipeline is running
// before listener is added, and stops source pipeline if there are 0 listeners
func (manager *StreamSinkManagerCtx) MoveListenerTo(listener types.SampleListener, stream types.StreamSinkManager) error {
if listener == nil {
return errors.New("listener cannot be nil")
}
targetStream, ok := stream.(*StreamSinkManagerCtx)
if !ok {
return errors.New("target stream manager does not support moving listeners")
}
// we need to acquire both mutextes, from source stream and from target stream
// in order to do that safely (without possibility of deadlock) we need third
// global mutex, that ensures atomic locking
// lock global mutex
moveSinkListenerMu.Lock()
// lock source stream
manager.mu.Lock()
defer manager.mu.Unlock()
// lock target stream
targetStream.mu.Lock()
defer targetStream.mu.Unlock()
// unlock global mutex
moveSinkListenerMu.Unlock()
// start if stopped
if err := targetStream.start(); err != nil {
return err
}
// swap listeners
manager.removeListener(listener)
targetStream.addListener(listener)
// stop if started
manager.stop()
return nil
}
func (manager *StreamSinkManagerCtx) ListenersCount() int {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
return len(manager.listeners) + len(manager.listenersKf)
}
func (manager *StreamSinkManagerCtx) Started() bool {
return manager.ListenersCount() > 0
}
func (manager *StreamSinkManagerCtx) CreatePipeline() error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline != nil {
return types.ErrCapturePipelineAlreadyExists
}
pipelineStr, err := manager.pipelineFn()
if err != nil {
return err
}
manager.logger.Info().
Str("codec", manager.codec.Name).
Str("src", pipelineStr).
Msgf("creating pipeline")
manager.pipeline, err = gst.CreatePipeline(pipelineStr)
if err != nil {
return err
}
manager.pipeline.AttachAppsink("appsink")
manager.pipeline.Play()
manager.wg.Add(1)
pipeline := manager.pipeline
go func() {
manager.logger.Debug().Msg("started emitting samples")
defer manager.wg.Done()
for {
sample, ok := <-pipeline.Sample()
if !ok {
manager.logger.Debug().Msg("stopped emitting samples")
return
}
manager.onSample(sample)
}
}()
manager.pipelinesCounter.Inc()
manager.pipelinesActive.Set(1)
return nil
}
func (manager *StreamSinkManagerCtx) saveSampleBitrate(timestamp time.Time, delta float64) {
// get unix timestamp in seconds
sec := timestamp.Unix()
// last bucket is timestamp rounded to 3 seconds - 1 second
last := int((sec - 1) % 3)
// current bucket is timestamp rounded to 3 seconds
curr := int(sec % 3)
// next bucket is timestamp rounded to 3 seconds + 1 second
next := int((sec + 1) % 3)
if manager.brBuckets[next] != 0 {
// atomic update bitrate
atomic.StoreUint64(&manager.bitrate, uint64(manager.brBuckets[last]))
// empty next bucket
manager.brBuckets[next] = 0
}
// add rate to current bucket
manager.brBuckets[curr] += delta
}
func (manager *StreamSinkManagerCtx) onSample(sample types.Sample) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
// save to metrics
length := float64(sample.Length)
manager.totalBytes.Add(length)
manager.saveSampleBitrate(sample.Timestamp, length)
// if is not delta unit -> it can be decoded independently -> it is a keyframe
if manager.waitForKf && !sample.DeltaUnit && len(manager.listenersKf) > 0 {
// if current sample is a keyframe, move listeners from
// keyframe lobby to actual listeners map and clear lobby
for k, v := range manager.listenersKf {
manager.listeners[k] = v
}
manager.listenersKf = make(map[uintptr]types.SampleListener)
}
for _, l := range manager.listeners {
l.WriteSample(sample)
}
}
func (manager *StreamSinkManagerCtx) DestroyPipeline() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil {
return
}
manager.pipeline.Destroy()
manager.logger.Info().Msgf("destroying pipeline")
manager.pipeline = nil
manager.pipelinesActive.Set(0)
manager.brBuckets = make(map[int]float64)
atomic.StoreUint64(&manager.bitrate, 0)
}

View File

@ -0,0 +1,197 @@
package capture
import (
"errors"
"sync"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/gst"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type StreamSrcManagerCtx struct {
logger zerolog.Logger
enabled bool
codecPipeline map[string]string // codec -> pipeline
codec codec.RTPCodec
pipeline gst.Pipeline
pipelineMu sync.Mutex
pipelineStr string
// metrics
pushedData map[string]prometheus.Summary
pipelinesCounter map[string]prometheus.Counter
pipelinesActive map[string]prometheus.Gauge
}
func streamSrcNew(enabled bool, codecPipeline map[string]string, video_id string) *StreamSrcManagerCtx {
logger := log.With().
Str("module", "capture").
Str("submodule", "stream-src").
Str("video_id", video_id).Logger()
pushedData := map[string]prometheus.Summary{}
pipelinesCounter := map[string]prometheus.Counter{}
pipelinesActive := map[string]prometheus.Gauge{}
for codecName, pipeline := range codecPipeline {
codec, ok := codec.ParseStr(codecName)
if !ok {
logger.Fatal().
Str("codec", codecName).
Str("pipeline", pipeline).
Msg("unknown codec name")
}
pushedData[codecName] = promauto.NewSummary(prometheus.SummaryOpts{
Name: "streamsrc_data_bytes",
Namespace: "neko",
Subsystem: "capture",
Help: "Data pushed to a pipeline (in bytes).",
ConstLabels: map[string]string{
"video_id": video_id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
})
pipelinesCounter[codecName] = promauto.NewCounter(prometheus.CounterOpts{
Name: "pipelines_total",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of created pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsrc",
"video_id": video_id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
})
pipelinesActive[codecName] = promauto.NewGauge(prometheus.GaugeOpts{
Name: "pipelines_active",
Namespace: "neko",
Subsystem: "capture",
Help: "Total number of active pipelines.",
ConstLabels: map[string]string{
"submodule": "streamsrc",
"video_id": video_id,
"codec_name": codec.Name,
"codec_type": codec.Type.String(),
},
})
}
return &StreamSrcManagerCtx{
logger: logger,
enabled: enabled,
codecPipeline: codecPipeline,
// metrics
pushedData: pushedData,
pipelinesCounter: pipelinesCounter,
pipelinesActive: pipelinesActive,
}
}
func (manager *StreamSrcManagerCtx) shutdown() {
manager.logger.Info().Msgf("shutdown")
manager.Stop()
}
func (manager *StreamSrcManagerCtx) Codec() codec.RTPCodec {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
return manager.codec
}
func (manager *StreamSrcManagerCtx) Start(codec codec.RTPCodec) error {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline != nil {
return types.ErrCapturePipelineAlreadyExists
}
if !manager.enabled {
return errors.New("stream-src not enabled")
}
found := false
for codecName, pipeline := range manager.codecPipeline {
if codecName == codec.Name {
manager.pipelineStr = pipeline
manager.codec = codec
found = true
break
}
}
if !found {
return errors.New("no pipeline found for a codec")
}
var err error
manager.logger.Info().
Str("codec", manager.codec.Name).
Str("src", manager.pipelineStr).
Msgf("creating pipeline")
manager.pipeline, err = gst.CreatePipeline(manager.pipelineStr)
if err != nil {
return err
}
manager.pipeline.AttachAppsrc("appsrc")
manager.pipeline.Play()
manager.pipelinesCounter[manager.codec.Name].Inc()
manager.pipelinesActive[manager.codec.Name].Set(1)
return nil
}
func (manager *StreamSrcManagerCtx) Stop() {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil {
return
}
manager.pipeline.Destroy()
manager.pipeline = nil
manager.logger.Info().
Str("codec", manager.codec.Name).
Str("src", manager.pipelineStr).
Msgf("destroying pipeline")
manager.pipelinesActive[manager.codec.Name].Set(0)
}
func (manager *StreamSrcManagerCtx) Push(bytes []byte) {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
if manager.pipeline == nil {
return
}
manager.pipeline.Push(bytes)
manager.pushedData[manager.codec.Name].Observe(float64(len(bytes)))
}
func (manager *StreamSrcManagerCtx) Started() bool {
manager.pipelineMu.Lock()
defer manager.pipelineMu.Unlock()
return manager.pipeline != nil
}

View File

@ -0,0 +1,246 @@
package config
import (
"os"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
"github.com/demodesk/neko/pkg/utils"
)
type Capture struct {
Display string
VideoCodec codec.RTPCodec
VideoIDs []string
VideoPipelines map[string]types.VideoConfig
AudioDevice string
AudioCodec codec.RTPCodec
AudioPipeline string
BroadcastAudioBitrate int
BroadcastVideoBitrate int
BroadcastPreset string
BroadcastPipeline string
BroadcastUrl string
ScreencastEnabled bool
ScreencastRate string
ScreencastQuality string
ScreencastPipeline string
WebcamEnabled bool
WebcamDevice string
WebcamWidth int
WebcamHeight int
MicrophoneEnabled bool
MicrophoneDevice string
}
func (Capture) Init(cmd *cobra.Command) error {
// audio
cmd.PersistentFlags().String("capture.audio.device", "audio_output.monitor", "pulseaudio device to capture")
if err := viper.BindPFlag("capture.audio.device", cmd.PersistentFlags().Lookup("capture.audio.device")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.audio.codec", "opus", "audio codec to be used")
if err := viper.BindPFlag("capture.audio.codec", cmd.PersistentFlags().Lookup("capture.audio.codec")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.audio.pipeline", "", "gstreamer pipeline used for audio streaming")
if err := viper.BindPFlag("capture.audio.pipeline", cmd.PersistentFlags().Lookup("capture.audio.pipeline")); err != nil {
return err
}
// videos
cmd.PersistentFlags().String("capture.video.codec", "vp8", "video codec to be used")
if err := viper.BindPFlag("capture.video.codec", cmd.PersistentFlags().Lookup("capture.video.codec")); err != nil {
return err
}
cmd.PersistentFlags().StringSlice("capture.video.ids", []string{}, "ordered list of video ids")
if err := viper.BindPFlag("capture.video.ids", cmd.PersistentFlags().Lookup("capture.video.ids")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.video.pipelines", "[]", "pipelines config in JSON used for video streaming")
if err := viper.BindPFlag("capture.video.pipelines", cmd.PersistentFlags().Lookup("capture.video.pipelines")); err != nil {
return err
}
// broadcast
cmd.PersistentFlags().Int("capture.broadcast.audio_bitrate", 128, "broadcast audio bitrate in KB/s")
if err := viper.BindPFlag("capture.broadcast.audio_bitrate", cmd.PersistentFlags().Lookup("capture.broadcast.audio_bitrate")); err != nil {
return err
}
cmd.PersistentFlags().Int("capture.broadcast.video_bitrate", 4096, "broadcast video bitrate in KB/s")
if err := viper.BindPFlag("capture.broadcast.video_bitrate", cmd.PersistentFlags().Lookup("capture.broadcast.video_bitrate")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.broadcast.preset", "veryfast", "broadcast speed preset for h264 encoding")
if err := viper.BindPFlag("capture.broadcast.preset", cmd.PersistentFlags().Lookup("capture.broadcast.preset")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.broadcast.pipeline", "", "gstreamer pipeline used for broadcasting")
if err := viper.BindPFlag("capture.broadcast.pipeline", cmd.PersistentFlags().Lookup("capture.broadcast.pipeline")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.broadcast.url", "", "initial URL for broadcasting, setting this value will automatically start broadcasting")
if err := viper.BindPFlag("capture.broadcast.url", cmd.PersistentFlags().Lookup("capture.broadcast.url")); err != nil {
return err
}
// screencast
cmd.PersistentFlags().Bool("capture.screencast.enabled", false, "enable screencast")
if err := viper.BindPFlag("capture.screencast.enabled", cmd.PersistentFlags().Lookup("capture.screencast.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.screencast.rate", "10/1", "screencast frame rate")
if err := viper.BindPFlag("capture.screencast.rate", cmd.PersistentFlags().Lookup("capture.screencast.rate")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.screencast.quality", "60", "screencast JPEG quality")
if err := viper.BindPFlag("capture.screencast.quality", cmd.PersistentFlags().Lookup("capture.screencast.quality")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.screencast.pipeline", "", "gstreamer pipeline used for screencasting")
if err := viper.BindPFlag("capture.screencast.pipeline", cmd.PersistentFlags().Lookup("capture.screencast.pipeline")); err != nil {
return err
}
// webcam
cmd.PersistentFlags().Bool("capture.webcam.enabled", false, "enable webcam stream")
if err := viper.BindPFlag("capture.webcam.enabled", cmd.PersistentFlags().Lookup("capture.webcam.enabled")); err != nil {
return err
}
// sudo apt install v4l2loopback-dkms v4l2loopback-utils
// sudo apt-get install linux-headers-`uname -r` linux-modules-extra-`uname -r`
// sudo modprobe v4l2loopback exclusive_caps=1
cmd.PersistentFlags().String("capture.webcam.device", "/dev/video0", "v4l2sink device used for webcam")
if err := viper.BindPFlag("capture.webcam.device", cmd.PersistentFlags().Lookup("capture.webcam.device")); err != nil {
return err
}
cmd.PersistentFlags().Int("capture.webcam.width", 1280, "webcam stream width")
if err := viper.BindPFlag("capture.webcam.width", cmd.PersistentFlags().Lookup("capture.webcam.width")); err != nil {
return err
}
cmd.PersistentFlags().Int("capture.webcam.height", 720, "webcam stream height")
if err := viper.BindPFlag("capture.webcam.height", cmd.PersistentFlags().Lookup("capture.webcam.height")); err != nil {
return err
}
// microphone
cmd.PersistentFlags().Bool("capture.microphone.enabled", true, "enable microphone stream")
if err := viper.BindPFlag("capture.microphone.enabled", cmd.PersistentFlags().Lookup("capture.microphone.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("capture.microphone.device", "audio_input", "pulseaudio device used for microphone")
if err := viper.BindPFlag("capture.microphone.device", cmd.PersistentFlags().Lookup("capture.microphone.device")); err != nil {
return err
}
return nil
}
func (s *Capture) Set() {
var ok bool
// Display is provided by env variable
s.Display = os.Getenv("DISPLAY")
// video
videoCodec := viper.GetString("capture.video.codec")
s.VideoCodec, ok = codec.ParseStr(videoCodec)
if !ok || !s.VideoCodec.IsVideo() {
log.Warn().Str("codec", videoCodec).Msgf("unknown video codec, using Vp8")
s.VideoCodec = codec.VP8()
}
s.VideoIDs = viper.GetStringSlice("capture.video.ids")
if err := viper.UnmarshalKey("capture.video.pipelines", &s.VideoPipelines, viper.DecodeHook(
utils.JsonStringAutoDecode(s.VideoPipelines),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse video pipelines")
}
// default video
if len(s.VideoPipelines) == 0 {
log.Warn().Msgf("no video pipelines specified, using defaults")
s.VideoCodec = codec.VP8()
s.VideoPipelines = map[string]types.VideoConfig{
"main": {
Fps: "25",
GstEncoder: "vp8enc",
GstParams: map[string]string{
"target-bitrate": "round(3072 * 650)",
"cpu-used": "4",
"end-usage": "cbr",
"threads": "4",
"deadline": "1",
"undershoot": "95",
"buffer-size": "(3072 * 4)",
"buffer-initial-size": "(3072 * 2)",
"buffer-optimal-size": "(3072 * 3)",
"keyframe-max-dist": "25",
"min-quantizer": "4",
"max-quantizer": "20",
},
},
}
s.VideoIDs = []string{"main"}
}
// audio
s.AudioDevice = viper.GetString("capture.audio.device")
s.AudioPipeline = viper.GetString("capture.audio.pipeline")
audioCodec := viper.GetString("capture.audio.codec")
s.AudioCodec, ok = codec.ParseStr(audioCodec)
if !ok || !s.AudioCodec.IsAudio() {
log.Warn().Str("codec", audioCodec).Msgf("unknown audio codec, using Opus")
s.AudioCodec = codec.Opus()
}
// broadcast
s.BroadcastAudioBitrate = viper.GetInt("capture.broadcast.audio_bitrate")
s.BroadcastVideoBitrate = viper.GetInt("capture.broadcast.video_bitrate")
s.BroadcastPreset = viper.GetString("capture.broadcast.preset")
s.BroadcastPipeline = viper.GetString("capture.broadcast.pipeline")
s.BroadcastUrl = viper.GetString("capture.broadcast.url")
// screencast
s.ScreencastEnabled = viper.GetBool("capture.screencast.enabled")
s.ScreencastRate = viper.GetString("capture.screencast.rate")
s.ScreencastQuality = viper.GetString("capture.screencast.quality")
s.ScreencastPipeline = viper.GetString("capture.screencast.pipeline")
// webcam
s.WebcamEnabled = viper.GetBool("capture.webcam.enabled")
s.WebcamDevice = viper.GetString("capture.webcam.device")
s.WebcamWidth = viper.GetInt("capture.webcam.width")
s.WebcamHeight = viper.GetInt("capture.webcam.height")
// microphone
s.MicrophoneEnabled = viper.GetBool("capture.microphone.enabled")
s.MicrophoneDevice = viper.GetString("capture.microphone.device")
}

View File

@ -0,0 +1,8 @@
package config
import "github.com/spf13/cobra"
type Config interface {
Init(cmd *cobra.Command) error
Set()
}

View File

@ -0,0 +1,91 @@
package config
import (
"os"
"regexp"
"strconv"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/demodesk/neko/pkg/types"
)
type Desktop struct {
Display string
ScreenSize types.ScreenSize
UseInputDriver bool
InputSocket string
Unminimize bool
UploadDrop bool
FileChooserDialog bool
}
func (Desktop) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("desktop.screen", "1280x720@30", "default screen size and framerate")
if err := viper.BindPFlag("desktop.screen", cmd.PersistentFlags().Lookup("desktop.screen")); err != nil {
return err
}
cmd.PersistentFlags().Bool("desktop.input.enabled", true, "whether custom xf86 input driver should be used to handle touchscreen")
if err := viper.BindPFlag("desktop.input.enabled", cmd.PersistentFlags().Lookup("desktop.input.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("desktop.input.socket", "/tmp/xf86-input-neko.sock", "socket path for custom xf86 input driver connection")
if err := viper.BindPFlag("desktop.input.socket", cmd.PersistentFlags().Lookup("desktop.input.socket")); err != nil {
return err
}
cmd.PersistentFlags().Bool("desktop.unminimize", true, "automatically unminimize window when it is minimized")
if err := viper.BindPFlag("desktop.unminimize", cmd.PersistentFlags().Lookup("desktop.unminimize")); err != nil {
return err
}
cmd.PersistentFlags().Bool("desktop.upload_drop", true, "whether drop upload is enabled")
if err := viper.BindPFlag("desktop.upload_drop", cmd.PersistentFlags().Lookup("desktop.upload_drop")); err != nil {
return err
}
cmd.PersistentFlags().Bool("desktop.file_chooser_dialog", false, "whether to handle file chooser dialog externally")
if err := viper.BindPFlag("desktop.file_chooser_dialog", cmd.PersistentFlags().Lookup("desktop.file_chooser_dialog")); err != nil {
return err
}
return nil
}
func (s *Desktop) Set() {
// Display is provided by env variable
s.Display = os.Getenv("DISPLAY")
s.ScreenSize = types.ScreenSize{
Width: 1280,
Height: 720,
Rate: 30,
}
r := regexp.MustCompile(`([0-9]{1,4})x([0-9]{1,4})@([0-9]{1,3})`)
res := r.FindStringSubmatch(viper.GetString("desktop.screen"))
if len(res) > 0 {
width, err1 := strconv.ParseInt(res[1], 10, 64)
height, err2 := strconv.ParseInt(res[2], 10, 64)
rate, err3 := strconv.ParseInt(res[3], 10, 64)
if err1 == nil && err2 == nil && err3 == nil {
s.ScreenSize.Width = int(width)
s.ScreenSize.Height = int(height)
s.ScreenSize.Rate = int16(rate)
}
}
s.UseInputDriver = viper.GetBool("desktop.input.enabled")
s.InputSocket = viper.GetString("desktop.input.socket")
s.Unminimize = viper.GetBool("desktop.unminimize")
s.UploadDrop = viper.GetBool("desktop.upload_drop")
s.FileChooserDialog = viper.GetBool("desktop.file_chooser_dialog")
}

View File

@ -0,0 +1,128 @@
package config
import (
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/demodesk/neko/internal/member/file"
"github.com/demodesk/neko/internal/member/multiuser"
"github.com/demodesk/neko/internal/member/object"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type Member struct {
Provider string
// providers
File file.Config
Object object.Config
Multiuser multiuser.Config
}
func (Member) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("member.provider", "multiuser", "choose member provider")
if err := viper.BindPFlag("member.provider", cmd.PersistentFlags().Lookup("member.provider")); err != nil {
return err
}
// file provider
cmd.PersistentFlags().String("member.file.path", "", "member file provider: storage path")
if err := viper.BindPFlag("member.file.path", cmd.PersistentFlags().Lookup("member.file.path")); err != nil {
return err
}
cmd.PersistentFlags().Bool("member.file.hash", true, "member file provider: whether to hash passwords using sha256 (recommended)")
if err := viper.BindPFlag("member.file.hash", cmd.PersistentFlags().Lookup("member.file.hash")); err != nil {
return err
}
// object provider
cmd.PersistentFlags().String("member.object.users", "[]", "member object provider: users in JSON format")
if err := viper.BindPFlag("member.object.users", cmd.PersistentFlags().Lookup("member.object.users")); err != nil {
return err
}
// multiuser provider
cmd.PersistentFlags().String("member.multiuser.user_password", "neko", "member multiuser provider: user password")
if err := viper.BindPFlag("member.multiuser.user_password", cmd.PersistentFlags().Lookup("member.multiuser.user_password")); err != nil {
return err
}
cmd.PersistentFlags().String("member.multiuser.admin_password", "admin", "member multiuser provider: admin password")
if err := viper.BindPFlag("member.multiuser.admin_password", cmd.PersistentFlags().Lookup("member.multiuser.admin_password")); err != nil {
return err
}
cmd.PersistentFlags().String("member.multiuser.user_profile", "{}", "member multiuser provider: user profile in JSON format")
if err := viper.BindPFlag("member.multiuser.user_profile", cmd.PersistentFlags().Lookup("member.multiuser.user_profile")); err != nil {
return err
}
cmd.PersistentFlags().String("member.multiuser.admin_profile", "{}", "member multiuser provider: admin profile in JSON format")
if err := viper.BindPFlag("member.multiuser.admin_profile", cmd.PersistentFlags().Lookup("member.multiuser.admin_profile")); err != nil {
return err
}
return nil
}
func (s *Member) Set() {
s.Provider = viper.GetString("member.provider")
// file provider
s.File.Path = viper.GetString("member.file.path")
s.File.Hash = viper.GetBool("member.file.hash")
// object provider
if err := viper.UnmarshalKey("member.object.users", &s.Object.Users, viper.DecodeHook(
utils.JsonStringAutoDecode(s.Object.Users),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse member object users")
}
// multiuser provider
s.Multiuser.UserPassword = viper.GetString("member.multiuser.user_password")
s.Multiuser.AdminPassword = viper.GetString("member.multiuser.admin_password")
// default user profile
s.Multiuser.UserProfile = types.MemberProfile{
IsAdmin: false,
CanLogin: true,
CanConnect: true,
CanWatch: true,
CanHost: true,
CanShareMedia: true,
CanAccessClipboard: true,
SendsInactiveCursor: true,
CanSeeInactiveCursors: false,
}
// override user profile
if err := viper.UnmarshalKey("member.multiuser.user_profile", &s.Multiuser.UserProfile, viper.DecodeHook(
utils.JsonStringAutoDecode(s.Multiuser.UserProfile),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse member multiuser user profile")
}
// default admin profile
s.Multiuser.AdminProfile = types.MemberProfile{
IsAdmin: true,
CanLogin: true,
CanConnect: true,
CanWatch: true,
CanHost: true,
CanShareMedia: true,
CanAccessClipboard: true,
SendsInactiveCursor: true,
CanSeeInactiveCursors: true,
}
// override admin profile
if err := viper.UnmarshalKey("member.multiuser.admin_profile", &s.Multiuser.AdminProfile, viper.DecodeHook(
utils.JsonStringAutoDecode(s.Multiuser.AdminProfile),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse member multiuser admin profile")
}
}

View File

@ -0,0 +1,37 @@
package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Plugins struct {
Enabled bool
Dir string
Required bool
}
func (Plugins) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().Bool("plugins.enabled", false, "load plugins in runtime")
if err := viper.BindPFlag("plugins.enabled", cmd.PersistentFlags().Lookup("plugins.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("plugins.dir", "./bin/plugins", "path to neko plugins to load")
if err := viper.BindPFlag("plugins.dir", cmd.PersistentFlags().Lookup("plugins.dir")); err != nil {
return err
}
cmd.PersistentFlags().Bool("plugins.required", false, "if true, neko will exit if there is an error when loading a plugin")
if err := viper.BindPFlag("plugins.required", cmd.PersistentFlags().Lookup("plugins.required")); err != nil {
return err
}
return nil
}
func (s *Plugins) Set() {
s.Enabled = viper.GetBool("plugins.enabled")
s.Dir = viper.GetString("plugins.dir")
s.Required = viper.GetBool("plugins.required")
}

View File

@ -0,0 +1,97 @@
package config
import (
"os"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Root struct {
Config string
LogLevel zerolog.Level
LogTime string
LogJson bool
LogNocolor bool
LogDir string
}
func (Root) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().StringP("config", "c", "", "configuration file path")
if err := viper.BindPFlag("config", cmd.PersistentFlags().Lookup("config")); err != nil {
return err
}
// just a shortcut
cmd.PersistentFlags().BoolP("debug", "d", false, "enable debug mode")
if err := viper.BindPFlag("debug", cmd.PersistentFlags().Lookup("debug")); err != nil {
return err
}
cmd.PersistentFlags().String("log.level", "info", "set log level (trace, debug, info, warn, error, fatal, panic, disabled)")
if err := viper.BindPFlag("log.level", cmd.PersistentFlags().Lookup("log.level")); err != nil {
return err
}
cmd.PersistentFlags().String("log.time", "unix", "time format used in logs (unix, unixms, unixmicro)")
if err := viper.BindPFlag("log.time", cmd.PersistentFlags().Lookup("log.time")); err != nil {
return err
}
cmd.PersistentFlags().Bool("log.json", false, "logs in JSON format")
if err := viper.BindPFlag("log.json", cmd.PersistentFlags().Lookup("log.json")); err != nil {
return err
}
cmd.PersistentFlags().Bool("log.nocolor", false, "no ANSI colors in non-JSON output")
if err := viper.BindPFlag("log.nocolor", cmd.PersistentFlags().Lookup("log.nocolor")); err != nil {
return err
}
cmd.PersistentFlags().String("log.dir", "", "logging directory to store logs")
if err := viper.BindPFlag("log.dir", cmd.PersistentFlags().Lookup("log.dir")); err != nil {
return err
}
return nil
}
func (s *Root) Set() {
s.Config = viper.GetString("config")
logLevel := viper.GetString("log.level")
level, err := zerolog.ParseLevel(logLevel)
if err != nil {
log.Warn().Msgf("unknown log level %s", logLevel)
} else {
s.LogLevel = level
}
logTime := viper.GetString("log.time")
switch logTime {
case "unix":
s.LogTime = zerolog.TimeFormatUnix
case "unixms":
s.LogTime = zerolog.TimeFormatUnixMs
case "unixmicro":
s.LogTime = zerolog.TimeFormatUnixMicro
default:
log.Warn().Msgf("unknown log time %s", logTime)
}
s.LogJson = viper.GetBool("log.json")
s.LogNocolor = viper.GetBool("log.nocolor")
s.LogDir = viper.GetString("log.dir")
if viper.GetBool("debug") && s.LogLevel != zerolog.TraceLevel {
s.LogLevel = zerolog.DebugLevel
}
// support for NO_COLOR env variable: https://no-color.org/
if os.Getenv("NO_COLOR") != "" {
s.LogNocolor = true
}
}

View File

@ -0,0 +1,103 @@
package config
import (
"path"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/demodesk/neko/pkg/utils"
)
type Server struct {
Cert string
Key string
Bind string
Proxy bool
Static string
PathPrefix string
PProf bool
Metrics bool
CORS []string
}
func (Server) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("server.bind", "127.0.0.1:8080", "address/port/socket to serve neko")
if err := viper.BindPFlag("server.bind", cmd.PersistentFlags().Lookup("server.bind")); err != nil {
return err
}
cmd.PersistentFlags().String("server.cert", "", "path to the SSL cert used to secure the neko server")
if err := viper.BindPFlag("server.cert", cmd.PersistentFlags().Lookup("server.cert")); err != nil {
return err
}
cmd.PersistentFlags().String("server.key", "", "path to the SSL key used to secure the neko server")
if err := viper.BindPFlag("server.key", cmd.PersistentFlags().Lookup("server.key")); err != nil {
return err
}
cmd.PersistentFlags().Bool("server.proxy", false, "trust reverse proxy headers")
if err := viper.BindPFlag("server.proxy", cmd.PersistentFlags().Lookup("server.proxy")); err != nil {
return err
}
cmd.PersistentFlags().String("server.static", "", "path to neko client files to serve")
if err := viper.BindPFlag("server.static", cmd.PersistentFlags().Lookup("server.static")); err != nil {
return err
}
cmd.PersistentFlags().String("server.path_prefix", "/", "path prefix for HTTP requests")
if err := viper.BindPFlag("server.path_prefix", cmd.PersistentFlags().Lookup("server.path_prefix")); err != nil {
return err
}
cmd.PersistentFlags().Bool("server.pprof", false, "enable pprof endpoint available at /debug/pprof")
if err := viper.BindPFlag("server.pprof", cmd.PersistentFlags().Lookup("server.pprof")); err != nil {
return err
}
cmd.PersistentFlags().Bool("server.metrics", true, "enable prometheus metrics available at /metrics")
if err := viper.BindPFlag("server.metrics", cmd.PersistentFlags().Lookup("server.metrics")); err != nil {
return err
}
cmd.PersistentFlags().StringSlice("server.cors", []string{}, "list of allowed origins for CORS, if empty CORS is disabled, if '*' is present all origins are allowed")
if err := viper.BindPFlag("server.cors", cmd.PersistentFlags().Lookup("server.cors")); err != nil {
return err
}
return nil
}
func (s *Server) Set() {
s.Cert = viper.GetString("server.cert")
s.Key = viper.GetString("server.key")
s.Bind = viper.GetString("server.bind")
s.Proxy = viper.GetBool("server.proxy")
s.Static = viper.GetString("server.static")
s.PathPrefix = path.Join("/", path.Clean(viper.GetString("server.path_prefix")))
s.PProf = viper.GetBool("server.pprof")
s.Metrics = viper.GetBool("server.metrics")
s.CORS = viper.GetStringSlice("server.cors")
in, _ := utils.ArrayIn("*", s.CORS)
if len(s.CORS) == 0 || in {
s.CORS = []string{"*"}
}
}
func (s *Server) HasCors() bool {
return len(s.CORS) > 0
}
func (s *Server) AllowOrigin(origin string) bool {
// if CORS is disabled, allow all origins
if len(s.CORS) == 0 {
return true
}
// if CORS is enabled, allow only origins in the list
in, _ := utils.ArrayIn(origin, s.CORS)
return in || s.CORS[0] == "*"
}

View File

@ -0,0 +1,114 @@
package config
import (
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Session struct {
File string
PrivateMode bool
LockedLogins bool
LockedControls bool
ControlProtection bool
ImplicitHosting bool
InactiveCursors bool
MercifulReconnect bool
APIToken string
CookieEnabled bool
CookieName string
CookieExpiration time.Duration
CookieSecure bool
}
func (Session) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().String("session.file", "", "if sessions should be stored in a file, otherwise they will be stored only in memory")
if err := viper.BindPFlag("session.file", cmd.PersistentFlags().Lookup("session.file")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.private_mode", false, "whether private mode should be enabled initially")
if err := viper.BindPFlag("session.private_mode", cmd.PersistentFlags().Lookup("session.private_mode")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.locked_logins", false, "whether logins should be locked for users initially")
if err := viper.BindPFlag("session.locked_logins", cmd.PersistentFlags().Lookup("session.locked_logins")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.locked_controls", false, "whether controls should be locked for users initially")
if err := viper.BindPFlag("session.locked_controls", cmd.PersistentFlags().Lookup("session.locked_controls")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.control_protection", false, "users can gain control only if at least one admin is in the room")
if err := viper.BindPFlag("session.control_protection", cmd.PersistentFlags().Lookup("session.control_protection")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.implicit_hosting", true, "allow implicit control switching")
if err := viper.BindPFlag("session.implicit_hosting", cmd.PersistentFlags().Lookup("session.implicit_hosting")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.inactive_cursors", false, "show inactive cursors on the screen")
if err := viper.BindPFlag("session.inactive_cursors", cmd.PersistentFlags().Lookup("session.inactive_cursors")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.merciful_reconnect", true, "allow reconnecting to websocket even if previous connection was not closed")
if err := viper.BindPFlag("session.merciful_reconnect", cmd.PersistentFlags().Lookup("session.merciful_reconnect")); err != nil {
return err
}
cmd.PersistentFlags().String("session.api_token", "", "API token for interacting with external services")
if err := viper.BindPFlag("session.api_token", cmd.PersistentFlags().Lookup("session.api_token")); err != nil {
return err
}
// cookie
cmd.PersistentFlags().Bool("session.cookie.enabled", true, "whether cookies authentication should be enabled")
if err := viper.BindPFlag("session.cookie.enabled", cmd.PersistentFlags().Lookup("session.cookie.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("session.cookie.name", "NEKO_SESSION", "name of the cookie that holds token")
if err := viper.BindPFlag("session.cookie.name", cmd.PersistentFlags().Lookup("session.cookie.name")); err != nil {
return err
}
cmd.PersistentFlags().Int("session.cookie.expiration", 365*24, "expiration of the cookie in hours")
if err := viper.BindPFlag("session.cookie.expiration", cmd.PersistentFlags().Lookup("session.cookie.expiration")); err != nil {
return err
}
cmd.PersistentFlags().Bool("session.cookie.secure", true, "use secure cookies")
if err := viper.BindPFlag("session.cookie.secure", cmd.PersistentFlags().Lookup("session.cookie.secure")); err != nil {
return err
}
return nil
}
func (s *Session) Set() {
s.File = viper.GetString("session.file")
s.PrivateMode = viper.GetBool("session.private_mode")
s.LockedLogins = viper.GetBool("session.locked_logins")
s.LockedControls = viper.GetBool("session.locked_controls")
s.ControlProtection = viper.GetBool("session.control_protection")
s.ImplicitHosting = viper.GetBool("session.implicit_hosting")
s.InactiveCursors = viper.GetBool("session.inactive_cursors")
s.MercifulReconnect = viper.GetBool("session.merciful_reconnect")
s.APIToken = viper.GetString("session.api_token")
s.CookieEnabled = viper.GetBool("session.cookie.enabled")
s.CookieName = viper.GetString("session.cookie.name")
s.CookieExpiration = time.Duration(viper.GetInt("session.cookie.expiration")) * time.Hour
s.CookieSecure = viper.GetBool("session.cookie.secure")
}

View File

@ -0,0 +1,273 @@
package config
import (
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
// default stun server
const defStunSrv = "stun:stun.l.google.com:19302"
type WebRTCEstimator struct {
Enabled bool
Passive bool
Debug bool
InitialBitrate int
// how often to read and process bandwidth estimation reports
ReadInterval time.Duration
// how long to wait for stable connection (only neutral or upward trend) before upgrading
StableDuration time.Duration
// how long to wait for unstable connection (downward trend) before downgrading
UnstableDuration time.Duration
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
StalledDuration time.Duration
// how long to wait before downgrading again after previous downgrade
DowngradeBackoff time.Duration
// how long to wait before upgrading again after previous upgrade
UpgradeBackoff time.Duration
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
DiffThreshold float64
}
type WebRTC struct {
ICELite bool
ICETrickle bool
ICEServersFrontend []types.ICEServer
ICEServersBackend []types.ICEServer
EphemeralMin uint16
EphemeralMax uint16
TCPMux int
UDPMux int
NAT1To1IPs []string
IpRetrievalUrl string
Estimator WebRTCEstimator
}
func (WebRTC) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().Bool("webrtc.icelite", false, "configures whether or not the ICE agent should be a lite agent")
if err := viper.BindPFlag("webrtc.icelite", cmd.PersistentFlags().Lookup("webrtc.icelite")); err != nil {
return err
}
cmd.PersistentFlags().Bool("webrtc.icetrickle", true, "configures whether cadidates should be sent asynchronously using Trickle ICE")
if err := viper.BindPFlag("webrtc.icetrickle", cmd.PersistentFlags().Lookup("webrtc.icetrickle")); err != nil {
return err
}
// Looks like this is conflicting with the frontend and backend ICE servers since latest versions
//cmd.PersistentFlags().String("webrtc.iceservers", "[]", "Global STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
//if err := viper.BindPFlag("webrtc.iceservers", cmd.PersistentFlags().Lookup("webrtc.iceservers")); err != nil {
// return err
//}
cmd.PersistentFlags().String("webrtc.iceservers.frontend", "[]", "Frontend only STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
if err := viper.BindPFlag("webrtc.iceservers.frontend", cmd.PersistentFlags().Lookup("webrtc.iceservers.frontend")); err != nil {
return err
}
cmd.PersistentFlags().String("webrtc.iceservers.backend", "[]", "Backend only STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
if err := viper.BindPFlag("webrtc.iceservers.backend", cmd.PersistentFlags().Lookup("webrtc.iceservers.backend")); err != nil {
return err
}
cmd.PersistentFlags().String("webrtc.epr", "", "limits the pool of ephemeral ports that ICE UDP connections can allocate from")
if err := viper.BindPFlag("webrtc.epr", cmd.PersistentFlags().Lookup("webrtc.epr")); err != nil {
return err
}
cmd.PersistentFlags().Int("webrtc.tcpmux", 0, "single TCP mux port for all peers")
if err := viper.BindPFlag("webrtc.tcpmux", cmd.PersistentFlags().Lookup("webrtc.tcpmux")); err != nil {
return err
}
cmd.PersistentFlags().Int("webrtc.udpmux", 0, "single UDP mux port for all peers, replaces EPR")
if err := viper.BindPFlag("webrtc.udpmux", cmd.PersistentFlags().Lookup("webrtc.udpmux")); err != nil {
return err
}
cmd.PersistentFlags().StringSlice("webrtc.nat1to1", []string{}, "sets a list of external IP addresses of 1:1 (D)NAT and a candidate type for which the external IP address is used")
if err := viper.BindPFlag("webrtc.nat1to1", cmd.PersistentFlags().Lookup("webrtc.nat1to1")); err != nil {
return err
}
cmd.PersistentFlags().String("webrtc.ip_retrieval_url", "https://checkip.amazonaws.com", "URL address used for retrieval of the external IP address")
if err := viper.BindPFlag("webrtc.ip_retrieval_url", cmd.PersistentFlags().Lookup("webrtc.ip_retrieval_url")); err != nil {
return err
}
// bandwidth estimator
cmd.PersistentFlags().Bool("webrtc.estimator.enabled", false, "enables the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.enabled", cmd.PersistentFlags().Lookup("webrtc.estimator.enabled")); err != nil {
return err
}
cmd.PersistentFlags().Bool("webrtc.estimator.passive", false, "passive estimator mode, when it does not switch pipelines, only estimates")
if err := viper.BindPFlag("webrtc.estimator.passive", cmd.PersistentFlags().Lookup("webrtc.estimator.passive")); err != nil {
return err
}
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
return err
}
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
return err
}
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
return err
}
return nil
}
func (s *WebRTC) Set() {
s.ICELite = viper.GetBool("webrtc.icelite")
s.ICETrickle = viper.GetBool("webrtc.icetrickle")
// parse frontend ice servers
if err := viper.UnmarshalKey("webrtc.iceservers.frontend", &s.ICEServersFrontend, viper.DecodeHook(
utils.JsonStringAutoDecode([]types.ICEServer{}),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse frontend ICE servers")
}
// parse backend ice servers
if err := viper.UnmarshalKey("webrtc.iceservers.backend", &s.ICEServersBackend, viper.DecodeHook(
utils.JsonStringAutoDecode([]types.ICEServer{}),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse backend ICE servers")
}
if s.ICELite && len(s.ICEServersBackend) > 0 {
log.Warn().Msgf("ICE Lite is enabled, but backend ICE servers are configured. Backend ICE servers will be ignored.")
}
// if no frontend or backend ice servers are configured
if len(s.ICEServersFrontend) == 0 && len(s.ICEServersBackend) == 0 {
// parse global ice servers
var iceServers []types.ICEServer
if err := viper.UnmarshalKey("webrtc.iceservers", &iceServers, viper.DecodeHook(
utils.JsonStringAutoDecode([]types.ICEServer{}),
)); err != nil {
log.Warn().Err(err).Msgf("unable to parse global ICE servers")
}
// add default stun server if none are configured
if len(iceServers) == 0 {
iceServers = append(iceServers, types.ICEServer{
URLs: []string{defStunSrv},
})
}
s.ICEServersFrontend = append(s.ICEServersFrontend, iceServers...)
s.ICEServersBackend = append(s.ICEServersBackend, iceServers...)
}
s.TCPMux = viper.GetInt("webrtc.tcpmux")
s.UDPMux = viper.GetInt("webrtc.udpmux")
epr := viper.GetString("webrtc.epr")
if epr != "" {
ports := strings.SplitN(epr, "-", -1)
if len(ports) > 1 {
min, err := strconv.ParseUint(ports[0], 10, 16)
if err != nil {
log.Panic().Err(err).Msgf("unable to parse ephemeral min port")
}
max, err := strconv.ParseUint(ports[1], 10, 16)
if err != nil {
log.Panic().Err(err).Msgf("unable to parse ephemeral max port")
}
s.EphemeralMin = uint16(min)
s.EphemeralMax = uint16(max)
}
if s.EphemeralMin > s.EphemeralMax {
log.Panic().Msgf("ephemeral min port cannot be bigger than max")
}
}
if epr == "" && s.TCPMux == 0 && s.UDPMux == 0 {
// using default epr range
s.EphemeralMin = 59000
s.EphemeralMax = 59100
log.Warn().
Uint16("min", s.EphemeralMin).
Uint16("max", s.EphemeralMax).
Msgf("no TCP, UDP mux or epr specified, using default epr range")
}
s.NAT1To1IPs = viper.GetStringSlice("webrtc.nat1to1")
s.IpRetrievalUrl = viper.GetString("webrtc.ip_retrieval_url")
if s.IpRetrievalUrl != "" && len(s.NAT1To1IPs) == 0 {
ip, err := utils.HttpRequestGET(s.IpRetrievalUrl)
if err == nil {
s.NAT1To1IPs = append(s.NAT1To1IPs, ip)
} else {
log.Warn().Err(err).Msgf("IP retrieval failed")
}
}
// bandwidth estimator
s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
}

View File

@ -0,0 +1,122 @@
package desktop
import (
"bytes"
"fmt"
"os/exec"
"strings"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/xevent"
)
func (manager *DesktopManagerCtx) ClipboardGetText() (*types.ClipboardText, error) {
text, err := manager.ClipboardGetBinary("STRING")
if err != nil {
return nil, err
}
// Rich text must not always be available, can fail silently.
html, _ := manager.ClipboardGetBinary("text/html")
return &types.ClipboardText{
Text: string(text),
HTML: string(html),
}, nil
}
func (manager *DesktopManagerCtx) ClipboardSetText(data types.ClipboardText) error {
// TODO: Refactor.
// Current implementation is unable to set multiple targets. HTML
// is set, if available. Otherwise plain text.
if data.HTML != "" {
return manager.ClipboardSetBinary("text/html", []byte(data.HTML))
}
return manager.ClipboardSetBinary("STRING", []byte(data.Text))
}
func (manager *DesktopManagerCtx) ClipboardGetBinary(mime string) ([]byte, error) {
cmd := exec.Command("xclip", "-selection", "clipboard", "-out", "-target", mime)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
msg := strings.TrimSpace(stderr.String())
return nil, fmt.Errorf("%s", msg)
}
return stdout.Bytes(), nil
}
func (manager *DesktopManagerCtx) ClipboardSetBinary(mime string, data []byte) error {
cmd := exec.Command("xclip", "-selection", "clipboard", "-in", "-target", mime)
var stderr bytes.Buffer
cmd.Stderr = &stderr
stdin, err := cmd.StdinPipe()
if err != nil {
return err
}
// TODO: Refactor.
// We need to wait until the data came to the clipboard.
wait := make(chan struct{})
xevent.Emmiter.Once("clipboard-updated", func(payload ...any) {
wait <- struct{}{}
})
err = cmd.Start()
if err != nil {
msg := strings.TrimSpace(stderr.String())
return fmt.Errorf("%s", msg)
}
_, err = stdin.Write(data)
if err != nil {
return err
}
stdin.Close()
// TODO: Refactor.
// cmd.Wait()
<-wait
return nil
}
func (manager *DesktopManagerCtx) ClipboardGetTargets() ([]string, error) {
cmd := exec.Command("xclip", "-selection", "clipboard", "-out", "-target", "TARGETS")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
msg := strings.TrimSpace(stderr.String())
return nil, fmt.Errorf("%s", msg)
}
var response []string
targets := strings.Split(stdout.String(), "\n")
for _, target := range targets {
if target == "" {
continue
}
if !strings.Contains(target, "/") {
continue
}
response = append(response, target)
}
return response, nil
}

View File

@ -0,0 +1,68 @@
package desktop
import (
"time"
"github.com/demodesk/neko/pkg/drop"
)
// repeat move event multiple times
const dropMoveRepeat = 4
// wait after each repeated move event
const dropMoveDelay = 100 * time.Millisecond
func (manager *DesktopManagerCtx) DropFiles(x int, y int, files []string) bool {
mu.Lock()
defer mu.Unlock()
drop.Emmiter.Clear()
drop.Emmiter.Once("create", func(payload ...any) {
manager.Move(0, 0)
})
drop.Emmiter.Once("cursor-enter", func(payload ...any) {
//nolint
manager.ButtonDown(1)
})
drop.Emmiter.Once("button-press", func(payload ...any) {
manager.Move(x, y)
})
drop.Emmiter.Once("begin", func(payload ...any) {
for i := 0; i < dropMoveRepeat; i++ {
manager.Move(x, y)
time.Sleep(dropMoveDelay)
}
//nolint
manager.ButtonUp(1)
})
finished := make(chan bool)
drop.Emmiter.Once("finish", func(payload ...any) {
b, ok := payload[0].(bool)
// workaround until https://github.com/kataras/go-events/pull/8 is merged
if !ok {
b = (payload[0].([]any))[0].(bool)
}
finished <- b
})
manager.ResetKeys()
go drop.OpenWindow(files)
select {
case succeeded := <-finished:
return succeeded
case <-time.After(1 * time.Second):
drop.CloseWindow()
return false
}
}
func (manager *DesktopManagerCtx) IsUploadDropEnabled() bool {
return manager.config.UploadDrop
}

View File

@ -0,0 +1,102 @@
package desktop
import (
"errors"
"os/exec"
"github.com/demodesk/neko/pkg/xorg"
)
// name of the window that is being controlled
const fileChooserDialogName = "Open File"
// short sleep value between fake user interactions
const fileChooserDialogShortSleep = "0.2"
// long sleep value between fake user interactions
const fileChooserDialogLongSleep = "0.4"
func (manager *DesktopManagerCtx) HandleFileChooserDialog(uri string) error {
mu.Lock()
defer mu.Unlock()
// TODO: Use native API.
err1 := exec.Command(
"xdotool",
"search", "--name", fileChooserDialogName, "windowfocus",
"sleep", fileChooserDialogShortSleep,
"key", "--clearmodifiers", "ctrl+l",
"type", "--args", "1", uri+"//",
"sleep", fileChooserDialogShortSleep,
"key", "Delete", // remove autocomplete results
"sleep", fileChooserDialogShortSleep,
"key", "Return",
"sleep", fileChooserDialogLongSleep,
"key", "Down",
"key", "--clearmodifiers", "ctrl+a",
"key", "Return",
"sleep", fileChooserDialogLongSleep,
).Run()
if err1 != nil {
return err1
}
// TODO: Use native API.
err2 := exec.Command(
"xdotool",
"search", "--name", fileChooserDialogName,
).Run()
// if last command didn't return error, consider dialog as still open
if err2 == nil {
return errors.New("unable to select files in dialog")
}
return nil
}
func (manager *DesktopManagerCtx) CloseFileChooserDialog() {
for i := 0; i < 5; i++ {
mu.Lock()
manager.logger.Debug().Msg("attempting to close file chooser dialog")
// TODO: Use native API.
err := exec.Command(
"xdotool",
"search", "--name", fileChooserDialogName, "windowfocus",
).Run()
if err != nil {
mu.Unlock()
manager.logger.Info().Msg("file chooser dialog is closed")
return
}
// custom press Alt + F4
// because xdotool is failing to send proper Alt+F4
//nolint
manager.KeyPress(xorg.XK_Alt_L, xorg.XK_F4)
mu.Unlock()
}
}
func (manager *DesktopManagerCtx) IsFileChooserDialogEnabled() bool {
return manager.config.FileChooserDialog
}
func (manager *DesktopManagerCtx) IsFileChooserDialogOpened() bool {
mu.Lock()
defer mu.Unlock()
// TODO: Use native API.
err := exec.Command(
"xdotool",
"search", "--name", fileChooserDialogName,
).Run()
return err == nil
}

View File

@ -0,0 +1,138 @@
package desktop
import (
"sync"
"time"
"github.com/kataras/go-events"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/xevent"
"github.com/demodesk/neko/pkg/xinput"
"github.com/demodesk/neko/pkg/xorg"
)
var mu = sync.Mutex{}
type DesktopManagerCtx struct {
logger zerolog.Logger
wg sync.WaitGroup
shutdown chan struct{}
emmiter events.EventEmmiter
config *config.Desktop
screenSize types.ScreenSize // cached screen size
input xinput.Driver
}
func New(config *config.Desktop) *DesktopManagerCtx {
var input xinput.Driver
if config.UseInputDriver {
input = xinput.NewDriver(config.InputSocket)
} else {
input = xinput.NewDummy()
}
return &DesktopManagerCtx{
logger: log.With().Str("module", "desktop").Logger(),
shutdown: make(chan struct{}),
emmiter: events.New(),
config: config,
screenSize: config.ScreenSize,
input: input,
}
}
func (manager *DesktopManagerCtx) Start() {
if xorg.DisplayOpen(manager.config.Display) {
manager.logger.Panic().Str("display", manager.config.Display).Msg("unable to open display")
}
// X11 can throw errors below, and the default error handler exits
xevent.SetupErrorHandler()
xorg.GetScreenConfigurations()
screenSize, err := xorg.ChangeScreenSize(manager.config.ScreenSize)
if err != nil {
manager.logger.Err(err).
Str("screen_size", screenSize.String()).
Msgf("unable to set initial screen size")
} else {
// cache screen size
manager.screenSize = screenSize
manager.logger.Info().
Str("screen_size", screenSize.String()).
Msgf("setting initial screen size")
}
err = manager.input.Connect()
if err != nil {
// TODO: fail silently to dummy driver?
manager.logger.Panic().Err(err).Msg("unable to connect to input driver")
}
// set up event listeners
xevent.Unminimize = manager.config.Unminimize
xevent.FileChooserDialog = manager.config.FileChooserDialog
go xevent.EventLoop(manager.config.Display)
// in case it was opened
if manager.config.FileChooserDialog {
go manager.CloseFileChooserDialog()
}
manager.OnEventError(func(error_code uint8, message string, request_code uint8, minor_code uint8) {
manager.logger.Warn().
Uint8("error_code", error_code).
Str("message", message).
Uint8("request_code", request_code).
Uint8("minor_code", minor_code).
Msg("X event error occured")
})
manager.wg.Add(1)
go func() {
defer manager.wg.Done()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
const debounceDuration = 10 * time.Second
for {
select {
case <-manager.shutdown:
return
case <-ticker.C:
xorg.CheckKeys(debounceDuration)
manager.input.Debounce(debounceDuration)
}
}
}()
}
func (manager *DesktopManagerCtx) OnBeforeScreenSizeChange(listener func()) {
manager.emmiter.On("before_screen_size_change", func(payload ...any) {
listener()
})
}
func (manager *DesktopManagerCtx) OnAfterScreenSizeChange(listener func()) {
manager.emmiter.On("after_screen_size_change", func(payload ...any) {
listener()
})
}
func (manager *DesktopManagerCtx) Shutdown() error {
manager.logger.Info().Msgf("shutdown")
close(manager.shutdown)
manager.wg.Wait()
xorg.DisplayClose()
return nil
}

View File

@ -0,0 +1,35 @@
package desktop
import (
"github.com/demodesk/neko/pkg/xevent"
)
func (manager *DesktopManagerCtx) OnCursorChanged(listener func(serial uint64)) {
xevent.Emmiter.On("cursor-changed", func(payload ...any) {
listener(payload[0].(uint64))
})
}
func (manager *DesktopManagerCtx) OnClipboardUpdated(listener func()) {
xevent.Emmiter.On("clipboard-updated", func(payload ...any) {
listener()
})
}
func (manager *DesktopManagerCtx) OnFileChooserDialogOpened(listener func()) {
xevent.Emmiter.On("file-chooser-dialog-opened", func(payload ...any) {
listener()
})
}
func (manager *DesktopManagerCtx) OnFileChooserDialogClosed(listener func()) {
xevent.Emmiter.On("file-chooser-dialog-closed", func(payload ...any) {
listener()
})
}
func (manager *DesktopManagerCtx) OnEventError(listener func(error_code uint8, message string, request_code uint8, minor_code uint8)) {
xevent.Emmiter.On("event-error", func(payload ...any) {
listener(payload[0].(uint8), payload[1].(string), payload[2].(uint8), payload[3].(uint8))
})
}

View File

@ -0,0 +1,36 @@
package desktop
import "github.com/demodesk/neko/pkg/xinput"
func (manager *DesktopManagerCtx) inputRelToAbs(x, y int) (int, int) {
return (x * xinput.AbsX) / manager.screenSize.Width, (y * xinput.AbsY) / manager.screenSize.Height
}
func (manager *DesktopManagerCtx) HasTouchSupport() bool {
// we assume now, that if the input driver is enabled, we have touch support
return manager.config.UseInputDriver
}
func (manager *DesktopManagerCtx) TouchBegin(touchId uint32, x, y int, pressure uint8) error {
mu.Lock()
defer mu.Unlock()
x, y = manager.inputRelToAbs(x, y)
return manager.input.TouchBegin(touchId, x, y, pressure)
}
func (manager *DesktopManagerCtx) TouchUpdate(touchId uint32, x, y int, pressure uint8) error {
mu.Lock()
defer mu.Unlock()
x, y = manager.inputRelToAbs(x, y)
return manager.input.TouchUpdate(touchId, x, y, pressure)
}
func (manager *DesktopManagerCtx) TouchEnd(touchId uint32, x, y int, pressure uint8) error {
mu.Lock()
defer mu.Unlock()
x, y = manager.inputRelToAbs(x, y)
return manager.input.TouchEnd(touchId, x, y, pressure)
}

View File

@ -0,0 +1,202 @@
package desktop
import (
"image"
"os/exec"
"regexp"
"time"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/xorg"
)
func (manager *DesktopManagerCtx) Move(x, y int) {
xorg.Move(x, y)
}
func (manager *DesktopManagerCtx) GetCursorPosition() (int, int) {
return xorg.GetCursorPosition()
}
func (manager *DesktopManagerCtx) Scroll(deltaX, deltaY int, controlKey bool) {
xorg.Scroll(deltaX, deltaY, controlKey)
}
func (manager *DesktopManagerCtx) ButtonDown(code uint32) error {
return xorg.ButtonDown(code)
}
func (manager *DesktopManagerCtx) KeyDown(code uint32) error {
return xorg.KeyDown(code)
}
func (manager *DesktopManagerCtx) ButtonUp(code uint32) error {
return xorg.ButtonUp(code)
}
func (manager *DesktopManagerCtx) KeyUp(code uint32) error {
return xorg.KeyUp(code)
}
func (manager *DesktopManagerCtx) ButtonPress(code uint32) error {
xorg.ResetKeys()
defer xorg.ResetKeys()
return xorg.ButtonDown(code)
}
func (manager *DesktopManagerCtx) KeyPress(codes ...uint32) error {
xorg.ResetKeys()
defer xorg.ResetKeys()
for _, code := range codes {
if err := xorg.KeyDown(code); err != nil {
return err
}
}
if len(codes) > 1 {
time.Sleep(10 * time.Millisecond)
}
return nil
}
func (manager *DesktopManagerCtx) ResetKeys() {
xorg.ResetKeys()
}
func (manager *DesktopManagerCtx) ScreenConfigurations() []types.ScreenSize {
var configs []types.ScreenSize
for _, size := range xorg.ScreenConfigurations {
for _, fps := range size.Rates {
// filter out all irrelevant rates
if fps > 60 || (fps > 30 && fps%10 != 0) {
continue
}
configs = append(configs, types.ScreenSize{
Width: size.Width,
Height: size.Height,
Rate: fps,
})
}
}
return configs
}
func (manager *DesktopManagerCtx) SetScreenSize(screenSize types.ScreenSize) (types.ScreenSize, error) {
mu.Lock()
manager.emmiter.Emit("before_screen_size_change")
defer func() {
manager.emmiter.Emit("after_screen_size_change")
mu.Unlock()
}()
screenSize, err := xorg.ChangeScreenSize(screenSize)
if err == nil {
// cache the new screen size
manager.screenSize = screenSize
}
return screenSize, err
}
func (manager *DesktopManagerCtx) GetScreenSize() types.ScreenSize {
return xorg.GetScreenSize()
}
func (manager *DesktopManagerCtx) SetKeyboardMap(kbd types.KeyboardMap) error {
// TOOD: Use native API.
cmd := exec.Command("setxkbmap", "-layout", kbd.Layout, "-variant", kbd.Variant)
_, err := cmd.Output()
return err
}
func (manager *DesktopManagerCtx) GetKeyboardMap() (*types.KeyboardMap, error) {
// TOOD: Use native API.
cmd := exec.Command("setxkbmap", "-query")
res, err := cmd.Output()
if err != nil {
return nil, err
}
kbd := types.KeyboardMap{}
re := regexp.MustCompile(`layout:\s+(.*)\n`)
arr := re.FindStringSubmatch(string(res))
if len(arr) > 1 {
kbd.Layout = arr[1]
}
re = regexp.MustCompile(`variant:\s+(.*)\n`)
arr = re.FindStringSubmatch(string(res))
if len(arr) > 1 {
kbd.Variant = arr[1]
}
return &kbd, nil
}
func (manager *DesktopManagerCtx) SetKeyboardModifiers(mod types.KeyboardModifiers) {
if mod.Shift != nil {
xorg.SetKeyboardModifier(xorg.KbdModShift, *mod.Shift)
}
if mod.CapsLock != nil {
xorg.SetKeyboardModifier(xorg.KbdModCapsLock, *mod.CapsLock)
}
if mod.Control != nil {
xorg.SetKeyboardModifier(xorg.KbdModControl, *mod.Control)
}
if mod.Alt != nil {
xorg.SetKeyboardModifier(xorg.KbdModAlt, *mod.Alt)
}
if mod.NumLock != nil {
xorg.SetKeyboardModifier(xorg.KbdModNumLock, *mod.NumLock)
}
if mod.Meta != nil {
xorg.SetKeyboardModifier(xorg.KbdModMeta, *mod.Meta)
}
if mod.Super != nil {
xorg.SetKeyboardModifier(xorg.KbdModSuper, *mod.Super)
}
if mod.AltGr != nil {
xorg.SetKeyboardModifier(xorg.KbdModAltGr, *mod.AltGr)
}
}
func (manager *DesktopManagerCtx) GetKeyboardModifiers() types.KeyboardModifiers {
modifiers := xorg.GetKeyboardModifiers()
isset := func(mod xorg.KbdMod) *bool {
x := modifiers&mod != 0
return &x
}
return types.KeyboardModifiers{
Shift: isset(xorg.KbdModShift),
CapsLock: isset(xorg.KbdModCapsLock),
Control: isset(xorg.KbdModControl),
Alt: isset(xorg.KbdModAlt),
NumLock: isset(xorg.KbdModNumLock),
Meta: isset(xorg.KbdModMeta),
Super: isset(xorg.KbdModSuper),
AltGr: isset(xorg.KbdModAltGr),
}
}
func (manager *DesktopManagerCtx) GetCursorImage() *types.CursorImage {
return xorg.GetCursorImage()
}
func (manager *DesktopManagerCtx) GetScreenshotImage() *image.RGBA {
return xorg.GetScreenshotImage()
}

View File

@ -0,0 +1,123 @@
package http
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type BatchRequest struct {
Path string `json:"path"`
Method string `json:"method"`
Body json.RawMessage `json:"body,omitempty"`
}
type BatchResponse struct {
Path string `json:"path"`
Method string `json:"method"`
Body json.RawMessage `json:"body,omitempty"`
Status int `json:"status"`
}
func (b *BatchResponse) Error(httpErr *utils.HTTPError) (err error) {
b.Body, err = json.Marshal(httpErr)
b.Status = httpErr.Code
return
}
type batchHandler struct {
Router types.Router
PathPrefix string
Excluded []string
}
func (b *batchHandler) Handle(w http.ResponseWriter, r *http.Request) error {
var requests []BatchRequest
if err := json.NewDecoder(r.Body).Decode(&requests); err != nil {
return err
}
responses := make([]BatchResponse, len(requests))
for i, request := range requests {
res := BatchResponse{
Path: request.Path,
Method: request.Method,
}
if !strings.HasPrefix(request.Path, b.PathPrefix) {
res.Error(utils.HttpBadRequest("this path is not allowed in batch requests"))
responses[i] = res
continue
}
if exists, _ := utils.ArrayIn(request.Path, b.Excluded); exists {
res.Error(utils.HttpBadRequest("this path is excluded from batch requests"))
responses[i] = res
continue
}
// prepare request
req, err := http.NewRequest(request.Method, request.Path, bytes.NewBuffer(request.Body))
if err != nil {
return err
}
// copy headers
for k, vv := range r.Header {
for _, v := range vv {
req.Header.Add(k, v)
}
}
// execute request
rr := newResponseRecorder()
b.Router.ServeHTTP(rr, req)
// read response
body, err := io.ReadAll(rr.Body)
if err != nil {
return err
}
// write response
responses[i] = BatchResponse{
Path: request.Path,
Method: request.Method,
Body: body,
Status: rr.Code,
}
}
return utils.HttpSuccess(w, responses)
}
type responseRecorder struct {
Code int
HeaderMap http.Header
Body *bytes.Buffer
}
func newResponseRecorder() *responseRecorder {
return &responseRecorder{
Code: http.StatusOK,
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
}
func (w *responseRecorder) Header() http.Header {
return w.HeaderMap
}
func (w *responseRecorder) Write(b []byte) (int, error) {
return w.Body.Write(b)
}
func (w *responseRecorder) WriteHeader(code int) {
w.Code = code
}

View File

@ -0,0 +1,36 @@
package http
import (
"net/http"
"net/http/pprof"
"github.com/go-chi/chi"
"github.com/demodesk/neko/pkg/types"
)
func pprofHandler(r types.Router) {
r.Get("/debug/pprof/", func(w http.ResponseWriter, r *http.Request) error {
pprof.Index(w, r)
return nil
})
r.Get("/debug/pprof/{action}", func(w http.ResponseWriter, r *http.Request) error {
action := chi.URLParam(r, "action")
switch action {
case "cmdline":
pprof.Cmdline(w, r)
case "profile":
pprof.Profile(w, r)
case "symbol":
pprof.Symbol(w, r)
case "trace":
pprof.Trace(w, r)
default:
pprof.Handler(action).ServeHTTP(w, r)
}
return nil
})
}

View File

@ -0,0 +1,135 @@
package http
import (
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/middleware"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type logFormatter struct {
logger zerolog.Logger
}
func (l *logFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
// exclude health & metrics from logs
if r.RequestURI == "/health" || r.RequestURI == "/metrics" {
return &nulllog{}
}
req := map[string]any{}
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
req["id"] = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
req["scheme"] = scheme
req["proto"] = r.Proto
req["method"] = r.Method
req["remote"] = r.RemoteAddr
req["agent"] = r.UserAgent()
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
return &logEntry{
logger: l.logger.With().Interface("req", req).Logger(),
}
}
type logEntry struct {
logger zerolog.Logger
err error
panic *logPanic
session types.Session
}
type logPanic struct {
message string
stack string
}
func (e *logEntry) Panic(v any, stack []byte) {
e.panic = &logPanic{
message: fmt.Sprintf("%+v", v),
stack: string(stack),
}
}
func (e *logEntry) Error(err error) {
e.err = err
}
func (e *logEntry) SetSession(session types.Session) {
e.session = session
}
func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) {
res := map[string]any{}
res["time"] = time.Now().UTC().Format(time.RFC1123)
res["status"] = status
res["bytes"] = bytes
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
logger := e.logger.With().Interface("res", res).Logger()
// add session ID to logs (if exists)
if e.session != nil {
logger = logger.With().Str("session_id", e.session.ID()).Logger()
}
// handle panic error message
if e.panic != nil {
logger.WithLevel(zerolog.PanicLevel).
Err(e.err).
Str("stack", e.panic.stack).
Msgf("request failed (%d): %s", status, e.panic.message)
return
}
// handle panic error message
if e.err != nil {
httpErr, ok := e.err.(*utils.HTTPError)
if !ok {
logger.Err(e.err).Msgf("request failed (%d)", status)
return
}
if httpErr.Message == "" {
httpErr.Message = http.StatusText(httpErr.Code)
}
var logLevel zerolog.Level
if httpErr.Code < 500 {
logLevel = zerolog.WarnLevel
} else {
logLevel = zerolog.ErrorLevel
}
message := httpErr.Message
if httpErr.InternalMsg != "" {
message = httpErr.InternalMsg
}
logger.WithLevel(logLevel).Err(httpErr.InternalErr).Msgf("request failed (%d): %s", status, message)
return
}
logger.Debug().Msgf("request complete (%d)", status)
}
type nulllog struct{}
func (e *nulllog) Panic(v any, stack []byte) {}
func (e *nulllog) Error(err error) {}
func (e *nulllog) SetSession(session types.Session) {}
func (e *nulllog) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) {
}

View File

@ -0,0 +1,132 @@
package http
import (
"context"
"net/http"
"os"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
)
type HttpManagerCtx struct {
logger zerolog.Logger
config *config.Server
router types.Router
http *http.Server
}
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
logger := log.With().Str("module", "http").Logger()
opts := []RouterOption{
WithRequestID(), // create a request id for each request
}
// use real ip if behind proxy
// before logger so it can log the real ip
if config.Proxy {
opts = append(opts, WithRealIP())
}
opts = append(opts,
WithLogger(logger),
WithRecoverer(), // recover from panics without crashing server
)
if config.HasCors() {
opts = append(opts, WithCORS(config.AllowOrigin))
}
if config.PathPrefix != "/" {
opts = append(opts, WithPathPrefix(config.PathPrefix))
}
router := newRouter(opts...)
router.Route("/api", ApiManager.Route)
router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool {
return config.AllowOrigin(r.Header.Get("Origin"))
}))
batch := batchHandler{
Router: router,
PathPrefix: "/api",
Excluded: []string{
"/api/batch", // do not allow batchception
"/api/ws",
},
}
router.Post("/api/batch", batch.Handle)
router.Get("/health", func(w http.ResponseWriter, r *http.Request) error {
_, err := w.Write([]byte("true"))
return err
})
if config.Metrics {
router.Get("/metrics", func(w http.ResponseWriter, r *http.Request) error {
promhttp.Handler().ServeHTTP(w, r)
return nil
})
}
if config.Static != "" {
fs := http.FileServer(http.Dir(config.Static))
router.Get("/*", func(w http.ResponseWriter, r *http.Request) error {
_, err := os.Stat(config.Static + r.URL.Path)
if err == nil {
fs.ServeHTTP(w, r)
return nil
}
if os.IsNotExist(err) {
http.NotFound(w, r)
return nil
}
return err
})
}
if config.PProf {
pprofHandler(router)
}
return &HttpManagerCtx{
logger: logger,
config: config,
router: router,
http: &http.Server{
Addr: config.Bind,
Handler: router,
},
}
}
func (manager *HttpManagerCtx) Start() {
if manager.config.Cert != "" && manager.config.Key != "" {
go func() {
if err := manager.http.ListenAndServeTLS(manager.config.Cert, manager.config.Key); err != http.ErrServerClosed {
manager.logger.Panic().Err(err).Msg("unable to start https server")
}
}()
manager.logger.Info().Msgf("https listening on %s", manager.http.Addr)
} else {
go func() {
if err := manager.http.ListenAndServe(); err != http.ErrServerClosed {
manager.logger.Panic().Err(err).Msg("unable to start http server")
}
}()
manager.logger.Info().Msgf("http listening on %s", manager.http.Addr)
}
}
func (manager *HttpManagerCtx) Shutdown() error {
manager.logger.Info().Msg("shutdown")
return manager.http.Shutdown(context.Background())
}

View File

@ -0,0 +1,172 @@
package http
import (
"net/http"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type RouterOption func(*router)
func WithRequestID() RouterOption {
return func(r *router) {
r.chi.Use(middleware.RequestID)
}
}
func WithLogger(logger zerolog.Logger) RouterOption {
return func(r *router) {
r.chi.Use(middleware.RequestLogger(&logFormatter{logger}))
}
}
func WithRecoverer() RouterOption {
return func(r *router) {
r.chi.Use(middleware.Recoverer)
}
}
func WithCORS(allowOrigin func(origin string) bool) RouterOption {
return func(r *router) {
r.chi.Use(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return allowOrigin(origin)
},
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
}
}
func WithPathPrefix(prefix string) RouterOption {
return func(r *router) {
r.chi.Use(func(h http.Handler) http.Handler {
return http.StripPrefix(prefix, h)
})
}
}
func WithRealIP() RouterOption {
return func(r *router) {
r.chi.Use(middleware.RealIP)
}
}
type router struct {
chi chi.Router
}
func newRouter(opts ...RouterOption) types.Router {
r := &router{chi.NewRouter()}
for _, opt := range opts {
opt(r)
}
return r
}
func (r *router) Group(fn func(types.Router)) {
r.chi.Group(func(c chi.Router) {
fn(&router{c})
})
}
func (r *router) Route(pattern string, fn func(types.Router)) {
r.chi.Route(pattern, func(c chi.Router) {
fn(&router{c})
})
}
func (r *router) Get(pattern string, fn types.RouterHandler) {
r.chi.Get(pattern, routeHandler(fn))
}
func (r *router) Post(pattern string, fn types.RouterHandler) {
r.chi.Post(pattern, routeHandler(fn))
}
func (r *router) Put(pattern string, fn types.RouterHandler) {
r.chi.Put(pattern, routeHandler(fn))
}
func (r *router) Patch(pattern string, fn types.RouterHandler) {
r.chi.Patch(pattern, routeHandler(fn))
}
func (r *router) Delete(pattern string, fn types.RouterHandler) {
r.chi.Delete(pattern, routeHandler(fn))
}
func (r *router) With(fn types.MiddlewareHandler) types.Router {
c := r.chi.With(middlewareHandler(fn))
return &router{c}
}
func (r *router) Use(fn types.MiddlewareHandler) {
r.chi.Use(middlewareHandler(fn))
}
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.chi.ServeHTTP(w, req)
}
func errorHandler(err error, w http.ResponseWriter, r *http.Request) {
httpErr, ok := err.(*utils.HTTPError)
if !ok {
httpErr = utils.HttpInternalServerError().WithInternalErr(err)
}
utils.HttpJsonResponse(w, httpErr.Code, httpErr)
}
func routeHandler(fn types.RouterHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// get custom log entry pointer from context
logEntry, _ := r.Context().Value(middleware.LogEntryCtxKey).(*logEntry)
if err := fn(w, r); err != nil {
logEntry.Error(err)
errorHandler(err, w, r)
}
// set session if exits
if session, ok := auth.GetSession(r); ok {
logEntry.SetSession(session)
}
}
}
func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// get custom log entry pointer from context
logEntry, _ := r.Context().Value(middleware.LogEntryCtxKey).(*logEntry)
ctx, err := fn(w, r)
if err != nil {
logEntry.Error(err)
errorHandler(err, w, r)
// set session if exits
if session, ok := auth.GetSession(r); ok {
logEntry.SetSession(session)
}
return
}
if ctx != nil {
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,204 @@
package file
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"io"
"os"
"github.com/demodesk/neko/pkg/types"
)
func New(config Config) types.MemberProvider {
return &MemberProviderCtx{
config: config,
}
}
type MemberProviderCtx struct {
config Config
}
func (provider *MemberProviderCtx) hash(password string) string {
// if hash is disabled, return password as plain text
if !provider.config.Hash {
return password
}
sha256 := sha256.New()
sha256.Write([]byte(password))
hashedPassword := sha256.Sum(nil)
return base64.StdEncoding.EncodeToString(hashedPassword)
}
func (provider *MemberProviderCtx) Connect() error {
return nil
}
func (provider *MemberProviderCtx) Disconnect() error {
return nil
}
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
// id will be also username
id := username
entry, err := provider.getEntry(id)
if err != nil {
return "", types.MemberProfile{}, err
}
if entry.Password != provider.hash(password) {
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
}
return id, entry.Profile, nil
}
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
// id will be also username
id := username
entries, err := provider.deserialize()
if err != nil {
return "", err
}
_, ok := entries[id]
if ok {
return "", types.ErrMemberAlreadyExists
}
entries[id] = MemberEntry{
Password: provider.hash(password),
Profile: profile,
}
return id, provider.serialize(entries)
}
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
entries, err := provider.deserialize()
if err != nil {
return err
}
entry, ok := entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
entry.Profile = profile
entries[id] = entry
return provider.serialize(entries)
}
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
entries, err := provider.deserialize()
if err != nil {
return err
}
entry, ok := entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
entry.Password = provider.hash(password)
entries[id] = entry
return provider.serialize(entries)
}
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
entry, err := provider.getEntry(id)
if err != nil {
return types.MemberProfile{}, err
}
return entry.Profile, nil
}
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
profiles := map[string]types.MemberProfile{}
entries, err := provider.deserialize()
if err != nil {
return profiles, err
}
i := 0
for id, entry := range entries {
if i >= offset && (limit == 0 || i < offset+limit) {
profiles[id] = entry.Profile
}
i = i + 1
}
return profiles, nil
}
func (provider *MemberProviderCtx) Delete(id string) error {
entries, err := provider.deserialize()
if err != nil {
return err
}
_, ok := entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
delete(entries, id)
return provider.serialize(entries)
}
func (provider *MemberProviderCtx) deserialize() (map[string]MemberEntry, error) {
file, err := os.OpenFile(provider.config.Path, os.O_RDONLY|os.O_CREATE, os.ModePerm)
if err != nil {
return nil, err
}
raw, err := io.ReadAll(file)
if err != nil {
return nil, err
}
if len(raw) == 0 {
return map[string]MemberEntry{}, nil
}
var entries map[string]MemberEntry
if err := json.Unmarshal([]byte(raw), &entries); err != nil {
return nil, err
}
return entries, nil
}
func (provider *MemberProviderCtx) getEntry(id string) (MemberEntry, error) {
entries, err := provider.deserialize()
if err != nil {
return MemberEntry{}, err
}
entry, ok := entries[id]
if !ok {
return MemberEntry{}, types.ErrMemberDoesNotExist
}
return entry, nil
}
func (provider *MemberProviderCtx) serialize(data map[string]MemberEntry) error {
raw, err := json.Marshal(data)
if err != nil {
return err
}
return os.WriteFile(provider.config.Path, raw, os.ModePerm)
}

View File

@ -0,0 +1,48 @@
package file
import (
"encoding/json"
"testing"
"github.com/demodesk/neko/pkg/utils"
)
// Ensure that hashes are the same after encoding and decoding using json
func TestMemberProviderCtx_hash(t *testing.T) {
provider := &MemberProviderCtx{
config: Config{
Hash: true,
},
}
// generate random strings
passwords := []string{}
for i := 0; i < 10; i++ {
password, err := utils.NewUID(32)
if err != nil {
t.Errorf("utils.NewUID() returned error: %s", err)
}
passwords = append(passwords, password)
}
for _, password := range passwords {
hashedPassword := provider.hash(password)
// json encode password hash
hashedPasswordJSON, err := json.Marshal(hashedPassword)
if err != nil {
t.Errorf("json.Marshal() returned error: %s", err)
}
// json decode password hash json
var hashedPasswordStr string
err = json.Unmarshal(hashedPasswordJSON, &hashedPasswordStr)
if err != nil {
t.Errorf("json.Unmarshal() returned error: %s", err)
}
if hashedPasswordStr != hashedPassword {
t.Errorf("hashedPasswordStr: %s != hashedPassword: %s", hashedPasswordStr, hashedPassword)
}
}
}

View File

@ -0,0 +1,15 @@
package file
import (
"github.com/demodesk/neko/pkg/types"
)
type MemberEntry struct {
Password string `json:"password"`
Profile types.MemberProfile `json:"profile"`
}
type Config struct {
Path string
Hash bool
}

View File

@ -0,0 +1,168 @@
package member
import (
"errors"
"sync"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/member/file"
"github.com/demodesk/neko/internal/member/multiuser"
"github.com/demodesk/neko/internal/member/noauth"
"github.com/demodesk/neko/internal/member/object"
"github.com/demodesk/neko/pkg/types"
)
func New(sessions types.SessionManager, config *config.Member) *MemberManagerCtx {
manager := &MemberManagerCtx{
logger: log.With().Str("module", "member").Logger(),
sessions: sessions,
config: config,
}
switch config.Provider {
case "file":
manager.provider = file.New(config.File)
case "object":
manager.provider = object.New(config.Object)
case "multiuser":
manager.provider = multiuser.New(config.Multiuser)
case "noauth":
fallthrough
default:
manager.provider = noauth.New()
}
return manager
}
type MemberManagerCtx struct {
logger zerolog.Logger
sessions types.SessionManager
config *config.Member
providerMu sync.Mutex
provider types.MemberProvider
loginMu sync.Mutex
}
func (manager *MemberManagerCtx) Connect() error {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.Connect()
}
func (manager *MemberManagerCtx) Disconnect() error {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.Disconnect()
}
func (manager *MemberManagerCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.Authenticate(username, password)
}
func (manager *MemberManagerCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.Insert(username, password, profile)
}
func (manager *MemberManagerCtx) Select(id string) (types.MemberProfile, error) {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
// get primarily from corresponding session, if exists
session, ok := manager.sessions.Get(id)
if ok {
return session.Profile(), nil
}
return manager.provider.Select(id)
}
func (manager *MemberManagerCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.SelectAll(limit, offset)
}
func (manager *MemberManagerCtx) UpdateProfile(id string, profile types.MemberProfile) error {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
// update corresponding session, if exists
err := manager.sessions.Update(id, profile)
if err != nil && !errors.Is(err, types.ErrSessionNotFound) {
manager.logger.Err(err).Msg("error while updating session")
}
return manager.provider.UpdateProfile(id, profile)
}
func (manager *MemberManagerCtx) UpdatePassword(id string, password string) error {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
return manager.provider.UpdatePassword(id, password)
}
func (manager *MemberManagerCtx) Delete(id string) error {
manager.providerMu.Lock()
defer manager.providerMu.Unlock()
// destroy corresponding session, if exists
err := manager.sessions.Delete(id)
if err != nil && !errors.Is(err, types.ErrSessionNotFound) {
manager.logger.Err(err).Msg("error while deleting session")
}
return manager.provider.Delete(id)
}
//
// member -> session
//
func (manager *MemberManagerCtx) Login(username string, password string) (types.Session, string, error) {
manager.loginMu.Lock()
defer manager.loginMu.Unlock()
id, profile, err := manager.provider.Authenticate(username, password)
if err != nil {
return nil, "", err
}
if !profile.IsAdmin && manager.sessions.Settings().LockedLogins {
return nil, "", types.ErrSessionLoginsLocked
}
session, ok := manager.sessions.Get(id)
if ok {
if session.State().IsConnected {
return nil, "", types.ErrSessionAlreadyConnected
}
// TODO: Replace session.
if err := manager.sessions.Delete(id); err != nil {
return nil, "", err
}
}
return manager.sessions.Create(id, profile)
}
func (manager *MemberManagerCtx) Logout(id string) error {
manager.loginMu.Lock()
defer manager.loginMu.Unlock()
return manager.sessions.Delete(id)
}

View File

@ -0,0 +1,82 @@
package multiuser
import (
"errors"
"fmt"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func New(config Config) types.MemberProvider {
return &MemberProviderCtx{
config: config,
}
}
type MemberProviderCtx struct {
config Config
}
func (provider *MemberProviderCtx) Connect() error {
return nil
}
func (provider *MemberProviderCtx) Disconnect() error {
return nil
}
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
// generate random token
token, err := utils.NewUID(5)
if err != nil {
return "", types.MemberProfile{}, err
}
// id is username with token
id := fmt.Sprintf("%s-%s", username, token)
// if logged in as administrator
if provider.config.AdminPassword == password {
profile := provider.config.AdminProfile
if profile.Name == "" {
profile.Name = username
}
return id, profile, nil
}
// if logged in as user
if provider.config.UserPassword == password {
profile := provider.config.UserProfile
if profile.Name == "" {
profile.Name = username
}
return id, profile, nil
}
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
}
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
return "", errors.New("new user is created on first login in multiuser mode")
}
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
return nil
}
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
return errors.New("password can only be modified in config while in multiuser mode")
}
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
return types.MemberProfile{}, errors.New("cannot select user in multiuser mode")
}
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
return map[string]types.MemberProfile{}, nil
}
func (provider *MemberProviderCtx) Delete(id string) error {
return errors.New("cannot delete user in multiuser mode")
}

View File

@ -0,0 +1,10 @@
package multiuser
import "github.com/demodesk/neko/pkg/types"
type Config struct {
AdminPassword string
UserPassword string
AdminProfile types.MemberProfile
UserProfile types.MemberProfile
}

View File

@ -0,0 +1,75 @@
package noauth
import (
"errors"
"fmt"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func New() types.MemberProvider {
return &MemberProviderCtx{
profile: types.MemberProfile{
IsAdmin: true,
CanLogin: true,
CanConnect: true,
CanWatch: true,
CanHost: true,
CanShareMedia: true,
CanAccessClipboard: true,
SendsInactiveCursor: true,
CanSeeInactiveCursors: true,
},
}
}
type MemberProviderCtx struct {
profile types.MemberProfile
}
func (provider *MemberProviderCtx) Connect() error {
return nil
}
func (provider *MemberProviderCtx) Disconnect() error {
return nil
}
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
// generate random token
token, err := utils.NewUID(5)
if err != nil {
return "", types.MemberProfile{}, err
}
// id is username with token
id := fmt.Sprintf("%s-%s", username, token)
provider.profile.Name = username
return id, provider.profile, nil
}
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
return "", errors.New("new user is created on first login in noauth mode")
}
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
return nil
}
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
return errors.New("noauth mode does not have password")
}
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
return types.MemberProfile{}, errors.New("cannot select user in noauth mode")
}
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
return map[string]types.MemberProfile{}, nil
}
func (provider *MemberProviderCtx) Delete(id string) error {
return errors.New("cannot delete user in noauth mode")
}

View File

@ -0,0 +1,124 @@
package object
import (
"github.com/demodesk/neko/pkg/types"
)
func New(config Config) types.MemberProvider {
return &MemberProviderCtx{
config: config,
entries: make(map[string]*memberEntry),
}
}
type MemberProviderCtx struct {
config Config
entries map[string]*memberEntry
}
func (provider *MemberProviderCtx) Connect() error {
var err error
for _, entry := range provider.config.Users {
_, err = provider.Insert(entry.Username, entry.Password, entry.Profile)
}
return err
}
func (provider *MemberProviderCtx) Disconnect() error {
return nil
}
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
// id will be also username
id := username
entry, ok := provider.entries[id]
if !ok {
return "", types.MemberProfile{}, types.ErrMemberDoesNotExist
}
// TODO: Use hash function.
if !entry.CheckPassword(password) {
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
}
return id, entry.profile, nil
}
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
// id will be also username
id := username
_, ok := provider.entries[id]
if ok {
return "", types.ErrMemberAlreadyExists
}
provider.entries[id] = &memberEntry{
// TODO: Use hash function.
password: password,
profile: profile,
}
return id, nil
}
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
entry, ok := provider.entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
entry.profile = profile
return nil
}
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
entry, ok := provider.entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
// TODO: Use hash function.
entry.password = password
return nil
}
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
entry, ok := provider.entries[id]
if !ok {
return types.MemberProfile{}, types.ErrMemberDoesNotExist
}
return entry.profile, nil
}
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
profiles := make(map[string]types.MemberProfile)
i := 0
for id, entry := range provider.entries {
if i >= offset && (limit == 0 || i < offset+limit) {
profiles[id] = entry.profile
}
i = i + 1
}
return profiles, nil
}
func (provider *MemberProviderCtx) Delete(id string) error {
_, ok := provider.entries[id]
if !ok {
return types.ErrMemberDoesNotExist
}
delete(provider.entries, id)
return nil
}

View File

@ -0,0 +1,24 @@
package object
import (
"github.com/demodesk/neko/pkg/types"
)
type memberEntry struct {
password string
profile types.MemberProfile
}
func (m *memberEntry) CheckPassword(password string) bool {
return m.password == password
}
type User struct {
Username string
Password string
Profile types.MemberProfile
}
type Config struct {
Users []User
}

View File

@ -0,0 +1,23 @@
package chat
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Config struct {
Enabled bool
}
func (Config) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().Bool("chat.enabled", true, "whether to enable chat plugin")
if err := viper.BindPFlag("chat.enabled", cmd.PersistentFlags().Lookup("chat.enabled")); err != nil {
return err
}
return nil
}
func (s *Config) Set() {
s.Enabled = viper.GetBool("chat.enabled")
}

View File

@ -0,0 +1,162 @@
package chat
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func NewManager(
sessions types.SessionManager,
config *Config,
) *Manager {
logger := log.With().Str("module", "chat").Logger()
return &Manager{
logger: logger,
config: config,
sessions: sessions,
}
}
type Manager struct {
logger zerolog.Logger
config *Config
sessions types.SessionManager
}
type Settings struct {
CanSend bool `json:"can_send" mapstructure:"can_send"`
CanReceive bool `json:"can_receive" mapstructure:"can_receive"`
}
func (m *Manager) settingsForSession(session types.Session) (Settings, error) {
settings := Settings{
CanSend: true, // defaults to true
CanReceive: true, // defaults to true
}
err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings)
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err)
}
profile := Settings{
CanSend: true, // defaults to true
CanReceive: true, // defaults to true
}
err = session.Profile().Plugins.Unmarshal(PluginName, &profile)
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err)
}
return Settings{
CanSend: m.config.Enabled && (settings.CanSend || session.Profile().IsAdmin) && profile.CanSend,
CanReceive: m.config.Enabled && (settings.CanReceive || session.Profile().IsAdmin) && profile.CanReceive,
}, nil
}
func (m *Manager) sendMessage(session types.Session, content Content) {
now := time.Now()
// get all sessions that have chat enabled
var sessions []types.Session
m.sessions.Range(func(s types.Session) bool {
if settings, err := m.settingsForSession(s); err == nil && settings.CanReceive {
sessions = append(sessions, s)
}
// continue iteration over all sessions
return true
})
// send content to all sessions
for _, s := range sessions {
s.Send(CHAT_MESSAGE, Message{
ID: session.ID(),
Created: now,
Content: content,
})
}
}
func (m *Manager) Start() error {
// send init message once a user connects
m.sessions.OnConnected(func(session types.Session) {
session.Send(CHAT_INIT, Init{
Enabled: m.config.Enabled,
})
})
return nil
}
func (m *Manager) Shutdown() error {
return nil
}
func (m *Manager) Route(r types.Router) {
r.With(auth.AdminsOnly).Post("/", m.sendMessageHandler)
}
func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool {
switch msg.Event {
case CHAT_MESSAGE:
var content Content
if err := json.Unmarshal(msg.Payload, &content); err != nil {
m.logger.Error().Err(err).Msg("failed to unmarshal chat message")
// we processed the message, return true
return true
}
settings, err := m.settingsForSession(session)
if err != nil {
m.logger.Error().Err(err).Msg("error checking chat permissions for this session")
// we processed the message, return true
return true
}
if !settings.CanSend {
m.logger.Warn().Msg("not allowed to send chat messages")
// we processed the message, return true
return true
}
m.sendMessage(session, content)
return true
}
return false
}
func (m *Manager) sendMessageHandler(w http.ResponseWriter, r *http.Request) error {
session, ok := auth.GetSession(r)
if !ok {
return utils.HttpUnauthorized("session not found")
}
settings, err := m.settingsForSession(session)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
Msg("error checking chat permissions for this session")
}
if !settings.CanSend {
return utils.HttpForbidden("not allowed to send chat messages")
}
content := Content{}
if err := utils.HttpJsonRequest(w, r, &content); err != nil {
return err
}
m.sendMessage(session, content)
return utils.HttpSuccess(w)
}

View File

@ -0,0 +1,35 @@
package chat
import (
"github.com/demodesk/neko/pkg/types"
)
type Plugin struct {
config *Config
manager *Manager
}
func NewPlugin() *Plugin {
return &Plugin{
config: &Config{},
}
}
func (p *Plugin) Name() string {
return PluginName
}
func (p *Plugin) Config() types.PluginConfig {
return p.config
}
func (p *Plugin) Start(m types.PluginManagers) error {
p.manager = NewManager(m.SessionManager, p.config)
m.ApiManager.AddRouter("/chat", p.manager.Route)
m.WebSocketManager.AddHandler(p.manager.WebSocketHandler)
return p.manager.Start()
}
func (p *Plugin) Shutdown() error {
return p.manager.Shutdown()
}

View File

@ -0,0 +1,24 @@
package chat
import "time"
const PluginName = "chat"
const (
CHAT_INIT = "chat/init"
CHAT_MESSAGE = "chat/message"
)
type Init struct {
Enabled bool `json:"enabled"`
}
type Content struct {
Text string `json:"text"`
}
type Message struct {
ID string `json:"id"`
Created time.Time `json:"created"`
Content Content `json:"content"`
}

View File

@ -0,0 +1,133 @@
package plugins
import (
"fmt"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
)
type dependency struct {
plugin types.Plugin
dependsOn []*dependency
invoked bool
logger zerolog.Logger
}
func (a *dependency) findPlugin(name string) (*dependency, bool) {
if a == nil {
return nil, false
}
if a.plugin.Name() == name {
return a, true
}
for _, dep := range a.dependsOn {
plug, ok := dep.findPlugin(name)
if ok {
return plug, true
}
}
return nil, false
}
func (a *dependency) startPlugin(pm types.PluginManagers) error {
if a.invoked {
return nil
}
a.invoked = true
for _, do := range a.dependsOn {
if err := do.startPlugin(pm); err != nil {
return fmt.Errorf("plugin's '%s' dependency: %w", a.plugin.Name(), err)
}
}
err := a.plugin.Start(pm)
if err != nil {
return fmt.Errorf("plugin '%s' failed to start: %w", a.plugin.Name(), err)
}
a.logger.Info().Str("plugin", a.plugin.Name()).Msg("plugin started")
return nil
}
type dependiencies struct {
deps map[string]*dependency
logger zerolog.Logger
}
func (d *dependiencies) addPlugin(plugin types.Plugin) error {
pluginName := plugin.Name()
plug, ok := d.deps[pluginName]
if !ok {
plug = &dependency{}
} else if plug.plugin != nil {
return fmt.Errorf("plugin '%s' already added", pluginName)
}
plug.plugin = plugin
plug.logger = d.logger
d.deps[pluginName] = plug
dplug, ok := plugin.(types.DependablePlugin)
if !ok {
return nil
}
for _, depName := range dplug.DependsOn() {
dependsOn, ok := d.deps[depName]
if !ok {
dependsOn = &dependency{}
} else if dependsOn.plugin != nil {
// if there is a cyclical dependency, break it and return error
if tdep, ok := dependsOn.findPlugin(pluginName); ok {
dependsOn.dependsOn = nil
delete(d.deps, pluginName)
return fmt.Errorf("cyclical dependency detected: '%s' <-> '%s'", pluginName, tdep.plugin.Name())
}
}
plug.dependsOn = append(plug.dependsOn, dependsOn)
d.deps[depName] = dependsOn
}
return nil
}
func (d *dependiencies) findPlugin(name string) (*dependency, bool) {
for _, dep := range d.deps {
plug, ok := dep.findPlugin(name)
if ok {
return plug, true
}
}
return nil, false
}
func (d *dependiencies) start(pm types.PluginManagers) error {
for _, dep := range d.deps {
if err := dep.startPlugin(pm); err != nil {
return err
}
}
return nil
}
func (d *dependiencies) forEach(f func(*dependency) error) error {
for _, dep := range d.deps {
if err := f(dep); err != nil {
return err
}
}
return nil
}
func (d *dependiencies) len() int {
return len(d.deps)
}

View File

@ -0,0 +1,630 @@
package plugins
import (
"reflect"
"testing"
"github.com/demodesk/neko/pkg/types"
)
func Test_deps_addPlugin(t *testing.T) {
type args struct {
p []types.Plugin
}
tests := []struct {
name string
args args
want map[string]*dependency
skipRun bool
wantErr1 bool
wantErr2 bool
}{
{
name: "three plugins - no dependencies",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "first"},
&dummyPlugin{name: "second"},
&dummyPlugin{name: "third"},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
"third": {
plugin: &dummyPlugin{name: "third", idx: 0},
invoked: true,
dependsOn: nil,
},
},
}, {
name: "three plugins - one dependency",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "third", dep: []string{"second"}},
&dummyPlugin{name: "first"},
&dummyPlugin{name: "second"},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
}, {
name: "three plugins - one double dependency",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "third", dep: []string{"first", "second"}},
&dummyPlugin{name: "first"},
&dummyPlugin{name: "second"},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
{
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
}, {
name: "three plugins - two dependencies",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "third", dep: []string{"first"}},
&dummyPlugin{name: "first"},
&dummyPlugin{name: "second", dep: []string{"first"}},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first"},
invoked: false,
dependsOn: nil,
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"first"}},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first"},
invoked: false,
dependsOn: nil,
},
},
},
"second": {
plugin: &dummyPlugin{name: "second", dep: []string{"first"}},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first"},
invoked: false,
dependsOn: nil,
},
},
},
},
skipRun: true,
}, {
name: "three plugins - three dependencies",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "third", dep: []string{"second"}},
&dummyPlugin{name: "first"},
&dummyPlugin{name: "second", dep: []string{"first"}},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
},
}, {
name: "four plugins - added in reverse order, with dependencies",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "forth", dep: []string{"third"}},
&dummyPlugin{name: "third", dep: []string{"second"}},
&dummyPlugin{name: "second", dep: []string{"first"}},
&dummyPlugin{name: "first"},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: false,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: false,
dependsOn: nil,
},
},
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: false,
dependsOn: nil,
},
},
},
},
},
"forth": {
plugin: &dummyPlugin{name: "forth", dep: []string{"third"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: false,
dependsOn: nil,
},
},
},
},
},
},
},
},
skipRun: true,
}, {
name: "four plugins - two double dependencies",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "forth", dep: []string{"first", "third"}},
&dummyPlugin{name: "third", dep: []string{"first", "second"}},
&dummyPlugin{name: "second"},
&dummyPlugin{name: "first"},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
"second": {
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
{
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
"forth": {
plugin: &dummyPlugin{name: "forth", dep: []string{"first", "third"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
{
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", idx: 0},
invoked: true,
dependsOn: nil,
},
{
plugin: &dummyPlugin{name: "second", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
},
}, {
// So, when we have plugin A in the list and want to add plugin C we can't determine the proper order without
// resolving their direct dependiencies first:
//
// Can be C->D->A->B if D depends on A
//
// So to do it properly I would imagine tht we need to resolve all direct dependiencies first and build multiple lists:
//
// i.e. A->B->C D F->G
//
// and then join these lists in any order.
name: "add indirect dependency CDAB",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "A", dep: []string{"B"}},
&dummyPlugin{name: "C", dep: []string{"D"}},
&dummyPlugin{name: "B"},
&dummyPlugin{name: "D", dep: []string{"A"}},
},
},
want: map[string]*dependency{
"B": {
plugin: &dummyPlugin{name: "B", idx: 0},
invoked: true,
dependsOn: nil,
},
"A": {
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "B", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
"D": {
plugin: &dummyPlugin{name: "D", dep: []string{"A"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "B", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
"C": {
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 3},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "D", dep: []string{"A"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "B", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
},
},
},
}, {
// So, when we have plugin A in the list and want to add plugin C we can't determine the proper order without
// resolving their direct dependiencies first:
//
// Can be A->B->C->D (in this test) if B depends on C
//
// So to do it properly I would imagine tht we need to resolve all direct dependiencies first and build multiple lists:
//
// i.e. A->B->C D F->G
//
// and then join these lists in any order.
name: "add indirect dependency ABCD",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "C", dep: []string{"D"}},
&dummyPlugin{name: "D"},
&dummyPlugin{name: "B", dep: []string{"C"}},
&dummyPlugin{name: "A", dep: []string{"B"}},
},
},
want: map[string]*dependency{
"D": {
plugin: &dummyPlugin{name: "D", idx: 0},
invoked: true,
dependsOn: nil,
},
"C": {
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "D", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
"B": {
plugin: &dummyPlugin{name: "B", dep: []string{"C"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "D", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
"A": {
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 3},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "B", dep: []string{"C"}, idx: 2},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
invoked: true,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "D", idx: 0},
invoked: true,
dependsOn: nil,
},
},
},
},
},
},
},
},
}, {
name: "add duplicate plugin",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "first"},
&dummyPlugin{name: "first"},
},
},
want: map[string]*dependency{
"first": {plugin: &dummyPlugin{name: "first", idx: 0}, invoked: true},
},
wantErr1: true,
}, {
name: "cyclical dependency",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "first", dep: []string{"second"}},
&dummyPlugin{name: "second", dep: []string{"first"}},
},
},
want: map[string]*dependency{
"first": {
plugin: &dummyPlugin{name: "first", dep: []string{"second"}, idx: 1},
invoked: true,
},
},
wantErr1: true,
}, {
name: "four plugins - cyclical transitive dependencies in reverse order",
args: args{
p: []types.Plugin{
&dummyPlugin{name: "forth", dep: []string{"third"}},
&dummyPlugin{name: "third", dep: []string{"second"}},
&dummyPlugin{name: "second", dep: []string{"first"}},
&dummyPlugin{name: "first", dep: []string{"forth"}},
},
},
want: map[string]*dependency{
"second": {
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", dep: []string{"forth"}, idx: 0},
invoked: false,
dependsOn: nil,
},
},
},
"third": {
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
invoked: false,
dependsOn: []*dependency{
{
plugin: &dummyPlugin{name: "first", dep: []string{"forth"}, idx: 0},
invoked: false,
dependsOn: nil,
},
},
},
},
},
"forth": {
plugin: &dummyPlugin{name: "forth", dep: []string{"third"}, idx: 0},
invoked: false,
dependsOn: nil,
},
},
wantErr1: true,
skipRun: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &dependiencies{deps: make(map[string]*dependency)}
var (
err error
counter int
)
for _, p := range tt.args.p {
if !tt.skipRun {
p.(*dummyPlugin).counter = &counter
}
if err = d.addPlugin(p); err != nil {
break
}
}
if err != nil != tt.wantErr1 {
t.Errorf("dependiencies.addPlugin() error = %v, wantErr1 %v", err, tt.wantErr1)
return
}
if !tt.skipRun {
if err := d.start(types.PluginManagers{}); (err != nil) != tt.wantErr2 {
t.Errorf("dependiencies.start() error = %v, wantErr1 %v", err, tt.wantErr2)
}
}
if !reflect.DeepEqual(d.deps, tt.want) {
t.Errorf("deps = %v, want %v", d.deps, tt.want)
}
})
}
}
type dummyPlugin struct {
name string
dep []string
idx int
counter *int
}
func (d dummyPlugin) Name() string {
return d.name
}
func (d dummyPlugin) DependsOn() []string {
return d.dep
}
func (d dummyPlugin) Config() types.PluginConfig {
return nil
}
func (d *dummyPlugin) Start(types.PluginManagers) error {
if len(d.dep) > 0 {
*d.counter++
d.idx = *d.counter
}
d.counter = nil
return nil
}
func (d dummyPlugin) Shutdown() error {
return nil
}

View File

@ -0,0 +1,41 @@
package filetransfer
import (
"path/filepath"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type Config struct {
Enabled bool
RootDir string
RefreshInterval time.Duration
}
func (Config) Init(cmd *cobra.Command) error {
cmd.PersistentFlags().Bool("filetransfer.enabled", false, "whether file transfer is enabled")
if err := viper.BindPFlag("filetransfer.enabled", cmd.PersistentFlags().Lookup("filetransfer.enabled")); err != nil {
return err
}
cmd.PersistentFlags().String("filetransfer.dir", "/home/neko/Downloads", "root directory for file transfer")
if err := viper.BindPFlag("filetransfer.dir", cmd.PersistentFlags().Lookup("filetransfer.dir")); err != nil {
return err
}
cmd.PersistentFlags().Duration("filetransfer.refresh_interval", 30*time.Second, "interval to refresh file list")
if err := viper.BindPFlag("filetransfer.refresh_interval", cmd.PersistentFlags().Lookup("filetransfer.refresh_interval")); err != nil {
return err
}
return nil
}
func (s *Config) Set() {
s.Enabled = viper.GetBool("filetransfer.enabled")
rootDir := viper.GetString("filetransfer.dir")
s.RootDir = filepath.Clean(rootDir)
s.RefreshInterval = viper.GetDuration("filetransfer.refresh_interval")
}

View File

@ -0,0 +1,332 @@
package filetransfer
import (
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"sync"
"time"
"github.com/demodesk/neko/pkg/auth"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
const MULTIPART_FORM_MAX_MEMORY = 32 << 20
func NewManager(
sessions types.SessionManager,
config *Config,
) *Manager {
logger := log.With().Str("module", "filetransfer").Logger()
return &Manager{
logger: logger,
config: config,
sessions: sessions,
shutdown: make(chan struct{}),
}
}
type Manager struct {
logger zerolog.Logger
config *Config
sessions types.SessionManager
shutdown chan struct{}
mu sync.RWMutex
fileList []Item
}
func (m *Manager) isEnabledForSession(session types.Session) (bool, error) {
settings := Settings{
Enabled: true, // defaults to true
}
err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings)
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
return false, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err)
}
profile := Settings{
Enabled: true, // defaults to true
}
err = session.Profile().Plugins.Unmarshal(PluginName, &profile)
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
return false, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err)
}
return m.config.Enabled && (settings.Enabled || session.Profile().IsAdmin) && profile.Enabled, nil
}
func (m *Manager) refresh() (error, bool) {
// if file transfer is disabled, return immediately without refreshing
if !m.config.Enabled {
return nil, false
}
files, err := ListFiles(m.config.RootDir)
if err != nil {
return err, false
}
m.mu.Lock()
defer m.mu.Unlock()
// check if file list has changed (todo: use hash instead of comparing all fields)
changed := false
if len(files) == len(m.fileList) {
for i, file := range files {
if file.Name != m.fileList[i].Name || file.Size != m.fileList[i].Size {
changed = true
break
}
}
} else {
changed = true
}
m.fileList = files
return nil, changed
}
func (m *Manager) broadcastUpdate() {
m.mu.RLock()
fileList := m.fileList
m.mu.RUnlock()
m.sessions.Broadcast(FILETRANSFER_UPDATE, Message{
Enabled: m.config.Enabled,
RootDir: m.config.RootDir,
Files: fileList,
})
}
func (m *Manager) sendUpdate(session types.Session) {
m.mu.RLock()
fileList := m.fileList
m.mu.RUnlock()
session.Send(FILETRANSFER_UPDATE, Message{
Enabled: m.config.Enabled,
RootDir: m.config.RootDir,
Files: fileList,
})
}
func (m *Manager) Start() error {
// send init message once a user connects
m.sessions.OnConnected(func(session types.Session) {
m.sendUpdate(session)
})
// if file transfer is disabled, return immediately without starting the watcher
if !m.config.Enabled {
return nil
}
if _, err := os.Stat(m.config.RootDir); os.IsNotExist(err) {
err = os.Mkdir(m.config.RootDir, os.ModePerm)
m.logger.Err(err).Msg("creating file transfer directory")
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("unable to start file transfer dir watcher: %w", err)
}
go func() {
defer watcher.Close()
// periodically refresh file list
ticker := time.NewTicker(m.config.RefreshInterval)
defer ticker.Stop()
for {
select {
case <-m.shutdown:
m.logger.Info().Msg("shutting down file transfer manager")
return
case <-ticker.C:
err, changed := m.refresh()
if err != nil {
m.logger.Err(err).Msg("unable to refresh file transfer list")
}
if changed {
m.broadcastUpdate()
}
case e, ok := <-watcher.Events:
if !ok {
m.logger.Info().Msg("file transfer dir watcher closed")
return
}
if e.Has(fsnotify.Create) || e.Has(fsnotify.Remove) || e.Has(fsnotify.Rename) {
m.logger.Debug().Str("event", e.String()).Msg("file transfer dir watcher event")
err, changed := m.refresh()
if err != nil {
m.logger.Err(err).Msg("unable to refresh file transfer list")
}
if changed {
m.broadcastUpdate()
}
}
case err := <-watcher.Errors:
m.logger.Err(err).Msg("error in file transfer dir watcher")
}
}
}()
if err := watcher.Add(m.config.RootDir); err != nil {
return fmt.Errorf("unable to watch file transfer dir: %w", err)
}
// initial refresh
err, changed := m.refresh()
if err != nil {
return fmt.Errorf("unable to refresh file transfer list: %w", err)
}
if changed {
m.broadcastUpdate()
}
return nil
}
func (m *Manager) Shutdown() error {
close(m.shutdown)
return nil
}
func (m *Manager) Route(r types.Router) {
r.With(auth.AdminsOnly).Get("/", m.downloadFileHandler)
r.With(auth.AdminsOnly).Post("/", m.uploadFileHandler)
}
func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool {
switch msg.Event {
case FILETRANSFER_UPDATE:
err, changed := m.refresh()
if err != nil {
m.logger.Err(err).Msg("unable to refresh file transfer list")
}
if changed {
// broadcast update message to all clients
m.broadcastUpdate()
} else {
// send update message to this client only
m.sendUpdate(session)
}
return true
}
// not handled by this plugin
return false
}
func (m *Manager) downloadFileHandler(w http.ResponseWriter, r *http.Request) error {
session, ok := auth.GetSession(r)
if !ok {
return utils.HttpUnauthorized("session not found")
}
enabled, err := m.isEnabledForSession(session)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
Msg("error checking file transfer permissions")
}
if !enabled {
return utils.HttpForbidden("file transfer is disabled")
}
filename := r.URL.Query().Get("filename")
badChars, err := regexp.MatchString(`(?m)\.\.(?:\/|$)`, filename)
if filename == "" || badChars || err != nil {
return utils.HttpBadRequest().
WithInternalErr(err).
Msg("bad filename")
}
// ensure filename is clean and only contains the basename
filename = filepath.Clean(filename)
filename = filepath.Base(filename)
filePath := filepath.Join(m.config.RootDir, filename)
http.ServeFile(w, r, filePath)
return nil
}
func (m *Manager) uploadFileHandler(w http.ResponseWriter, r *http.Request) error {
session, ok := auth.GetSession(r)
if !ok {
return utils.HttpUnauthorized("session not found")
}
enabled, err := m.isEnabledForSession(session)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
Msg("error checking file transfer permissions")
}
if !enabled {
return utils.HttpForbidden("file transfer is disabled")
}
err = r.ParseMultipartForm(MULTIPART_FORM_MAX_MEMORY)
if err != nil || r.MultipartForm == nil {
return utils.HttpBadRequest().
WithInternalErr(err).
Msg("error parsing form")
}
defer func() {
err = r.MultipartForm.RemoveAll()
if err != nil {
m.logger.Warn().Err(err).Msg("failed to clean up multipart form")
}
}()
for _, formheader := range r.MultipartForm.File["files"] {
// ensure filename is clean and only contains the basename
filename := filepath.Clean(formheader.Filename)
filename = filepath.Base(filename)
filePath := filepath.Join(m.config.RootDir, filename)
formfile, err := formheader.Open()
if err != nil {
return utils.HttpBadRequest().
WithInternalErr(err).
Msg("error opening formdata file")
}
defer formfile.Close()
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
Msg("error opening file for writing")
}
defer f.Close()
_, err = io.Copy(f, formfile)
if err != nil {
return utils.HttpInternalServerError().
WithInternalErr(err).
Msg("error writing file")
}
}
return nil
}

View File

@ -0,0 +1,35 @@
package filetransfer
import (
"github.com/demodesk/neko/pkg/types"
)
type Plugin struct {
config *Config
manager *Manager
}
func NewPlugin() *Plugin {
return &Plugin{
config: &Config{},
}
}
func (p *Plugin) Name() string {
return PluginName
}
func (p *Plugin) Config() types.PluginConfig {
return p.config
}
func (p *Plugin) Start(m types.PluginManagers) error {
p.manager = NewManager(m.SessionManager, p.config)
m.ApiManager.AddRouter("/filetransfer", p.manager.Route)
m.WebSocketManager.AddHandler(p.manager.WebSocketHandler)
return p.manager.Start()
}
func (p *Plugin) Shutdown() error {
return p.manager.Shutdown()
}

View File

@ -0,0 +1,30 @@
package filetransfer
const PluginName = "filetransfer"
type Settings struct {
Enabled bool `json:"enabled" mapstructure:"enabled"`
}
const (
FILETRANSFER_UPDATE = "filetransfer/update"
)
type Message struct {
Enabled bool `json:"enabled"`
RootDir string `json:"root_dir"`
Files []Item `json:"files"`
}
type ItemType string
const (
ItemTypeFile ItemType = "file"
ItemTypeDir ItemType = "dir"
)
type Item struct {
Name string `json:"name"`
Type ItemType `json:"type"`
Size int64 `json:"size,omitempty"`
}

View File

@ -0,0 +1,32 @@
package filetransfer
import "os"
func ListFiles(path string) ([]Item, error) {
items, err := os.ReadDir(path)
if err != nil {
return nil, err
}
out := make([]Item, len(items))
for i, item := range items {
var itemType ItemType
var size int64 = 0
if item.IsDir() {
itemType = ItemTypeDir
} else {
itemType = ItemTypeFile
info, err := item.Info()
if err == nil {
size = info.Size()
}
}
out[i] = Item{
Name: item.Name(),
Type: itemType,
Size: size,
}
}
return out, nil
}

View File

@ -0,0 +1,183 @@
package plugins
import (
"fmt"
"os"
"path/filepath"
"plugin"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/plugins/chat"
"github.com/demodesk/neko/internal/plugins/filetransfer"
"github.com/demodesk/neko/pkg/types"
)
type ManagerCtx struct {
logger zerolog.Logger
config *config.Plugins
plugins dependiencies
}
func New(config *config.Plugins) *ManagerCtx {
manager := &ManagerCtx{
logger: log.With().Str("module", "plugins").Logger(),
config: config,
plugins: dependiencies{
deps: make(map[string]*dependency),
},
}
manager.plugins.logger = manager.logger
if config.Enabled {
err := manager.loadDir(config.Dir)
// only log error if plugin is not required
if err != nil && config.Required {
manager.logger.Fatal().Err(err).Msg("error loading plugins")
}
manager.logger.Info().Msgf("loading finished, total %d plugins", manager.plugins.len())
}
// add built-in plugins
manager.plugins.addPlugin(filetransfer.NewPlugin())
manager.plugins.addPlugin(chat.NewPlugin())
return manager
}
func (manager *ManagerCtx) loadDir(dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
err = manager.load(path)
// return error if plugin is required
if err != nil && manager.config.Required {
return err
}
// otherwise only log error if plugin is not required
manager.logger.Err(err).Str("plugin", path).Msg("loading a plugin")
return nil
})
}
func (manager *ManagerCtx) load(path string) error {
pl, err := plugin.Open(path)
if err != nil {
return err
}
sym, err := pl.Lookup("Plugin")
if err != nil {
return err
}
p, ok := sym.(types.Plugin)
if !ok {
return fmt.Errorf("not a valid plugin")
}
if err = manager.plugins.addPlugin(p); err != nil {
return fmt.Errorf("failed to add plugin: %w", err)
}
return nil
}
func (manager *ManagerCtx) InitConfigs(cmd *cobra.Command) {
_ = manager.plugins.forEach(func(d *dependency) error {
if err := d.plugin.Config().Init(cmd); err != nil {
log.Err(err).Str("plugin", d.plugin.Name()).Msg("unable to initialize configuration")
}
return nil
})
}
func (manager *ManagerCtx) SetConfigs() {
_ = manager.plugins.forEach(func(d *dependency) error {
d.plugin.Config().Set()
return nil
})
}
func (manager *ManagerCtx) Start(
sessionManager types.SessionManager,
webSocketManager types.WebSocketManager,
apiManager types.ApiManager,
) {
err := manager.plugins.start(types.PluginManagers{
SessionManager: sessionManager,
WebSocketManager: webSocketManager,
ApiManager: apiManager,
LoadServiceFromPlugin: manager.LookupService,
})
if err != nil {
if manager.config.Required {
manager.logger.Fatal().Err(err).Msg("failed to start plugins, exiting...")
} else {
manager.logger.Err(err).Msg("failed to start plugins, skipping...")
}
}
}
func (manager *ManagerCtx) Shutdown() error {
_ = manager.plugins.forEach(func(d *dependency) error {
err := d.plugin.Shutdown()
manager.logger.Err(err).Str("plugin", d.plugin.Name()).Msg("plugin shutdown")
return nil
})
return nil
}
func (manager *ManagerCtx) LookupService(pluginName string) (any, error) {
plug, ok := manager.plugins.findPlugin(pluginName)
if !ok {
return nil, fmt.Errorf("plugin '%s' not found", pluginName)
}
expPlug, ok := plug.plugin.(types.ExposablePlugin)
if !ok {
return nil, fmt.Errorf("plugin '%s' is not exposable", pluginName)
}
return expPlug.ExposeService(), nil
}
func (manager *ManagerCtx) Metadata() []types.PluginMetadata {
var plugins []types.PluginMetadata
_ = manager.plugins.forEach(func(d *dependency) error {
dependsOn := make([]string, 0)
deps, isDependalbe := d.plugin.(types.DependablePlugin)
if isDependalbe {
dependsOn = deps.DependsOn()
}
_, isExposable := d.plugin.(types.ExposablePlugin)
plugins = append(plugins, types.PluginMetadata{
Name: d.plugin.Name(),
IsDependable: isDependalbe,
IsExposable: isExposable,
DependsOn: dependsOn,
})
return nil
})
return plugins
}

View File

@ -0,0 +1,80 @@
package session
import (
"errors"
"net/http"
"strings"
"time"
"github.com/demodesk/neko/pkg/types"
)
func (manager *SessionManagerCtx) CookieSetToken(w http.ResponseWriter, token string) {
sameSite := http.SameSiteDefaultMode
if manager.config.CookieSecure {
sameSite = http.SameSiteNoneMode
}
http.SetCookie(w, &http.Cookie{
Name: manager.config.CookieName,
Value: token,
Expires: time.Now().Add(manager.config.CookieExpiration),
Secure: manager.config.CookieSecure,
SameSite: sameSite,
HttpOnly: true,
})
}
func (manager *SessionManagerCtx) CookieClearToken(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil {
return
}
cookie.Value = ""
cookie.Expires = time.Unix(0, 0)
http.SetCookie(w, cookie)
}
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) {
token, ok := manager.getToken(r)
if !ok {
return nil, errors.New("no authentication provided")
}
session, ok := manager.GetByToken(token)
if !ok {
return nil, types.ErrSessionNotFound
}
if !session.Profile().CanLogin {
return nil, types.ErrSessionLoginDisabled
}
return session, nil
}
func (manager *SessionManagerCtx) getToken(r *http.Request) (string, bool) {
if manager.CookieEnabled() {
// get from Cookie
cookie, err := r.Cookie(manager.config.CookieName)
if err == nil {
return cookie.Value, true
}
}
// get from Header
reqToken := r.Header.Get("Authorization")
splitToken := strings.Split(reqToken, "Bearer ")
if len(splitToken) == 2 {
return strings.TrimSpace(splitToken[1]), true
}
// get from URL
token := r.URL.Query().Get("token")
if token != "" {
return token, true
}
return "", false
}

View File

@ -0,0 +1,470 @@
package session
import (
"errors"
"sync"
"sync/atomic"
"github.com/kataras/go-events"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
func New(config *config.Session) *SessionManagerCtx {
manager := &SessionManagerCtx{
logger: log.With().Str("module", "session").Logger(),
config: config,
settings: types.Settings{
PrivateMode: config.PrivateMode,
LockedLogins: config.LockedLogins,
LockedControls: config.LockedControls || config.ControlProtection,
ControlProtection: config.ControlProtection,
ImplicitHosting: config.ImplicitHosting,
InactiveCursors: config.InactiveCursors,
MercifulReconnect: config.MercifulReconnect,
},
tokens: make(map[string]string),
sessions: make(map[string]*SessionCtx),
cursors: make(map[types.Session][]types.Cursor),
emmiter: events.New(),
}
// create API session
if config.APIToken != "" {
manager.apiSession = &SessionCtx{
id: "API",
token: config.APIToken,
manager: manager,
logger: manager.logger.With().Str("session_id", "API").Logger(),
profile: types.MemberProfile{
Name: "API Session",
IsAdmin: true,
CanLogin: true,
CanConnect: false,
CanWatch: true,
CanHost: true,
CanAccessClipboard: true,
},
}
}
// try to load sessions from file
manager.load()
return manager
}
type SessionManagerCtx struct {
logger zerolog.Logger
config *config.Session
settings types.Settings
settingsMu sync.Mutex
tokens map[string]string
sessions map[string]*SessionCtx
sessionsMu sync.Mutex
hostId atomic.Value
cursors map[types.Session][]types.Cursor
cursorsMu sync.Mutex
emmiter events.EventEmmiter
apiSession *SessionCtx
}
func (manager *SessionManagerCtx) Create(id string, profile types.MemberProfile) (types.Session, string, error) {
token, err := utils.NewUID(64)
if err != nil {
return nil, "", err
}
manager.sessionsMu.Lock()
if _, ok := manager.sessions[id]; ok {
manager.sessionsMu.Unlock()
return nil, "", types.ErrSessionAlreadyExists
}
if _, ok := manager.tokens[token]; ok {
manager.sessionsMu.Unlock()
return nil, "", errors.New("session token already exists")
}
session := &SessionCtx{
id: id,
token: token,
manager: manager,
logger: manager.logger.With().Str("session_id", id).Logger(),
profile: profile,
}
manager.tokens[token] = id
manager.sessions[id] = session
manager.sessionsMu.Unlock()
manager.emmiter.Emit("created", session)
manager.save()
return session, token, nil
}
func (manager *SessionManagerCtx) Update(id string, profile types.MemberProfile) error {
manager.sessionsMu.Lock()
session, ok := manager.sessions[id]
if !ok {
manager.sessionsMu.Unlock()
return types.ErrSessionNotFound
}
old := session.profile
session.profile = profile
manager.sessionsMu.Unlock()
manager.emmiter.Emit("profile_changed", session, profile, old)
manager.save()
session.profileChanged()
return nil
}
func (manager *SessionManagerCtx) Delete(id string) error {
manager.sessionsMu.Lock()
session, ok := manager.sessions[id]
if !ok {
manager.sessionsMu.Unlock()
return types.ErrSessionNotFound
}
delete(manager.tokens, session.token)
delete(manager.sessions, id)
manager.sessionsMu.Unlock()
if session.State().IsConnected {
session.DestroyWebSocketPeer("session deleted")
}
if session.State().IsWatching {
session.GetWebRTCPeer().Destroy()
}
manager.emmiter.Emit("deleted", session)
manager.save()
return nil
}
func (manager *SessionManagerCtx) Disconnect(id string) error {
manager.sessionsMu.Lock()
session, ok := manager.sessions[id]
if !ok {
manager.sessionsMu.Unlock()
return types.ErrSessionNotFound
}
manager.sessionsMu.Unlock()
if session.State().IsConnected {
session.DestroyWebSocketPeer("session disconnected")
}
if session.State().IsWatching {
session.GetWebRTCPeer().Destroy()
}
return nil
}
func (manager *SessionManagerCtx) Get(id string) (types.Session, bool) {
manager.sessionsMu.Lock()
defer manager.sessionsMu.Unlock()
session, ok := manager.sessions[id]
return session, ok
}
func (manager *SessionManagerCtx) GetByToken(token string) (types.Session, bool) {
manager.sessionsMu.Lock()
id, ok := manager.tokens[token]
manager.sessionsMu.Unlock()
if ok {
return manager.Get(id)
}
// is API session
if manager.apiSession != nil && manager.apiSession.token == token {
return manager.apiSession, true
}
return nil, false
}
func (manager *SessionManagerCtx) List() []types.Session {
manager.sessionsMu.Lock()
defer manager.sessionsMu.Unlock()
var sessions []types.Session
for _, session := range manager.sessions {
sessions = append(sessions, session)
}
return sessions
}
func (manager *SessionManagerCtx) Range(f func(session types.Session) bool) {
manager.sessionsMu.Lock()
defer manager.sessionsMu.Unlock()
for _, session := range manager.sessions {
if !f(session) {
return
}
}
}
// ---
// host
// ---
func (manager *SessionManagerCtx) setHost(session, host types.Session) {
var hostId string
if host != nil {
hostId = host.ID()
}
manager.hostId.Store(hostId)
manager.emmiter.Emit("host_changed", session, host)
}
func (manager *SessionManagerCtx) GetHost() (types.Session, bool) {
hostId, ok := manager.hostId.Load().(string)
if !ok || hostId == "" {
return nil, false
}
return manager.Get(hostId)
}
func (manager *SessionManagerCtx) isHost(host types.Session) bool {
hostId, ok := manager.hostId.Load().(string)
return ok && hostId == host.ID()
}
// ---
// cursors
// ---
func (manager *SessionManagerCtx) SetCursor(cursor types.Cursor, session types.Session) {
manager.cursorsMu.Lock()
defer manager.cursorsMu.Unlock()
list, ok := manager.cursors[session]
if !ok {
list = []types.Cursor{}
}
list = append(list, cursor)
manager.cursors[session] = list
}
func (manager *SessionManagerCtx) PopCursors() map[types.Session][]types.Cursor {
manager.cursorsMu.Lock()
defer manager.cursorsMu.Unlock()
cursors := manager.cursors
manager.cursors = make(map[types.Session][]types.Cursor)
return cursors
}
// ---
// broadcasts
// ---
func (manager *SessionManagerCtx) Broadcast(event string, payload any, exclude ...string) {
for _, session := range manager.List() {
if !session.State().IsConnected {
continue
}
if len(exclude) > 0 {
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
continue
}
}
session.Send(event, payload)
}
}
func (manager *SessionManagerCtx) AdminBroadcast(event string, payload any, exclude ...string) {
for _, session := range manager.List() {
if !session.State().IsConnected || !session.Profile().IsAdmin {
continue
}
if len(exclude) > 0 {
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
continue
}
}
session.Send(event, payload)
}
}
func (manager *SessionManagerCtx) InactiveCursorsBroadcast(event string, payload any, exclude ...string) {
for _, session := range manager.List() {
if !session.State().IsConnected || !session.Profile().CanSeeInactiveCursors {
continue
}
if len(exclude) > 0 {
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
continue
}
}
session.Send(event, payload)
}
}
// ---
// events
// ---
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
manager.emmiter.On("created", func(payload ...any) {
listener(payload[0].(*SessionCtx))
})
}
func (manager *SessionManagerCtx) OnDeleted(listener func(session types.Session)) {
manager.emmiter.On("deleted", func(payload ...any) {
listener(payload[0].(*SessionCtx))
})
}
func (manager *SessionManagerCtx) OnConnected(listener func(session types.Session)) {
manager.emmiter.On("connected", func(payload ...any) {
listener(payload[0].(*SessionCtx))
})
}
func (manager *SessionManagerCtx) OnDisconnected(listener func(session types.Session)) {
manager.emmiter.On("disconnected", func(payload ...any) {
listener(payload[0].(*SessionCtx))
})
}
func (manager *SessionManagerCtx) OnProfileChanged(listener func(session types.Session, new, old types.MemberProfile)) {
manager.emmiter.On("profile_changed", func(payload ...any) {
listener(payload[0].(*SessionCtx), payload[1].(types.MemberProfile), payload[2].(types.MemberProfile))
})
}
func (manager *SessionManagerCtx) OnStateChanged(listener func(session types.Session)) {
manager.emmiter.On("state_changed", func(payload ...any) {
listener(payload[0].(*SessionCtx))
})
}
func (manager *SessionManagerCtx) OnHostChanged(listener func(session, host types.Session)) {
manager.emmiter.On("host_changed", func(payload ...any) {
if payload[1] == nil {
listener(payload[0].(*SessionCtx), nil)
} else {
listener(payload[0].(*SessionCtx), payload[1].(*SessionCtx))
}
})
}
func (manager *SessionManagerCtx) OnSettingsChanged(listener func(session types.Session, new, old types.Settings)) {
manager.emmiter.On("settings_changed", func(payload ...any) {
listener(payload[0].(types.Session), payload[1].(types.Settings), payload[2].(types.Settings))
})
}
// ---
// settings
// ---
func (manager *SessionManagerCtx) UpdateSettingsFunc(session types.Session, f func(settings *types.Settings) bool) {
manager.settingsMu.Lock()
new := manager.settings
if f(&new) {
old := manager.settings
manager.settings = new
manager.settingsMu.Unlock()
manager.updateSettings(session, new, old)
return
}
manager.settingsMu.Unlock()
}
func (manager *SessionManagerCtx) updateSettings(session types.Session, new, old types.Settings) {
// if private mode changed
if old.PrivateMode != new.PrivateMode {
// update webrtc paused state for all sessions
for _, s := range manager.List() {
enabled := s.PrivateModeEnabled()
// if session had control, it must release it
if enabled && s.IsHost() {
session.ClearHost()
}
// its webrtc connection will be paused or unpaused
if webrtcPeer := s.GetWebRTCPeer(); webrtcPeer != nil {
webrtcPeer.SetPaused(enabled)
}
}
}
// if control protection changed and controls are not locked
if old.ControlProtection != new.ControlProtection && new.ControlProtection && !new.LockedControls {
// if there is no admin, lock controls
hasAdmin := false
manager.Range(func(session types.Session) bool {
if session.Profile().IsAdmin && session.State().IsConnected {
hasAdmin = true
return false
}
return true
})
if !hasAdmin {
manager.settingsMu.Lock()
manager.settings.LockedControls = true
new.LockedControls = true
manager.settingsMu.Unlock()
}
}
// if contols have been locked
if old.LockedControls != new.LockedControls && new.LockedControls {
// if the host is not admin, it must release controls
host, hasHost := manager.GetHost()
if hasHost && !host.Profile().IsAdmin {
session.ClearHost()
}
}
manager.emmiter.Emit("settings_changed", session, new, old)
}
func (manager *SessionManagerCtx) Settings() types.Settings {
manager.settingsMu.Lock()
defer manager.settingsMu.Unlock()
return manager.settings
}
func (manager *SessionManagerCtx) CookieEnabled() bool {
return manager.config.CookieEnabled
}

View File

@ -0,0 +1,97 @@
package session
import (
"encoding/json"
"errors"
"os"
"github.com/demodesk/neko/pkg/types"
)
func (manager *SessionManagerCtx) save() {
if manager.config.File == "" {
return
}
// serialize sessions
sessions := make([]types.SessionProfile, 0, len(manager.sessions))
for _, session := range manager.sessions {
sessions = append(sessions, types.SessionProfile{
Id: session.id,
Token: session.token,
Profile: session.profile,
})
}
// convert to json
data, err := json.Marshal(sessions)
if err != nil {
manager.logger.Error().Err(err).Msg("failed to marshal sessions")
return
}
// write to file
err = os.WriteFile(manager.config.File, data, 0644)
if err != nil {
manager.logger.Error().Err(err).
Str("file", manager.config.File).
Msg("failed to write sessions to a file")
}
}
func (manager *SessionManagerCtx) load() {
if manager.config.File == "" {
return
}
// read file
data, err := os.ReadFile(manager.config.File)
if err != nil {
// if file does not exist
if errors.Is(err, os.ErrNotExist) {
manager.logger.Info().
Str("file", manager.config.File).
Msg("sessions file does not exist")
return
}
manager.logger.Error().Err(err).
Str("file", manager.config.File).
Msg("failed to read sessions from a file")
return
}
// if file is empty
if len(data) == 0 {
manager.logger.Info().
Str("file", manager.config.File).
Msg("sessions file is empty")
return
}
// deserialize sessions
sessions := make([]types.SessionProfile, 0)
err = json.Unmarshal(data, &sessions)
if err != nil {
manager.logger.Error().Err(err).Msg("failed to unmarshal sessions")
return
}
// create sessions
manager.sessionsMu.Lock()
for _, session := range sessions {
manager.tokens[session.Token] = session.Id
manager.sessions[session.Id] = &SessionCtx{
id: session.Id,
token: session.Token,
manager: manager,
logger: manager.logger.With().Str("session_id", session.Id).Logger(),
profile: session.Profile,
}
}
manager.sessionsMu.Unlock()
manager.logger.Info().
Int("sessions", len(sessions)).
Str("file", manager.config.File).
Msg("loaded sessions from a file")
}

View File

@ -0,0 +1,283 @@
package session
import (
"sync"
"time"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
)
// client is expected to reconnect within 5 second
// if some unexpected websocket disconnect happens
const WS_DELAYED_DURATION = 5 * time.Second
type SessionCtx struct {
id string
token string
logger zerolog.Logger
manager *SessionManagerCtx
profile types.MemberProfile
state types.SessionState
websocketPeer types.WebSocketPeer
websocketMu sync.Mutex
// websocket delayed set connected events
wsDelayedMu sync.Mutex
wsDelayedTimer *time.Timer
webrtcPeer types.WebRTCPeer
webrtcMu sync.Mutex
}
func (session *SessionCtx) ID() string {
return session.id
}
func (session *SessionCtx) Profile() types.MemberProfile {
return session.profile
}
func (session *SessionCtx) profileChanged() {
if !session.profile.CanHost && session.IsHost() {
session.ClearHost()
}
if (!session.profile.CanConnect || !session.profile.CanLogin || !session.profile.CanWatch) && session.state.IsWatching {
session.GetWebRTCPeer().Destroy()
}
if (!session.profile.CanConnect || !session.profile.CanLogin) && session.state.IsConnected {
session.DestroyWebSocketPeer("profile changed")
}
// update webrtc paused state
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
webrtcPeer.SetPaused(session.PrivateModeEnabled())
}
}
func (session *SessionCtx) State() types.SessionState {
return session.state
}
func (session *SessionCtx) IsHost() bool {
return session.manager.isHost(session)
}
func (session *SessionCtx) SetAsHost() {
session.manager.setHost(session, session)
}
func (session *SessionCtx) SetAsHostBy(host types.Session) {
session.manager.setHost(session, host)
}
func (session *SessionCtx) ClearHost() {
session.manager.setHost(session, nil)
}
func (session *SessionCtx) PrivateModeEnabled() bool {
return session.manager.Settings().PrivateMode && !session.profile.IsAdmin
}
func (session *SessionCtx) SetCursor(cursor types.Cursor) {
if session.manager.Settings().InactiveCursors && session.profile.SendsInactiveCursor {
session.manager.SetCursor(cursor, session)
}
}
// ---
// websocket
// ---
// Connect WebSocket peer sets current peer and emits connected event. It also destroys the
// previous peer, if there was one. If the peer is already set, it will be ignored.
func (session *SessionCtx) ConnectWebSocketPeer(websocketPeer types.WebSocketPeer) {
session.websocketMu.Lock()
isCurrentPeer := websocketPeer == session.websocketPeer
session.websocketPeer, websocketPeer = websocketPeer, session.websocketPeer
session.websocketMu.Unlock()
// ignore if already set
if isCurrentPeer {
return
}
session.logger.Info().Msg("set websocket connected")
// update state
now := time.Now()
session.state.IsConnected = true
session.state.ConnectedSince = &now
session.state.NotConnectedSince = nil
session.manager.emmiter.Emit("connected", session)
// if there is a previous peer, destroy it
if websocketPeer != nil {
websocketPeer.Destroy("connection replaced")
}
}
// Disconnect WebSocket peer sets current peer to nil and emits disconnected event. It also
// allows for a delayed disconnect. That means, the peer will not be disconnected immediately,
// but after a delay. If the peer is connected again before the delay, the disconnect will be
// cancelled.
//
// If the peer is not the current peer or the peer is nil, it will be ignored.
func (session *SessionCtx) DisconnectWebSocketPeer(websocketPeer types.WebSocketPeer, delayed bool) {
session.websocketMu.Lock()
isCurrentPeer := websocketPeer == session.websocketPeer && websocketPeer != nil
session.websocketMu.Unlock()
// ignore if not current peer
if !isCurrentPeer {
return
}
//
// ws delayed
//
var wsDelayedTimer *time.Timer
if delayed {
wsDelayedTimer = time.AfterFunc(WS_DELAYED_DURATION, func() {
session.DisconnectWebSocketPeer(websocketPeer, false)
})
}
session.wsDelayedMu.Lock()
if session.wsDelayedTimer != nil {
session.wsDelayedTimer.Stop()
}
session.wsDelayedTimer = wsDelayedTimer
session.wsDelayedMu.Unlock()
if delayed {
session.logger.Info().Msg("delayed websocket disconnected")
return
}
//
// not delayed
//
session.logger.Info().Msg("set websocket disconnected")
now := time.Now()
session.state.IsConnected = false
session.state.ConnectedSince = nil
session.state.NotConnectedSince = &now
session.manager.emmiter.Emit("disconnected", session)
session.websocketMu.Lock()
if websocketPeer == session.websocketPeer {
session.websocketPeer = nil
}
session.websocketMu.Unlock()
}
// Destroy WebSocket peer disconnects the peer and destroys it. It ensures that the peer is
// disconnected immediately even though normal flow would be to disconnect it delayed.
func (session *SessionCtx) DestroyWebSocketPeer(reason string) {
session.websocketMu.Lock()
peer := session.websocketPeer
session.websocketMu.Unlock()
if peer == nil {
return
}
// disconnect peer first, so that it is not used anymore
session.DisconnectWebSocketPeer(peer, false)
// destroy it afterwards
peer.Destroy(reason)
}
// Send event to websocket peer.
func (session *SessionCtx) Send(event string, payload any) {
session.websocketMu.Lock()
peer := session.websocketPeer
session.websocketMu.Unlock()
if peer != nil {
peer.Send(event, payload)
}
}
// ---
// webrtc
// ---
// Set webrtc peer and destroy the old one, if there is old one.
func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) {
session.webrtcMu.Lock()
session.webrtcPeer, webrtcPeer = webrtcPeer, session.webrtcPeer
session.webrtcMu.Unlock()
if webrtcPeer != nil && webrtcPeer != session.webrtcPeer {
webrtcPeer.Destroy()
}
}
// Set if current webrtc peer is connected or not. Since there might be lefover calls from
// webrtc peer, that are not used anymore, we need to check if the webrtc peer is still the
// same as the one we are setting the connected state for.
//
// If webrtc peer is disconnected, we don't expect it to be reconnected, so we set it to nil
// and send a signal close to the client. New connection is expected to use a new webrtc peer.
func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, connected bool) {
session.webrtcMu.Lock()
isCurrentPeer := webrtcPeer == session.webrtcPeer
session.webrtcMu.Unlock()
if !isCurrentPeer {
return
}
session.logger.Info().
Bool("connected", connected).
Msg("set webrtc connected")
// update state
session.state.IsWatching = connected
if now := time.Now(); connected {
session.state.WatchingSince = &now
session.state.NotWatchingSince = nil
} else {
session.state.WatchingSince = nil
session.state.NotWatchingSince = &now
}
session.manager.emmiter.Emit("state_changed", session)
if connected {
return
}
session.webrtcMu.Lock()
isCurrentPeer = webrtcPeer == session.webrtcPeer
if isCurrentPeer {
session.webrtcPeer = nil
}
session.webrtcMu.Unlock()
if isCurrentPeer {
session.Send(event.SIGNAL_CLOSE, nil)
}
}
// Get current WebRTC peer. Nil if not connected.
func (session *SessionCtx) GetWebRTCPeer() types.WebRTCPeer {
session.webrtcMu.Lock()
defer session.webrtcMu.Unlock()
return session.webrtcPeer
}

View File

@ -0,0 +1,168 @@
package cursor
import (
"reflect"
"sync"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/utils"
)
type ImageListener interface {
SendCursorImage(cur *types.CursorImage, img []byte) error
}
type Image interface {
Start()
Shutdown()
GetCurrent() (cur *types.CursorImage, img []byte, err error)
AddListener(listener ImageListener)
RemoveListener(listener ImageListener)
}
type imageEntry struct {
*types.CursorImage
ImagePNG []byte
}
type image struct {
logger zerolog.Logger
desktop types.DesktopManager
listeners map[uintptr]ImageListener
listenersMu sync.RWMutex
cache map[uint64]*imageEntry
cacheMu sync.RWMutex
current *imageEntry
maxSerial uint64
}
func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image {
return &image{
logger: logger.With().Str("submodule", "cursor-image").Logger(),
desktop: desktop,
listeners: map[uintptr]ImageListener{},
cache: map[uint64]*imageEntry{},
maxSerial: 300, // TODO: Cleanup?
}
}
func (manager *image) Start() {
manager.desktop.OnCursorChanged(func(serial uint64) {
entry, err := manager.getCached(serial)
if err != nil {
manager.logger.Err(err).Msg("failed to get cursor image")
return
}
manager.current = entry
manager.listenersMu.RLock()
for _, l := range manager.listeners {
if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil {
manager.logger.Err(err).Msg("failed to set cursor image")
}
}
manager.listenersMu.RUnlock()
})
manager.logger.Info().Msg("starting")
}
func (manager *image) Shutdown() {
manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock()
for key := range manager.listeners {
delete(manager.listeners, key)
}
manager.listenersMu.Unlock()
}
func (manager *image) getCached(serial uint64) (*imageEntry, error) {
// zero means no serial available
if serial == 0 || serial > manager.maxSerial {
manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass")
return manager.fetchEntry()
}
manager.cacheMu.RLock()
entry, ok := manager.cache[serial]
manager.cacheMu.RUnlock()
if ok {
return entry, nil
}
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss")
entry, err := manager.fetchEntry()
if err != nil {
return nil, err
}
manager.cacheMu.Lock()
manager.cache[entry.Serial] = entry
manager.cacheMu.Unlock()
if entry.Serial != serial {
manager.logger.Warn().
Uint64("expected-serial", serial).
Uint64("received-serial", entry.Serial).
Msg("serial mismatch")
}
return entry, nil
}
func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) {
if manager.current != nil {
return manager.current.CursorImage, manager.current.ImagePNG, nil
}
entry, err := manager.fetchEntry()
if err != nil {
return nil, nil, err
}
manager.current = entry
return entry.CursorImage, entry.ImagePNG, nil
}
func (manager *image) AddListener(listener ImageListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
manager.listeners[ptr] = listener
}
}
func (manager *image) RemoveListener(listener ImageListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
delete(manager.listeners, ptr)
}
}
func (manager *image) fetchEntry() (*imageEntry, error) {
cur := manager.desktop.GetCursorImage()
img, err := utils.CreatePNGImage(cur.Image)
if err != nil {
return nil, err
}
cur.Image = nil // free memory
return &imageEntry{
CursorImage: cur,
ImagePNG: img,
}, nil
}

View File

@ -0,0 +1,74 @@
package cursor
import (
"reflect"
"sync"
"github.com/rs/zerolog"
)
type PositionListener interface {
SendCursorPosition(x, y int) error
}
type Position interface {
Shutdown()
Set(x, y int)
AddListener(listener PositionListener)
RemoveListener(listener PositionListener)
}
type position struct {
logger zerolog.Logger
listeners map[uintptr]PositionListener
listenersMu sync.RWMutex
}
func NewPosition(logger zerolog.Logger) *position {
return &position{
logger: logger.With().Str("submodule", "cursor-position").Logger(),
listeners: map[uintptr]PositionListener{},
}
}
func (manager *position) Shutdown() {
manager.logger.Info().Msg("shutdown")
manager.listenersMu.Lock()
for key := range manager.listeners {
delete(manager.listeners, key)
}
manager.listenersMu.Unlock()
}
func (manager *position) Set(x, y int) {
manager.listenersMu.RLock()
defer manager.listenersMu.RUnlock()
for _, l := range manager.listeners {
if err := l.SendCursorPosition(x, y); err != nil {
manager.logger.Err(err).Msg("failed to set cursor position")
}
}
}
func (manager *position) AddListener(listener PositionListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
manager.listeners[ptr] = listener
}
}
func (manager *position) RemoveListener(listener PositionListener) {
manager.listenersMu.Lock()
defer manager.listenersMu.Unlock()
if listener != nil {
ptr := reflect.ValueOf(listener).Pointer()
delete(manager.listeners, ptr)
}
}

View File

@ -0,0 +1,205 @@
package webrtc
import (
"bytes"
"encoding/binary"
"math"
"time"
"github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
)
func (manager *WebRTCManagerCtx) handle(
logger zerolog.Logger, data []byte,
dataChannel *webrtc.DataChannel,
session types.Session,
) error {
isHost := session.IsHost()
//
// parse header
//
buffer := bytes.NewBuffer(data)
header := &payload.Header{}
if err := binary.Read(buffer, binary.BigEndian, header); err != nil {
return err
}
//
// parse body
//
// handle cursor move event
if header.Event == payload.OP_MOVE {
payload := &payload.Move{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
x, y := int(payload.X), int(payload.Y)
if isHost {
// handle active cursor movement
manager.desktop.Move(x, y)
manager.curPosition.Set(x, y)
} else {
// handle inactive cursor movement
session.SetCursor(types.Cursor{
X: x,
Y: y,
})
}
return nil
} else if header.Event == payload.OP_PING {
ping := &payload.Ping{}
if err := binary.Read(buffer, binary.BigEndian, ping); err != nil {
return err
}
// create pong header
header := payload.Header{
Event: payload.OP_PONG,
Length: 19,
}
// generate server timestamp
serverTs := uint64(time.Now().UnixMilli())
// generate pong payload
pong := payload.Pong{
Ping: *ping,
ServerTs1: uint32(serverTs / math.MaxUint32),
ServerTs2: uint32(serverTs % math.MaxUint32),
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, pong); err != nil {
return err
}
return dataChannel.Send(buffer.Bytes())
}
// continue only if session is host
if !isHost {
return nil
}
switch header.Event {
case payload.OP_SCROLL:
// TODO: remove this once the client is fixed
if header.Length == 4 {
payload := &payload.Scroll_Old{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
manager.desktop.Scroll(int(payload.X), int(payload.Y), false)
logger.Trace().
Int16("x", payload.X).
Int16("y", payload.Y).
Msg("scroll")
} else {
payload := &payload.Scroll{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
manager.desktop.Scroll(int(payload.DeltaX), int(payload.DeltaY), payload.ControlKey)
logger.Trace().
Int16("deltaX", payload.DeltaX).
Int16("deltaY", payload.DeltaY).
Bool("controlKey", payload.ControlKey).
Msg("scroll")
}
case payload.OP_KEY_DOWN:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.KeyDown(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key down failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("key down")
}
case payload.OP_KEY_UP:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.KeyUp(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key up failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("key up")
}
case payload.OP_BTN_DOWN:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.ButtonDown(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button down failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("button down")
}
case payload.OP_BTN_UP:
payload := &payload.Key{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.ButtonUp(payload.Key); err != nil {
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button up failed")
} else {
logger.Trace().Uint32("key", payload.Key).Msg("button up")
}
case payload.OP_TOUCH_BEGIN:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchBegin(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch begin failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch begin")
}
case payload.OP_TOUCH_UPDATE:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchUpdate(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch update failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch update")
}
case payload.OP_TOUCH_END:
payload := &payload.Touch{}
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
return err
}
if err := manager.desktop.TouchEnd(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch end failed")
} else {
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch end")
}
}
return nil
}

View File

@ -0,0 +1,576 @@
package webrtc
import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v2"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/interceptor/pkg/gcc"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/cursor"
"github.com/demodesk/neko/internal/webrtc/pionlog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
const (
// size of receiving channel used to buffer incoming TCP packets
tcpReadChanBufferSize = 50
// size of buffer used to buffer outgoing TCP packets. Default is 4MB
tcpWriteBufferSizeInBytes = 4 * 1024 * 1024
// the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds
disconnectedTimeout = 4 * time.Second
// the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds
failedTimeout = 6 * time.Second
// how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds
keepAliveInterval = 2 * time.Second
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
rtcpPLIInterval = 3 * time.Second
)
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
logger := log.With().Str("module", "webrtc").Logger()
configuration := webrtc.Configuration{
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
}
if !config.ICELite {
ICEServers := []webrtc.ICEServer{}
for _, server := range config.ICEServersBackend {
var credential any
if server.Credential != "" {
credential = server.Credential
} else {
credential = false
}
ICEServers = append(ICEServers, webrtc.ICEServer{
URLs: server.URLs,
Username: server.Username,
Credential: credential,
})
}
configuration.ICEServers = ICEServers
}
return &WebRTCManagerCtx{
logger: logger,
config: config,
metrics: newMetricsManager(),
webrtcConfiguration: configuration,
desktop: desktop,
capture: capture,
curImage: cursor.NewImage(logger, desktop),
curPosition: cursor.NewPosition(logger),
}
}
type WebRTCManagerCtx struct {
logger zerolog.Logger
config *config.WebRTC
metrics *metricsManager
peerId int32
desktop types.DesktopManager
capture types.CaptureManager
curImage cursor.Image
curPosition cursor.Position
webrtcConfiguration webrtc.Configuration
tcpMux ice.TCPMux
udpMux ice.UDPMux
camStop, micStop *func()
}
func (manager *WebRTCManagerCtx) Start() {
manager.curImage.Start()
logger := pionlog.New(manager.logger)
// add TCP Mux listener
if manager.config.TCPMux > 0 {
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{0, 0, 0, 0},
Port: manager.config.TCPMux,
})
if err != nil {
manager.logger.Fatal().Err(err).Msg("unable to setup ice TCP mux")
}
manager.tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{
Listener: tcpListener,
Logger: logger.NewLogger("ice-tcp"),
ReadBufferSize: tcpReadChanBufferSize,
WriteBufferSize: tcpWriteBufferSizeInBytes,
})
}
// add UDP Mux listener
if manager.config.UDPMux > 0 {
var err error
manager.udpMux, err = ice.NewMultiUDPMuxFromPort(manager.config.UDPMux,
ice.UDPMuxFromPortWithLogger(logger.NewLogger("ice-udp")),
)
if err != nil {
manager.logger.Fatal().Err(err).Msg("unable to setup ice UDP mux")
}
}
manager.logger.Info().
Bool("icelite", manager.config.ICELite).
Bool("icetrickle", manager.config.ICETrickle).
Interface("iceservers-frontend", manager.config.ICEServersFrontend).
Interface("iceservers-backend", manager.config.ICEServersBackend).
Str("nat1to1", strings.Join(manager.config.NAT1To1IPs, ",")).
Str("epr", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)).
Int("tcpmux", manager.config.TCPMux).
Int("udpmux", manager.config.UDPMux).
Msg("webrtc starting")
}
func (manager *WebRTCManagerCtx) Shutdown() error {
manager.logger.Info().Msg("shutdown")
manager.curImage.Shutdown()
manager.curPosition.Shutdown()
return nil
}
func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
return manager.config.ICEServersFrontend
}
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
// create media engine
engine := &webrtc.MediaEngine{}
for _, codec := range codecs {
if err := codec.Register(engine); err != nil {
return nil, nil, err
}
}
// create setting engine
settings := webrtc.SettingEngine{
LoggerFactory: pionlog.New(logger),
}
settings.DisableMediaEngineCopy(true)
settings.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
settings.SetNAT1To1IPs(manager.config.NAT1To1IPs, webrtc.ICECandidateTypeHost)
settings.SetLite(manager.config.ICELite)
// make sure server answer sdp setup as passive, to not force DTLS renegotiation
// otherwise iOS renegotiation fails with: Failed to set SSL role for the transport.
settings.SetAnsweringDTLSRole(webrtc.DTLSRoleServer)
var networkType []webrtc.NetworkType
// udp candidates
if manager.udpMux != nil {
settings.SetICEUDPMux(manager.udpMux)
networkType = append(networkType,
webrtc.NetworkTypeUDP4,
webrtc.NetworkTypeUDP6,
)
} else if manager.config.EphemeralMax != 0 {
_ = settings.SetEphemeralUDPPortRange(manager.config.EphemeralMin, manager.config.EphemeralMax)
networkType = append(networkType,
webrtc.NetworkTypeUDP4,
webrtc.NetworkTypeUDP6,
)
}
// tcp candidates
if manager.tcpMux != nil {
settings.SetICETCPMux(manager.tcpMux)
networkType = append(networkType,
webrtc.NetworkTypeTCP4,
webrtc.NetworkTypeTCP6,
)
}
// enable support for TCP and UDP ICE candidates
settings.SetNetworkTypes(networkType)
// create interceptor registry
registry := &interceptor.Registry{}
// create bandwidth estimator
estimatorChan := make(chan cc.BandwidthEstimator, 1)
if manager.config.Estimator.Enabled {
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
return gcc.NewSendSideBWE(
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
)
})
if err != nil {
return nil, nil, err
}
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
estimatorChan <- estimator
})
registry.Add(congestionController)
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
return nil, nil, err
}
} else {
// no estimator, send nil
estimatorChan <- nil
}
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
return nil, nil, err
}
// create new API
api := webrtc.NewAPI(
webrtc.WithMediaEngine(engine),
webrtc.WithSettingEngine(settings),
webrtc.WithInterceptorRegistry(registry),
)
// create new peer connection
configuration := manager.webrtcConfiguration
connection, err := api.NewPeerConnection(configuration)
return connection, <-estimatorChan, err
}
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
id := atomic.AddInt32(&manager.peerId, 1)
// get metrics for session
metrics := manager.metrics.getBySession(session)
metrics.NewConnection()
// add session id to logger context
logger := manager.logger.With().Str("session_id", session.ID()).Int32("peer_id", id).Logger()
logger.Info().Msg("creating webrtc peer")
// all audios must have the same codec
audio := manager.capture.Audio()
audioCodec := audio.Codec()
// all videos must have the same codec
video := manager.capture.Video()
videoCodec := video.Codec()
connection, estimator, err := manager.newPeerConnection(
logger, []codec.RTPCodec{audioCodec, videoCodec})
if err != nil {
return nil, nil, err
}
// asynchronously send local ICE Candidates
if manager.config.ICETrickle {
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
logger.Debug().Msg("all local ice candidates sent")
return
}
session.Send(
event.SIGNAL_CANDIDATE,
message.SignalCandidate{
ICECandidateInit: candidate.ToJSON(),
})
})
}
// audio track
audioTrack, err := NewTrack(logger, audioCodec, connection)
if err != nil {
return nil, nil, err
}
// we disable audio by default manually
audioTrack.SetPaused(true)
// set stream for audio track
_, err = audioTrack.SetStream(audio)
if err != nil {
return nil, nil, err
}
// video track
videoRtcp := make(chan []rtcp.Packet, 1)
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
if err != nil {
return nil, nil, err
}
//
// stream for video track will be set later
//
// data channel
dataChannel, err := connection.CreateDataChannel("data", nil)
if err != nil {
return nil, nil, err
}
peer := &WebRTCPeerCtx{
logger: logger,
session: session,
metrics: metrics,
connection: connection,
// bandwidth estimator
estimator: estimator,
estimateTrend: utils.NewTrendDetector(
utils.TrendDetectorParams{
// Probing
//RequiredSamples: 3,
//DownwardTrendThreshold: 0.0,
//CollapseValues: false,
// Non-Probing
RequiredSamples: 8,
DownwardTrendThreshold: -0.5,
CollapseValues: true,
}),
// stream selectors
video: video,
audio: audio,
// tracks & channels
audioTrack: audioTrack,
videoTrack: videoTrack,
dataChannel: dataChannel,
rtcpChannel: videoRtcp,
// config
iceTrickle: manager.config.ICETrickle,
estimatorConfig: manager.config.Estimator,
audioDisabled: true, // we disable audio by default manually
}
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
logger := logger.With().
Str("kind", track.Kind().String()).
Str("mime", track.Codec().RTPCodecCapability.MimeType).
Logger()
logger.Info().Msgf("received new remote track")
if !session.Profile().CanShareMedia {
err := receiver.Stop()
logger.Warn().Err(err).Msg("media sharing is disabled for this session")
return
}
// parse codec from remote track
codec, ok := codec.ParseRTC(track.Codec())
if !ok {
err := receiver.Stop()
logger.Warn().Err(err).Msg("remote track with unknown codec")
return
}
var srcManager types.StreamSrcManager
stopped := false
stopFn := func() {
if stopped {
return
}
stopped = true
err := receiver.Stop()
srcManager.Stop()
logger.Err(err).Msg("remote track stopped")
}
if track.Kind() == webrtc.RTPCodecTypeAudio {
// audio -> microphone
srcManager = manager.capture.Microphone()
defer stopFn()
if manager.micStop != nil {
(*manager.micStop)()
}
manager.micStop = &stopFn
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
// video -> webcam
srcManager = manager.capture.Webcam()
defer stopFn()
if manager.camStop != nil {
(*manager.camStop)()
}
manager.camStop = &stopFn
} else {
err := receiver.Stop()
logger.Warn().Err(err).Msg("remote track with unsupported codec type")
return
}
err := srcManager.Start(codec)
if err != nil {
logger.Err(err).Msg("failed to start pipeline")
return
}
ticker := time.NewTicker(rtcpPLIInterval)
defer ticker.Stop()
go func() {
for range ticker.C {
err := connection.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{
MediaSSRC: uint32(track.SSRC()),
},
})
if err != nil {
logger.Err(err).Msg("remote track rtcp send err")
}
}
}()
buf := make([]byte, 1400)
for {
i, _, err := track.Read(buf)
if err != nil {
logger.Warn().Err(err).Msg("failed read from remote track")
break
}
srcManager.Push(buf[:i])
}
logger.Info().Msg("remote track data finished")
})
connection.OnDataChannel(func(dc *webrtc.DataChannel) {
logger.Info().Interface("data_channel", dc).Msg("got remote data channel")
})
var once sync.Once
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateConnected:
session.SetWebRTCConnected(peer, true)
case webrtc.PeerConnectionStateDisconnected,
webrtc.PeerConnectionStateFailed:
peer.Destroy()
case webrtc.PeerConnectionStateClosed:
// ensure we only run this once
once.Do(func() {
session.SetWebRTCConnected(peer, false)
//
// TODO: Shutdown peer?
//
audioTrack.Shutdown()
videoTrack.Shutdown()
close(videoRtcp)
})
}
metrics.SetState(state)
})
dataChannel.OnOpen(func() {
manager.curImage.AddListener(peer)
manager.curPosition.AddListener(peer)
// send initial cursor image
cur, img, err := manager.curImage.GetCurrent()
if err == nil {
err := peer.SendCursorImage(cur, img)
if err != nil {
logger.Err(err).Msg("failed to set cursor image")
}
} else {
logger.Err(err).Msg("failed to get cursor image")
}
// send initial cursor position
x, y := manager.desktop.GetCursorPosition()
err = peer.SendCursorPosition(x, y)
if err != nil {
logger.Err(err).Msg("failed to set cursor position")
}
})
dataChannel.OnClose(func() {
manager.curImage.RemoveListener(peer)
manager.curPosition.RemoveListener(peer)
})
dataChannel.OnMessage(func(message webrtc.DataChannelMessage) {
if err := manager.handle(logger, message.Data, dataChannel, session); err != nil {
logger.Err(err).Msg("data handle failed")
}
})
session.SetWebRTCPeer(peer)
offer, err := peer.CreateOffer(false)
if err != nil {
return nil, nil, err
}
// on negotiation needed handler must be registered after creating initial
// offer, otherwise it can fire and intercept sucessful negotiation
connection.OnNegotiationNeeded(func() {
logger.Warn().Msg("negotiation is needed")
if connection.SignalingState() != webrtc.SignalingStateStable {
logger.Warn().Msg("connection isn't stable yet; postponing...")
return
}
offer, err := peer.CreateOffer(false)
if err != nil {
logger.Err(err).Msg("sdp offer failed")
return
}
session.Send(
event.SIGNAL_OFFER,
message.SignalDescription{
SDP: offer.SDP,
})
})
// start metrics collectors
go metrics.rtcpReceiver(videoRtcp)
go metrics.connectionStats(connection)
// start estimator reader
go peer.estimatorReader()
return offer, peer, nil
}
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
manager.curPosition.Set(x, y)
}

View File

@ -0,0 +1,458 @@
package webrtc
import (
"sync"
"time"
"github.com/demodesk/neko/pkg/types"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const (
// how often to read and process webrtc connection stats
connectionStatsInterval = 5 * time.Second
)
type metricsManager struct {
mu sync.Mutex
sessions map[string]*metrics
}
func newMetricsManager() *metricsManager {
return &metricsManager{
sessions: map[string]*metrics{},
}
}
func (m *metricsManager) getBySession(session types.Session) *metrics {
m.mu.Lock()
defer m.mu.Unlock()
sessionId := session.ID()
met, ok := m.sessions[sessionId]
if ok {
return met
}
met = &metrics{
sessionId: sessionId,
connectionState: promauto.NewGauge(prometheus.GaugeOpts{
Name: "connection_state",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Connection state of session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
connectionStateCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "connection_state_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of connection state changes for a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
connectionCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "connection_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Connection count of a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
iceCandidates: map[string]struct{}{},
iceCandidatesMu: &sync.Mutex{},
iceCandidatesUdpCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "ice_candidates_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of ICE candidates sent by a remote client.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "udp",
},
}),
iceCandidatesTcpCount: promauto.NewCounter(prometheus.CounterOpts{
Name: "ice_candidates_count",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Count of ICE candidates sent by a remote client.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "tcp",
},
}),
iceCandidatesUsedUdp: promauto.NewGauge(prometheus.GaugeOpts{
Name: "ice_candidates_used",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Used ICE candidates that are currently in use.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "udp",
},
}),
iceCandidatesUsedTcp: promauto.NewGauge(prometheus.GaugeOpts{
Name: "ice_candidates_used",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Used ICE candidates that are currently in use.",
ConstLabels: map[string]string{
"session_id": sessionId,
"protocol": "tcp",
},
}),
videoIds: map[string]prometheus.Gauge{},
videoIdsMu: &sync.Mutex{},
receiverEstimatedMaximumBitrate: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_estimated_maximum_bitrate",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverEstimatedTargetBitrate: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_estimated_target_bitrate",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Estimated Target Bitrate using Google's congestion control.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportDelay: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_delay",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportJitter: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_jitter",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Jitter from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
receiverReportTotalLost: promauto.NewGauge(prometheus.GaugeOpts{
Name: "receiver_report_total_lost",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Receiver Report Total Lost from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
Name: "transport_layer_nacks",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Transport Layer NACKs from RTCP.",
ConstLabels: map[string]string{
"session_id": sessionId,
},
}),
iceBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_sent",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Sent bytes to a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "ice",
},
}),
iceBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_received",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Received bytes from a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "ice",
},
}),
sctpBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_sent",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Sent bytes to a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "sctp",
},
}),
sctpBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
Name: "bytes_received",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Received bytes from a session.",
ConstLabels: map[string]string{
"session_id": sessionId,
"transport": "sctp",
},
}),
}
m.sessions[sessionId] = met
return met
}
type metrics struct {
sessionId string
connectionState prometheus.Gauge
connectionStateCount prometheus.Counter
connectionCount prometheus.Counter
iceCandidates map[string]struct{}
iceCandidatesMu *sync.Mutex
iceCandidatesUdpCount prometheus.Counter
iceCandidatesTcpCount prometheus.Counter
iceCandidatesUsedUdp prometheus.Gauge
iceCandidatesUsedTcp prometheus.Gauge
videoIds map[string]prometheus.Gauge
videoIdsMu *sync.Mutex
receiverEstimatedMaximumBitrate prometheus.Gauge
receiverEstimatedTargetBitrate prometheus.Gauge
receiverReportDelay prometheus.Gauge
receiverReportJitter prometheus.Gauge
receiverReportTotalLost prometheus.Gauge
transportLayerNacks prometheus.Counter
iceBytesSent prometheus.Gauge
iceBytesReceived prometheus.Gauge
sctpBytesSent prometheus.Gauge
sctpBytesReceived prometheus.Gauge
}
func (met *metrics) reset() {
met.videoIdsMu.Lock()
for _, entry := range met.videoIds {
entry.Set(0)
}
met.videoIdsMu.Unlock()
met.iceCandidatesUsedUdp.Set(float64(0))
met.iceCandidatesUsedTcp.Set(float64(0))
met.receiverEstimatedMaximumBitrate.Set(0)
met.receiverReportDelay.Set(0)
met.receiverReportJitter.Set(0)
}
func (met *metrics) NewConnection() {
met.connectionCount.Add(1)
}
func (met *metrics) NewICECandidate(candidate webrtc.ICECandidateStats) {
met.iceCandidatesMu.Lock()
defer met.iceCandidatesMu.Unlock()
if _, found := met.iceCandidates[candidate.ID]; found {
return
}
met.iceCandidates[candidate.ID] = struct{}{}
if candidate.Protocol == "udp" {
met.iceCandidatesUdpCount.Add(1)
} else if candidate.Protocol == "tcp" {
met.iceCandidatesTcpCount.Add(1)
}
}
func (met *metrics) SetICECandidatesUsed(candidates []webrtc.ICECandidateStats) {
udp, tcp := 0, 0
for _, candidate := range candidates {
if candidate.Protocol == "udp" {
udp++
} else if candidate.Protocol == "tcp" {
tcp++
}
}
met.iceCandidatesUsedUdp.Set(float64(udp))
met.iceCandidatesUsedTcp.Set(float64(tcp))
}
func (met *metrics) SetState(state webrtc.PeerConnectionState) {
switch state {
case webrtc.PeerConnectionStateNew:
met.connectionState.Set(0)
case webrtc.PeerConnectionStateConnecting:
met.connectionState.Set(4)
case webrtc.PeerConnectionStateConnected:
met.connectionState.Set(5)
case webrtc.PeerConnectionStateDisconnected:
met.connectionState.Set(3)
case webrtc.PeerConnectionStateFailed:
met.connectionState.Set(2)
case webrtc.PeerConnectionStateClosed:
met.connectionState.Set(1)
met.reset()
default:
met.connectionState.Set(-1)
}
met.connectionStateCount.Add(1)
}
func (met *metrics) SetVideoID(videoId string) {
met.videoIdsMu.Lock()
defer met.videoIdsMu.Unlock()
if _, found := met.videoIds[videoId]; !found {
met.videoIds[videoId] = promauto.NewGauge(prometheus.GaugeOpts{
Name: "video_listeners",
Namespace: "neko",
Subsystem: "webrtc",
Help: "Listeners for Video pipelines by a session.",
ConstLabels: map[string]string{
"session_id": met.sessionId,
"video_id": videoId,
},
})
}
for id, entry := range met.videoIds {
if id == videoId {
entry.Set(1)
} else {
entry.Set(0)
}
}
}
func (met *metrics) SetReceiverEstimatedMaximumBitrate(bitrate float32) {
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
}
func (met *metrics) SetReceiverEstimatedTargetBitrate(bitrate float64) {
met.receiverEstimatedTargetBitrate.Set(bitrate)
}
func (met *metrics) SetReceiverReport(report rtcp.ReceptionReport) {
met.receiverReportDelay.Set(float64(report.Delay))
met.receiverReportJitter.Set(float64(report.Jitter))
met.receiverReportTotalLost.Set(float64(report.TotalLost))
}
func (met *metrics) SetIceTransportStats(data webrtc.TransportStats) {
met.iceBytesSent.Set(float64(data.BytesSent))
met.iceBytesReceived.Set(float64(data.BytesReceived))
}
func (met *metrics) SetSctpTransportStats(data webrtc.TransportStats) {
met.sctpBytesSent.Set(float64(data.BytesSent))
met.sctpBytesReceived.Set(float64(data.BytesReceived))
}
//
// collectors
//
func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
for {
packets, ok := <-rtcpCh
if !ok {
break
}
for _, p := range packets {
switch rtcpPacket := p.(type) {
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
met.SetReceiverEstimatedMaximumBitrate(rtcpPacket.Bitrate)
case *rtcp.ReceiverReport:
l := len(rtcpPacket.Reports)
if l > 0 {
// use only last report
met.SetReceiverReport(rtcpPacket.Reports[l-1])
}
case *rtcp.TransportLayerNack:
for _, pair := range rtcpPacket.Nacks {
packetList := pair.PacketList()
met.transportLayerNacks.Add(float64(len(packetList)))
}
}
}
}
}
func (met *metrics) connectionStats(connection *webrtc.PeerConnection) {
ticker := time.NewTicker(connectionStatsInterval)
defer ticker.Stop()
for range ticker.C {
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
stats := connection.GetStats()
data, ok := stats["iceTransport"].(webrtc.TransportStats)
if ok {
met.SetIceTransportStats(data)
}
data, ok = stats["sctpTransport"].(webrtc.TransportStats)
if ok {
met.SetSctpTransportStats(data)
}
remoteCandidates := map[string]webrtc.ICECandidateStats{}
nominatedRemoteCandidates := map[string]struct{}{}
for _, entry := range stats {
// only remote ice candidate stats
candidate, ok := entry.(webrtc.ICECandidateStats)
if ok && candidate.Type == webrtc.StatsTypeRemoteCandidate {
met.NewICECandidate(candidate)
remoteCandidates[candidate.ID] = candidate
}
// only nominated ice candidate pair stats
pair, ok := entry.(webrtc.ICECandidatePairStats)
if ok && pair.Nominated {
nominatedRemoteCandidates[pair.RemoteCandidateID] = struct{}{}
}
}
iceCandidatesUsed := []webrtc.ICECandidateStats{}
for id := range nominatedRemoteCandidates {
if candidate, ok := remoteCandidates[id]; ok {
iceCandidatesUsed = append(iceCandidatesUsed, candidate)
}
}
met.SetICECandidatesUsed(iceCandidatesUsed)
}
}

View File

@ -0,0 +1,55 @@
package payload
import "math"
const (
OP_MOVE = 0x01
OP_SCROLL = 0x02
OP_KEY_DOWN = 0x03
OP_KEY_UP = 0x04
OP_BTN_DOWN = 0x05
OP_BTN_UP = 0x06
OP_PING = 0x07
// touch events
OP_TOUCH_BEGIN = 0x08
OP_TOUCH_UPDATE = 0x09
OP_TOUCH_END = 0x0a
)
type Move struct {
X uint16
Y uint16
}
// TODO: remove this once the client is fixed
type Scroll_Old struct {
X int16
Y int16
}
type Scroll struct {
DeltaX int16
DeltaY int16
ControlKey bool
}
type Key struct {
Key uint32
}
type Ping struct {
// client's timestamp split into two uint32
ClientTs1 uint32
ClientTs2 uint32
}
func (p Ping) ClientTs() uint64 {
return (uint64(p.ClientTs1) * uint64(math.MaxUint32)) + uint64(p.ClientTs2)
}
type Touch struct {
TouchId uint32
X int32
Y int32
Pressure uint8
}

View File

@ -0,0 +1,33 @@
package payload
import "math"
const (
OP_CURSOR_POSITION = 0x01
OP_CURSOR_IMAGE = 0x02
OP_PONG = 0x03
)
type CursorPosition struct {
X uint16
Y uint16
}
type CursorImage struct {
Width uint16
Height uint16
Xhot uint16
Yhot uint16
}
type Pong struct {
Ping
// server's timestamp split into two uint32
ServerTs1 uint32
ServerTs2 uint32
}
func (p Pong) ServerTs() uint64 {
return (uint64(p.ServerTs1) * uint64(math.MaxUint32)) + uint64(p.ServerTs2)
}

View File

@ -0,0 +1,6 @@
package payload
type Header struct {
Event uint8
Length uint16
}

View File

@ -0,0 +1,543 @@
package webrtc
import (
"bytes"
"encoding/binary"
"sync"
"time"
"github.com/pion/interceptor/pkg/cc"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/rs/zerolog"
"github.com/demodesk/neko/internal/config"
"github.com/demodesk/neko/internal/webrtc/payload"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/utils"
)
type WebRTCPeerCtx struct {
mu sync.Mutex
logger zerolog.Logger
session types.Session
metrics *metrics
connection *webrtc.PeerConnection
// bandwidth estimator
estimator cc.BandwidthEstimator
estimateTrend *utils.TrendDetector
// stream selectors
video types.StreamSelectorManager
audio types.StreamSinkManager
// tracks & channels
audioTrack *Track
videoTrack *Track
dataChannel *webrtc.DataChannel
rtcpChannel chan []rtcp.Packet
// config
iceTrickle bool
estimatorConfig config.WebRTCEstimator
paused bool
videoAuto bool
videoDisabled bool
audioDisabled bool
}
//
// connection
//
func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) {
peer.mu.Lock()
defer peer.mu.Unlock()
offer, err := peer.connection.CreateOffer(&webrtc.OfferOptions{
ICERestart: ICERestart,
})
if err != nil {
return nil, err
}
return peer.setLocalDescription(offer)
}
func (peer *WebRTCPeerCtx) CreateAnswer() (*webrtc.SessionDescription, error) {
peer.mu.Lock()
defer peer.mu.Unlock()
answer, err := peer.connection.CreateAnswer(nil)
if err != nil {
return nil, err
}
return peer.setLocalDescription(answer)
}
func (peer *WebRTCPeerCtx) setLocalDescription(description webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
if !peer.iceTrickle {
// Create channel that is blocked until ICE Gathering is complete
gatherComplete := webrtc.GatheringCompletePromise(peer.connection)
if err := peer.connection.SetLocalDescription(description); err != nil {
return nil, err
}
<-gatherComplete
} else {
if err := peer.connection.SetLocalDescription(description); err != nil {
return nil, err
}
}
return peer.connection.LocalDescription(), nil
}
func (peer *WebRTCPeerCtx) SetRemoteDescription(desc webrtc.SessionDescription) error {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.connection.SetRemoteDescription(desc)
}
func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.connection.AddICECandidate(candidate)
}
// TODO: Add shutdown function?
func (peer *WebRTCPeerCtx) Destroy() {
peer.mu.Lock()
defer peer.mu.Unlock()
var err error
// if peer connection is not closed, close it
if peer.connection.ConnectionState() != webrtc.PeerConnectionStateClosed {
err = peer.connection.Close()
}
peer.logger.Err(err).Msg("peer connection destroyed")
}
func (peer *WebRTCPeerCtx) estimatorReader() {
conf := peer.estimatorConfig
// if estimator is not in debug mode, use a nop logger
var debugLogger zerolog.Logger
if conf.Debug {
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
} else {
debugLogger = zerolog.Nop()
}
// if estimator is disabled, do nothing
if peer.estimator == nil {
return
}
// use a ticker to get current client target bitrate
ticker := time.NewTicker(conf.ReadInterval)
defer ticker.Stop()
// since when is the estimate stable/unstable
stableSince := time.Now() // we asume stable at start
unstableSince := time.Time{}
// since when are we neutral but cannot accomodate current bitrate
// we migt be stalled or estimator just reached zer (very bad connection)
stalledSince := time.Time{}
// when was the last upgrade/downgrade
lastUpgradeTime := time.Time{}
lastDowngradeTime := time.Time{}
for range ticker.C {
targetBitrate := peer.estimator.GetTargetBitrate()
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
// if peer connection is closed, stop reading
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
break
}
// if estimation or video is disabled, do nothing
if !peer.videoAuto || peer.videoDisabled || peer.paused || conf.Passive {
continue
}
// get trend direction to decide if we should upgrade or downgrade
peer.estimateTrend.AddValue(int64(targetBitrate))
direction := peer.estimateTrend.GetDirection()
// get current stream bitrate
stream, ok := peer.videoTrack.Stream()
if !ok {
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
continue
}
// if stream bitrate is 0, we need to wait for some time until we get a valid value
streamId, streamBitrate := stream.ID(), stream.Bitrate()
if streamBitrate == 0 {
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
continue
}
// check whats the difference between target and stream bitrate
diff := float64(targetBitrate) / float64(streamBitrate)
debugLogger.Info().
Float64("diff", diff).
Int("target_bitrate", targetBitrate).
Uint64("stream_bitrate", streamBitrate).
Str("direction", direction.String()).
Msg("got bitrate from estimator")
// if we can accomodate current stream or we are not netural anymore,
// we are not stalled so we reset the stalled time
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
stalledSince = time.Now()
}
// if we are neutral and stalled for too long, we might be congesting
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
if stalled {
debugLogger.Warn().
Time("stalled_since", stalledSince).
Msgf("it looks like we are stalled")
}
// if we have an downward trend or are stalled, we might be congesting
if direction == utils.TrendDirectionDownward || stalled {
// we reset the stable time because we are congesting
stableSince = time.Now()
// if we downgraded recently, we wait for some more time
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
debugLogger.Debug().
Time("last_downgrade", lastDowngradeTime).
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
continue
}
// if we are not unstable but we fluctuate we should wait for some more time
if time.Since(unstableSince) < conf.UnstableDuration {
debugLogger.Debug().
Time("unstable_since", unstableSince).
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
continue
}
// if we still have a big difference between target and stream bitrate, we wait for some more time
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("we still have a big difference between target and stream bitrate, " +
"therefore we still should be able to accomodate current stream")
continue
}
err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeLower,
},
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
}
lastDowngradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the lowest stream")
} else {
debugLogger.Info().Msg("downgraded video stream")
}
continue
}
// we reset the unstable time because we are not congesting
unstableSince = time.Now()
// if we have a neutral or upward trend, that means our estimate is stable
// if we are on the highest stream, we don't need to do anything
// but if there is a higher stream, we should try to upgrade and see if it works
// if we upgraded recently, we wait for some more time
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
debugLogger.Debug().
Time("last_upgrade", lastUpgradeTime).
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
continue
}
// if we are not stable for long enough, we wait for some more time
// because bandwidth estimation might fluctuate
if time.Since(stableSince) < conf.StableDuration {
debugLogger.Debug().
Time("stable_since", stableSince).
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
continue
}
// upgrade only if estimated bitrate passed the threshold
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
debugLogger.Debug().
Float64("diff", diff).
Float64("threshold", conf.DiffThreshold).
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
"therefore we should wait for some more time")
continue
}
err := peer.SetVideo(types.PeerVideoRequest{
Selector: &types.StreamSelector{
ID: streamId,
Type: types.StreamSelectorTypeHigher,
},
})
if err != nil && err != types.ErrWebRTCStreamNotFound {
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
}
lastUpgradeTime = time.Now()
if err == types.ErrWebRTCStreamNotFound {
debugLogger.Info().Msg("looks like we are already on the highest stream")
} else {
debugLogger.Info().Msg("upgraded video stream")
}
}
}
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
peer.mu.Lock()
defer peer.mu.Unlock()
peer.videoTrack.SetPaused(isPaused || peer.videoDisabled)
peer.audioTrack.SetPaused(isPaused || peer.audioDisabled)
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
peer.paused = isPaused
return nil
}
func (peer *WebRTCPeerCtx) Paused() bool {
peer.mu.Lock()
defer peer.mu.Unlock()
return peer.paused
}
//
// video
//
func (peer *WebRTCPeerCtx) SetVideo(r types.PeerVideoRequest) error {
peer.mu.Lock()
defer peer.mu.Unlock()
modified := false
// video disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.videoDisabled != disabled {
peer.videoDisabled = disabled
peer.videoTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set video disabled")
modified = true
}
}
// video selector
if r.Selector != nil {
selector := *r.Selector
// get requested video stream from selector
stream, ok := peer.video.GetStream(selector)
if !ok {
return types.ErrWebRTCStreamNotFound
}
// set video stream to track
changed, err := peer.videoTrack.SetStream(stream)
if err != nil {
return err
}
// update only if stream changed
if changed {
videoID := stream.ID()
peer.metrics.SetVideoID(videoID)
peer.logger.Info().Str("video_id", videoID).Msg("set video")
modified = true
}
}
// video auto
if r.Auto != nil {
videoAuto := *r.Auto
if peer.estimator == nil || peer.estimatorConfig.Passive {
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
videoAuto = false // ensure video auto is disabled
}
// update only if video auto changed
if peer.videoAuto != videoAuto {
peer.videoAuto = videoAuto
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
modified = true
}
}
// send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_VIDEO, peer.Video())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Video() types.PeerVideo {
peer.mu.Lock()
defer peer.mu.Unlock()
// get current video stream ID
ID := ""
stream, ok := peer.videoTrack.Stream()
if ok {
ID = stream.ID()
}
return types.PeerVideo{
Disabled: peer.videoDisabled,
ID: ID,
Video: ID, // TODO: Remove, used for backward compatibility
Auto: peer.videoAuto,
}
}
//
// audio
//
func (peer *WebRTCPeerCtx) SetAudio(r types.PeerAudioRequest) error {
peer.mu.Lock()
defer peer.mu.Unlock()
modified := false
// audio disabled
if r.Disabled != nil {
disabled := *r.Disabled
// update only if changed
if peer.audioDisabled != disabled {
peer.audioDisabled = disabled
peer.audioTrack.SetPaused(disabled || peer.paused)
peer.logger.Info().Bool("disabled", disabled).Msg("set audio disabled")
modified = true
}
}
// send video signal if modified
if modified {
go func() {
// in goroutine because of mutex and we don't want to block
peer.session.Send(event.SIGNAL_AUDIO, peer.Audio())
}()
}
return nil
}
func (peer *WebRTCPeerCtx) Audio() types.PeerAudio {
peer.mu.Lock()
defer peer.mu.Unlock()
return types.PeerAudio{
Disabled: peer.audioDisabled,
}
}
//
// data channel
//
func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error {
peer.mu.Lock()
defer peer.mu.Unlock()
// do not send cursor position to host
if peer.session.IsHost() {
return nil
}
header := payload.Header{
Event: payload.OP_CURSOR_POSITION,
Length: 7,
}
data := payload.CursorPosition{
X: uint16(x),
Y: uint16(y),
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
return err
}
return peer.dataChannel.Send(buffer.Bytes())
}
func (peer *WebRTCPeerCtx) SendCursorImage(cur *types.CursorImage, img []byte) error {
peer.mu.Lock()
defer peer.mu.Unlock()
header := payload.Header{
Event: payload.OP_CURSOR_IMAGE,
Length: uint16(11 + len(img)),
}
data := payload.CursorImage{
Width: cur.Width,
Height: cur.Height,
Xhot: cur.Xhot,
Yhot: cur.Yhot,
}
buffer := &bytes.Buffer{}
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
return err
}
if err := binary.Write(buffer, binary.BigEndian, img); err != nil {
return err
}
return peer.dataChannel.Send(buffer.Bytes())
}

View File

@ -0,0 +1,27 @@
package pionlog
import (
"github.com/pion/logging"
"github.com/rs/zerolog"
)
func New(logger zerolog.Logger) Factory {
return Factory{
Logger: logger.With().Str("submodule", "pion").Logger(),
}
}
type Factory struct {
Logger zerolog.Logger
}
func (l Factory) NewLogger(subsystem string) logging.LeveledLogger {
if subsystem == "sctp" {
return nulllog{}
}
return logger{
subsystem: subsystem,
logger: l.Logger.With().Str("subsystem", subsystem).Logger(),
}
}

View File

@ -0,0 +1,66 @@
package pionlog
import (
"fmt"
"strings"
"github.com/rs/zerolog"
)
type logger struct {
logger zerolog.Logger
subsystem string
}
func (l logger) Trace(msg string) {
l.logger.Trace().Msg(strings.TrimSpace(msg))
}
func (l logger) Tracef(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Trace().Msg(strings.TrimSpace(msg))
}
func (l logger) Debug(msg string) {
l.logger.Debug().Msg(strings.TrimSpace(msg))
}
func (l logger) Debugf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Debug().Msg(strings.TrimSpace(msg))
}
func (l logger) Info(msg string) {
if strings.Contains(msg, "duplicated packet") {
return
}
l.logger.Info().Msg(strings.TrimSpace(msg))
}
func (l logger) Infof(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
if strings.Contains(msg, "duplicated packet") {
return
}
l.logger.Info().Msg(strings.TrimSpace(msg))
}
func (l logger) Warn(msg string) {
l.logger.Warn().Msg(strings.TrimSpace(msg))
}
func (l logger) Warnf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Warn().Msg(strings.TrimSpace(msg))
}
func (l logger) Error(msg string) {
l.logger.Error().Msg(strings.TrimSpace(msg))
}
func (l logger) Errorf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.logger.Error().Msg(strings.TrimSpace(msg))
}

View File

@ -0,0 +1,14 @@
package pionlog
type nulllog struct{}
func (l nulllog) Trace(msg string) {}
func (l nulllog) Tracef(format string, args ...any) {}
func (l nulllog) Debug(msg string) {}
func (l nulllog) Debugf(format string, args ...any) {}
func (l nulllog) Info(msg string) {}
func (l nulllog) Infof(format string, args ...any) {}
func (l nulllog) Warn(msg string) {}
func (l nulllog) Warnf(format string, args ...any) {}
func (l nulllog) Error(msg string) {}
func (l nulllog) Errorf(format string, args ...any) {}

View File

@ -0,0 +1,203 @@
package webrtc
import (
"errors"
"io"
"sync"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/codec"
)
type Track struct {
logger zerolog.Logger
track *webrtc.TrackLocalStaticSample
rtcpCh chan []rtcp.Packet
sample chan types.Sample
paused bool
stream types.StreamSinkManager
streamMu sync.Mutex
}
type trackOption func(*Track)
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
return func(t *Track) {
t.rtcpCh = rtcp
}
}
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...trackOption) (*Track, error) {
id := codec.Type.String()
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
if err != nil {
return nil, err
}
t := &Track{
logger: logger.With().Str("id", id).Logger(),
track: track,
rtcpCh: nil,
sample: make(chan types.Sample),
}
for _, opt := range opts {
opt(t)
}
sender, err := connection.AddTrack(t.track)
if err != nil {
return nil, err
}
go t.rtcpReader(sender)
go t.sampleReader()
return t, nil
}
func (t *Track) Shutdown() {
t.RemoveStream()
close(t.sample)
}
func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
for {
packets, _, err := sender.ReadRTCP()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
t.logger.Debug().Msg("track rtcp reader closed")
return
}
t.logger.Warn().Err(err).Msg("failed to read track rtcp")
continue
}
if t.rtcpCh != nil {
t.rtcpCh <- packets
}
}
}
// --- sample ---
func (t *Track) sampleReader() {
for {
sample, ok := <-t.sample
if !ok {
t.logger.Debug().Msg("track sample reader closed")
return
}
err := t.track.WriteSample(media.Sample{
Data: sample.Data,
Duration: sample.Duration,
Timestamp: sample.Timestamp,
})
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.logger.Warn().Err(err).Msg("failed to write sample to track")
}
}
}
func (t *Track) WriteSample(sample types.Sample) {
t.sample <- sample
}
// --- stream ---
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if we already listen to the stream, do nothing
if t.stream == stream {
return false, nil
}
// if paused, we switch the stream but don't add the listener
if t.paused {
t.stream = stream
return true, nil
}
var err error
if t.stream != nil {
err = t.stream.MoveListenerTo(t, stream)
} else {
err = stream.AddListener(t)
}
if err != nil {
return false, err
}
t.stream = stream
return true, nil
}
func (t *Track) RemoveStream() {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if there is no stream, or paused we don't need to remove the listener
if t.stream == nil || t.paused {
t.stream = nil
return
}
err := t.stream.RemoveListener(t)
if err != nil {
t.logger.Warn().Err(err).Msg("failed to remove listener from stream")
}
t.stream = nil
}
func (t *Track) Stream() (types.StreamSinkManager, bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.stream, t.stream != nil
}
// --- paused ---
func (t *Track) SetPaused(paused bool) {
t.streamMu.Lock()
defer t.streamMu.Unlock()
// if there is no state change or no stream, do nothing
if t.paused == paused || t.stream == nil {
t.paused = paused
return
}
var err error
if paused {
err = t.stream.RemoveListener(t)
} else {
err = t.stream.AddListener(t)
}
if err != nil {
t.logger.Warn().Err(err).Msg("failed to change listener state")
return
}
t.paused = paused
}
func (t *Track) Paused() bool {
t.streamMu.Lock()
defer t.streamMu.Unlock()
return t.paused
}

View File

@ -0,0 +1,67 @@
package websocket
import (
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
func (manager *WebSocketManagerCtx) fileChooserDialogEvents() {
var activeSession types.Session
// when dialog opens, everyone should be notified.
manager.desktop.OnFileChooserDialogOpened(func() {
manager.logger.Info().Msg("file chooser dialog opened")
host, hasHost := manager.sessions.GetHost()
if !hasHost {
manager.logger.Warn().Msg("no host for file chooser dialog found, closing")
go manager.desktop.CloseFileChooserDialog()
return
}
activeSession = host
go manager.sessions.Broadcast(
event.FILE_CHOOSER_DIALOG_OPENED,
message.SessionID{
ID: host.ID(),
})
})
// when dialog closes, everyone should be notified.
manager.desktop.OnFileChooserDialogClosed(func() {
manager.logger.Info().Msg("file chooser dialog closed")
activeSession = nil
go manager.sessions.Broadcast(
event.FILE_CHOOSER_DIALOG_CLOSED,
message.SessionID{})
})
// when new user joins, and someone holds dialog, he shouldd be notified about it.
manager.sessions.OnConnected(func(session types.Session) {
if activeSession == nil {
return
}
manager.logger.Debug().Str("session_id", session.ID()).Msg("sending file chooser dialog status to a new session")
session.Send(
event.FILE_CHOOSER_DIALOG_OPENED,
message.SessionID{
ID: activeSession.ID(),
})
})
// when user, that holds dialog, disconnects, it should be closed.
manager.sessions.OnDisconnected(func(session types.Session) {
if activeSession == nil || activeSession != session {
return
}
manager.logger.Info().Msg("file chooser dialog owner left, closing")
manager.desktop.CloseFileChooserDialog()
})
}

View File

@ -0,0 +1,23 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) clipboardSet(session types.Session, payload *message.ClipboardData) error {
if !session.Profile().CanAccessClipboard {
return errors.New("cannot access clipboard")
}
if !session.IsHost() {
return errors.New("is not the host")
}
return h.desktop.ClipboardSetText(types.ClipboardText{
Text: payload.Text,
// TODO: Send HTML?
})
}

View File

@ -0,0 +1,228 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/xorg"
)
var (
ErrIsNotAllowedToHost = errors.New("is not allowed to host")
ErrIsNotTheHost = errors.New("is not the host")
ErrIsAlreadyTheHost = errors.New("is already the host")
ErrIsAlreadyHosted = errors.New("is already hosted")
)
func (h *MessageHandlerCtx) controlRelease(session types.Session) error {
if !session.Profile().CanHost || session.PrivateModeEnabled() {
return ErrIsNotAllowedToHost
}
if !session.IsHost() {
return ErrIsNotTheHost
}
h.desktop.ResetKeys()
session.ClearHost()
return nil
}
func (h *MessageHandlerCtx) controlRequest(session types.Session) error {
if !session.Profile().CanHost || session.PrivateModeEnabled() {
return ErrIsNotAllowedToHost
}
if session.IsHost() {
return ErrIsAlreadyTheHost
}
if h.sessions.Settings().LockedControls && !session.Profile().IsAdmin {
return ErrIsNotAllowedToHost
}
// if implicit hosting is enabled, set session as host without asking
if h.sessions.Settings().ImplicitHosting {
session.SetAsHost()
return nil
}
// if there is no host, set session as host
host, hasHost := h.sessions.GetHost()
if !hasHost {
session.SetAsHost()
return nil
}
// TODO: Some throttling mechanism to prevent spamming.
// let host know that someone wants to take control
host.Send(
event.CONTROL_REQUEST,
message.SessionID{
ID: session.ID(),
})
return ErrIsAlreadyHosted
}
func (h *MessageHandlerCtx) controlMove(session types.Session, payload *message.ControlPos) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
// handle active cursor movement
h.desktop.Move(payload.X, payload.Y)
h.webrtc.SetCursorPosition(payload.X, payload.Y)
return nil
}
func (h *MessageHandlerCtx) controlScroll(session types.Session, payload *message.ControlScroll) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
// TOOD: remove this once the client is fixed
if payload.DeltaX == 0 && payload.DeltaY == 0 {
payload.DeltaX = payload.X
payload.DeltaY = payload.Y
}
h.desktop.Scroll(payload.DeltaX, payload.DeltaY, payload.ControlKey)
return nil
}
func (h *MessageHandlerCtx) controlButtonPress(session types.Session, payload *message.ControlButton) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.ButtonPress(payload.Code)
}
func (h *MessageHandlerCtx) controlButtonDown(session types.Session, payload *message.ControlButton) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.ButtonDown(payload.Code)
}
func (h *MessageHandlerCtx) controlButtonUp(session types.Session, payload *message.ControlButton) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.ButtonUp(payload.Code)
}
func (h *MessageHandlerCtx) controlKeyPress(session types.Session, payload *message.ControlKey) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyPress(payload.Keysym)
}
func (h *MessageHandlerCtx) controlKeyDown(session types.Session, payload *message.ControlKey) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyDown(payload.Keysym)
}
func (h *MessageHandlerCtx) controlKeyUp(session types.Session, payload *message.ControlKey) error {
if payload.ControlPos != nil {
if err := h.controlMove(session, payload.ControlPos); err != nil {
return err
}
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyUp(payload.Keysym)
}
func (h *MessageHandlerCtx) controlTouchBegin(session types.Session, payload *message.ControlTouch) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.TouchBegin(payload.TouchId, payload.X, payload.Y, payload.Pressure)
}
func (h *MessageHandlerCtx) controlTouchUpdate(session types.Session, payload *message.ControlTouch) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.TouchUpdate(payload.TouchId, payload.X, payload.Y, payload.Pressure)
}
func (h *MessageHandlerCtx) controlTouchEnd(session types.Session, payload *message.ControlTouch) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.TouchEnd(payload.TouchId, payload.X, payload.Y, payload.Pressure)
}
func (h *MessageHandlerCtx) controlCut(session types.Session) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_x)
}
func (h *MessageHandlerCtx) controlCopy(session types.Session) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_c)
}
func (h *MessageHandlerCtx) controlPaste(session types.Session, payload *message.ClipboardData) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
// if there have been set clipboard data, set them first
if payload != nil && payload.Text != "" {
if err := h.clipboardSet(session, payload); err != nil {
return err
}
}
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_v)
}
func (h *MessageHandlerCtx) controlSelectAll(session types.Session) error {
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
return err
}
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_a)
}

View File

@ -0,0 +1,203 @@
package handler
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
func New(
sessions types.SessionManager,
desktop types.DesktopManager,
capture types.CaptureManager,
webrtc types.WebRTCManager,
) *MessageHandlerCtx {
return &MessageHandlerCtx{
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
sessions: sessions,
desktop: desktop,
capture: capture,
webrtc: webrtc,
}
}
type MessageHandlerCtx struct {
logger zerolog.Logger
sessions types.SessionManager
webrtc types.WebRTCManager
desktop types.DesktopManager
capture types.CaptureManager
}
func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketMessage) bool {
var err error
switch data.Event {
// System Events
case event.SYSTEM_LOGS:
payload := &message.SystemLogs{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.systemLogs(session, payload)
})
// Signal Events
case event.SIGNAL_REQUEST:
payload := &message.SignalRequest{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalRequest(session, payload)
})
case event.SIGNAL_RESTART:
err = h.signalRestart(session)
case event.SIGNAL_OFFER:
payload := &message.SignalDescription{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalOffer(session, payload)
})
case event.SIGNAL_ANSWER:
payload := &message.SignalDescription{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalAnswer(session, payload)
})
case event.SIGNAL_CANDIDATE:
payload := &message.SignalCandidate{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalCandidate(session, payload)
})
case event.SIGNAL_VIDEO:
payload := &message.SignalVideo{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalVideo(session, payload)
})
case event.SIGNAL_AUDIO:
payload := &message.SignalAudio{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.signalAudio(session, payload)
})
// Control Events
case event.CONTROL_RELEASE:
err = h.controlRelease(session)
case event.CONTROL_REQUEST:
err = h.controlRequest(session)
case event.CONTROL_MOVE:
payload := &message.ControlPos{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlMove(session, payload)
})
case event.CONTROL_SCROLL:
payload := &message.ControlScroll{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlScroll(session, payload)
})
case event.CONTROL_BUTTONPRESS:
payload := &message.ControlButton{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlButtonPress(session, payload)
})
case event.CONTROL_BUTTONDOWN:
payload := &message.ControlButton{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlButtonDown(session, payload)
})
case event.CONTROL_BUTTONUP:
payload := &message.ControlButton{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlButtonUp(session, payload)
})
case event.CONTROL_KEYPRESS:
payload := &message.ControlKey{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlKeyPress(session, payload)
})
case event.CONTROL_KEYDOWN:
payload := &message.ControlKey{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlKeyDown(session, payload)
})
case event.CONTROL_KEYUP:
payload := &message.ControlKey{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlKeyUp(session, payload)
})
// touch
case event.CONTROL_TOUCHBEGIN:
payload := &message.ControlTouch{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlTouchBegin(session, payload)
})
case event.CONTROL_TOUCHUPDATE:
payload := &message.ControlTouch{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlTouchUpdate(session, payload)
})
case event.CONTROL_TOUCHEND:
payload := &message.ControlTouch{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlTouchEnd(session, payload)
})
// actions
case event.CONTROL_CUT:
err = h.controlCut(session)
case event.CONTROL_COPY:
err = h.controlCopy(session)
case event.CONTROL_PASTE:
payload := &message.ClipboardData{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.controlPaste(session, payload)
})
case event.CONTROL_SELECT_ALL:
err = h.controlSelectAll(session)
// Screen Events
case event.SCREEN_SET:
payload := &message.ScreenSize{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.screenSet(session, payload)
})
// Clipboard Events
case event.CLIPBOARD_SET:
payload := &message.ClipboardData{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.clipboardSet(session, payload)
})
// Keyboard Events
case event.KEYBOARD_MAP:
payload := &message.KeyboardMap{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.keyboardMap(session, payload)
})
case event.KEYBOARD_MODIFIERS:
payload := &message.KeyboardModifiers{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.keyboardModifiers(session, payload)
})
// Send Events
case event.SEND_UNICAST:
payload := &message.SendUnicast{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.sendUnicast(session, payload)
})
case event.SEND_BROADCAST:
payload := &message.SendBroadcast{}
err = utils.Unmarshal(payload, data.Payload, func() error {
return h.sendBroadcast(session, payload)
})
default:
return false
}
if err != nil {
h.logger.Warn().Err(err).
Str("event", data.Event).
Str("session_id", session.ID()).
Msg("message handler has failed")
}
return true
}

View File

@ -0,0 +1,25 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) keyboardMap(session types.Session, payload *message.KeyboardMap) error {
if !session.IsHost() {
return errors.New("is not the host")
}
return h.desktop.SetKeyboardMap(payload.KeyboardMap)
}
func (h *MessageHandlerCtx) keyboardModifiers(session types.Session, payload *message.KeyboardModifiers) error {
if !session.IsHost() {
return errors.New("is not the host")
}
h.desktop.SetKeyboardModifiers(payload.KeyboardModifiers)
return nil
}

View File

@ -0,0 +1,26 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) screenSet(session types.Session, payload *message.ScreenSize) error {
if !session.Profile().IsAdmin {
return errors.New("is not the admin")
}
size, err := h.desktop.SetScreenSize(payload.ScreenSize)
if err != nil {
return err
}
h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{
ID: session.ID(),
ScreenSize: size,
})
return nil
}

View File

@ -0,0 +1,39 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) sendUnicast(session types.Session, payload *message.SendUnicast) error {
receiver, ok := h.sessions.Get(payload.Receiver)
if !ok {
return errors.New("receiver session ID not found")
}
receiver.Send(
event.SEND_UNICAST,
message.SendUnicast{
Sender: session.ID(),
Receiver: receiver.ID(),
Subject: payload.Subject,
Body: payload.Body,
})
return nil
}
func (h *MessageHandlerCtx) sendBroadcast(session types.Session, payload *message.SendBroadcast) error {
h.sessions.Broadcast(
event.SEND_BROADCAST,
message.SendBroadcast{
Sender: session.ID(),
Subject: payload.Subject,
Body: payload.Body,
}, session.ID())
return nil
}

View File

@ -0,0 +1,106 @@
package handler
import (
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) SessionCreated(session types.Session) error {
h.sessions.Broadcast(
event.SESSION_CREATED,
message.SessionData{
ID: session.ID(),
Profile: session.Profile(),
State: session.State(),
})
return nil
}
func (h *MessageHandlerCtx) SessionDeleted(session types.Session) error {
h.sessions.Broadcast(
event.SESSION_DELETED,
message.SessionID{
ID: session.ID(),
})
return nil
}
func (h *MessageHandlerCtx) SessionConnected(session types.Session) error {
if err := h.systemInit(session); err != nil {
return err
}
if session.Profile().IsAdmin {
if err := h.systemAdmin(session); err != nil {
return err
}
// update settings in atomic way
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
// if control protection & locked controls: unlock controls
if settings.LockedControls && settings.ControlProtection {
settings.LockedControls = false
return true // update settings
}
return false // do not update settings
})
}
return h.SessionStateChanged(session)
}
func (h *MessageHandlerCtx) SessionDisconnected(session types.Session) error {
// clear host if exists
if session.IsHost() {
h.desktop.ResetKeys()
session.ClearHost()
}
if session.Profile().IsAdmin {
hasAdmin := false
h.sessions.Range(func(s types.Session) bool {
if s.Profile().IsAdmin && s.ID() != session.ID() && s.State().IsConnected {
hasAdmin = true
return false
}
return true
})
// update settings in atomic way
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
// if control protection & not locked controls & no admin: lock controls
if !settings.LockedControls && settings.ControlProtection && !hasAdmin {
settings.LockedControls = true
return true // update settings
}
return false // do not update settings
})
}
return h.SessionStateChanged(session)
}
func (h *MessageHandlerCtx) SessionProfileChanged(session types.Session, new, old types.MemberProfile) error {
h.sessions.Broadcast(
event.SESSION_PROFILE,
message.MemberProfile{
ID: session.ID(),
MemberProfile: new,
})
return nil
}
func (h *MessageHandlerCtx) SessionStateChanged(session types.Session) error {
h.sessions.Broadcast(
event.SESSION_STATE,
message.SessionState{
ID: session.ID(),
SessionState: session.State(),
})
return nil
}

View File

@ -0,0 +1,162 @@
package handler
import (
"errors"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/pion/webrtc/v3"
)
func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *message.SignalRequest) error {
if !session.Profile().CanWatch {
return errors.New("not allowed to watch")
}
offer, peer, err := h.webrtc.CreatePeer(session)
if err != nil {
return err
}
// set webrtc as paused if session has private mode enabled
if session.PrivateModeEnabled() {
peer.SetPaused(true)
}
video := payload.Video
// use default first video, if not provided
if video.Selector == nil {
videos := h.capture.Video().IDs()
video.Selector = &types.StreamSelector{
ID: videos[0],
Type: types.StreamSelectorTypeExact,
}
}
// TODO: Remove, used for compatibility with old clients.
if video.Auto == nil {
video.Auto = &payload.Auto
}
// set video stream
err = peer.SetVideo(video)
if err != nil {
return err
}
audio := payload.Audio
// enable by default if not requested otherwise
if audio.Disabled == nil {
disabled := false
audio.Disabled = &disabled
}
// set audio stream
err = peer.SetAudio(audio)
if err != nil {
return err
}
session.Send(
event.SIGNAL_PROVIDE,
message.SignalProvide{
SDP: offer.SDP,
ICEServers: h.webrtc.ICEServers(),
Video: peer.Video(),
Audio: peer.Audio(),
})
return nil
}
func (h *MessageHandlerCtx) signalRestart(session types.Session) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
offer, err := peer.CreateOffer(true)
if err != nil {
return err
}
// TODO: Use offer event instead.
session.Send(
event.SIGNAL_RESTART,
message.SignalDescription{
SDP: offer.SDP,
})
return nil
}
func (h *MessageHandlerCtx) signalOffer(session types.Session, payload *message.SignalDescription) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
err := peer.SetRemoteDescription(webrtc.SessionDescription{
SDP: payload.SDP,
Type: webrtc.SDPTypeOffer,
})
if err != nil {
return err
}
answer, err := peer.CreateAnswer()
if err != nil {
return err
}
session.Send(
event.SIGNAL_ANSWER,
message.SignalDescription{
SDP: answer.SDP,
})
return nil
}
func (h *MessageHandlerCtx) signalAnswer(session types.Session, payload *message.SignalDescription) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
return peer.SetRemoteDescription(webrtc.SessionDescription{
SDP: payload.SDP,
Type: webrtc.SDPTypeAnswer,
})
}
func (h *MessageHandlerCtx) signalCandidate(session types.Session, payload *message.SignalCandidate) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
return peer.SetCandidate(payload.ICECandidateInit)
}
func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.SignalVideo) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
return peer.SetVideo(payload.PeerVideoRequest)
}
func (h *MessageHandlerCtx) signalAudio(session types.Session, payload *message.SignalAudio) error {
peer := session.GetWebRTCPeer()
if peer == nil {
return errors.New("webRTC peer does not exist")
}
return peer.SetAudio(payload.PeerAudioRequest)
}

View File

@ -0,0 +1,96 @@
package handler
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
)
func (h *MessageHandlerCtx) systemInit(session types.Session) error {
host, hasHost := h.sessions.GetHost()
var hostID string
if hasHost {
hostID = host.ID()
}
controlHost := message.ControlHost{
HasHost: hasHost,
HostID: hostID,
}
sessions := map[string]message.SessionData{}
for _, session := range h.sessions.List() {
sessionId := session.ID()
sessions[sessionId] = message.SessionData{
ID: sessionId,
Profile: session.Profile(),
State: session.State(),
}
}
session.Send(
event.SYSTEM_INIT,
message.SystemInit{
SessionId: session.ID(),
ControlHost: controlHost,
ScreenSize: h.desktop.GetScreenSize(),
Sessions: sessions,
Settings: h.sessions.Settings(),
TouchEvents: h.desktop.HasTouchSupport(),
ScreencastEnabled: h.capture.Screencast().Enabled(),
WebRTC: message.SystemWebRTC{
Videos: h.capture.Video().IDs(),
},
})
return nil
}
func (h *MessageHandlerCtx) systemAdmin(session types.Session) error {
configurations := h.desktop.ScreenConfigurations()
list := make([]types.ScreenSize, 0, len(configurations))
for _, conf := range configurations {
list = append(list, types.ScreenSize{
Width: conf.Width,
Height: conf.Height,
Rate: conf.Rate,
})
}
broadcast := h.capture.Broadcast()
session.Send(
event.SYSTEM_ADMIN,
message.SystemAdmin{
ScreenSizesList: list, // TODO: remove
BroadcastStatus: message.BroadcastStatus{
IsActive: broadcast.Started(),
URL: broadcast.Url(),
},
})
return nil
}
func (h *MessageHandlerCtx) systemLogs(session types.Session, payload *message.SystemLogs) error {
for _, msg := range *payload {
level, _ := zerolog.ParseLevel(msg.Level)
if level < zerolog.DebugLevel || level > zerolog.ErrorLevel {
level = zerolog.NoLevel
}
// do not use handler logger context
log.WithLevel(level).
Fields(msg.Fields).
Str("module", "client").
Str("session_id", session.ID()).
Msg(msg.Message)
}
return nil
}

View File

@ -0,0 +1,431 @@
package websocket
import (
"encoding/json"
"errors"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/demodesk/neko/internal/websocket/handler"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
// send pings to peer with this period - must be less than pongWait
const pingPeriod = 10 * time.Second
// period for sending inactive cursor messages
const inactiveCursorsPeriod = 750 * time.Millisecond
// maximum payload length for logging
const maxPayloadLogLength = 10_000
// events that are not logged in debug mode
var nologEvents = []string{
// don't log twice
event.SYSTEM_LOGS,
// don't log heartbeat
event.SYSTEM_HEARTBEAT,
// don't log every cursor update
event.SESSION_CURSORS,
}
func New(
sessions types.SessionManager,
desktop types.DesktopManager,
capture types.CaptureManager,
webrtc types.WebRTCManager,
) *WebSocketManagerCtx {
logger := log.With().Str("module", "websocket").Logger()
return &WebSocketManagerCtx{
logger: logger,
shutdown: make(chan struct{}),
sessions: sessions,
desktop: desktop,
handler: handler.New(sessions, desktop, capture, webrtc),
handlers: []types.WebSocketHandler{},
}
}
type WebSocketManagerCtx struct {
logger zerolog.Logger
wg sync.WaitGroup
shutdown chan struct{}
sessions types.SessionManager
desktop types.DesktopManager
handler *handler.MessageHandlerCtx
handlers []types.WebSocketHandler
shutdownInactiveCursors chan struct{}
}
func (manager *WebSocketManagerCtx) Start() {
manager.sessions.OnCreated(func(session types.Session) {
err := manager.handler.SessionCreated(session)
manager.logger.Err(err).
Str("session_id", session.ID()).
Msg("session created")
})
manager.sessions.OnDeleted(func(session types.Session) {
err := manager.handler.SessionDeleted(session)
manager.logger.Err(err).
Str("session_id", session.ID()).
Msg("session deleted")
})
manager.sessions.OnConnected(func(session types.Session) {
err := manager.handler.SessionConnected(session)
manager.logger.Err(err).
Str("session_id", session.ID()).
Msg("session connected")
})
manager.sessions.OnDisconnected(func(session types.Session) {
err := manager.handler.SessionDisconnected(session)
manager.logger.Err(err).
Str("session_id", session.ID()).
Msg("session disconnected")
})
manager.sessions.OnProfileChanged(func(session types.Session, new, old types.MemberProfile) {
err := manager.handler.SessionProfileChanged(session, new, old)
manager.logger.Err(err).
Str("session_id", session.ID()).
Interface("new", new).
Interface("old", old).
Msg("session profile changed")
})
manager.sessions.OnStateChanged(func(session types.Session) {
err := manager.handler.SessionStateChanged(session)
manager.logger.Err(err).
Str("session_id", session.ID()).
Msg("session state changed")
})
manager.sessions.OnHostChanged(func(session, host types.Session) {
payload := message.ControlHost{
ID: session.ID(),
HasHost: host != nil,
}
if payload.HasHost {
payload.HostID = host.ID()
}
manager.sessions.Broadcast(event.CONTROL_HOST, payload)
manager.logger.Info().
Str("session_id", session.ID()).
Bool("has_host", payload.HasHost).
Str("host_id", payload.HostID).
Msg("session host changed")
})
manager.sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) {
// start inactive cursors
if new.InactiveCursors && !old.InactiveCursors {
manager.startInactiveCursors()
}
// stop inactive cursors
if !new.InactiveCursors && old.InactiveCursors {
manager.stopInactiveCursors()
}
manager.sessions.Broadcast(event.SYSTEM_SETTINGS, message.SystemSettingsUpdate{
ID: session.ID(),
Settings: new,
})
manager.logger.Info().
Str("session_id", session.ID()).
Interface("new", new).
Interface("old", old).
Msg("settings changed")
})
manager.desktop.OnClipboardUpdated(func() {
host, hasHost := manager.sessions.GetHost()
if !hasHost || !host.Profile().CanAccessClipboard {
return
}
manager.logger.Info().Msg("sync clipboard")
data, err := manager.desktop.ClipboardGetText()
if err != nil {
manager.logger.Err(err).Msg("could not get clipboard content")
return
}
host.Send(
event.CLIPBOARD_UPDATED,
message.ClipboardData{
Text: data.Text,
// TODO: Send HTML?
})
})
if manager.desktop.IsFileChooserDialogEnabled() {
manager.fileChooserDialogEvents()
}
if manager.sessions.Settings().InactiveCursors {
manager.startInactiveCursors()
}
manager.logger.Info().Msg("websocket starting")
}
func (manager *WebSocketManagerCtx) Shutdown() error {
manager.logger.Info().Msg("shutdown")
close(manager.shutdown)
manager.stopInactiveCursors()
manager.wg.Wait()
return nil
}
func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) {
manager.handlers = append(manager.handlers, handler)
}
func (manager *WebSocketManagerCtx) Upgrade(checkOrigin types.CheckOrigin) types.RouterHandler {
return func(w http.ResponseWriter, r *http.Request) error {
upgrader := websocket.Upgrader{
CheckOrigin: checkOrigin,
// Do not return any error while handshake
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {},
}
connection, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return utils.HttpBadRequest().WithInternalErr(err)
}
// Cannot write HTTP response after connection upgrade
manager.connect(connection, r)
return nil
}
}
func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http.Request) {
session, err := manager.sessions.Authenticate(r)
if err != nil {
manager.logger.Warn().Err(err).Msg("authentication failed")
newPeer(manager.logger, connection).Destroy(err.Error())
return
}
// add session id to all log messages
logger := manager.logger.With().Str("session_id", session.ID()).Logger()
// create new peer
peer := newPeer(logger, connection)
if !session.Profile().CanConnect {
logger.Warn().Msg("connection disabled")
peer.Destroy("connection disabled")
return
}
if session.State().IsConnected {
logger.Warn().Msg("already connected")
if !manager.sessions.Settings().MercifulReconnect {
peer.Destroy("already connected")
return
}
logger.Info().Msg("replacing peer connection")
}
logger.Info().
Str("address", connection.RemoteAddr().String()).
Str("agent", r.UserAgent()).
Msg("connection started")
session.ConnectWebSocketPeer(peer)
// this is a blocking function that lives
// throughout whole websocket connection
err = manager.handle(connection, peer, session)
logger.Info().
Str("address", connection.RemoteAddr().String()).
Str("agent", r.UserAgent()).
Msg("connection ended")
if err == nil {
logger.Debug().Msg("websocket close")
session.DisconnectWebSocketPeer(peer, false)
return
}
delayedDisconnect := false
e, ok := err.(*websocket.CloseError)
if !ok {
err = errors.Unwrap(err) // unwrap if possible
logger.Warn().Err(err).Msg("read message error")
// client is expected to reconnect soon
delayedDisconnect = true
} else {
switch e.Code {
case websocket.CloseNormalClosure:
logger.Debug().Str("reason", e.Text).Msg("websocket close")
case websocket.CloseGoingAway:
logger.Debug().Str("reason", "going away").Msg("websocket close")
default:
logger.Warn().Err(err).Msg("websocket close")
// abnormal websocket closure:
// client is expected to reconnect soon
delayedDisconnect = true
}
}
session.DisconnectWebSocketPeer(peer, delayedDisconnect)
}
func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer types.WebSocketPeer, session types.Session) error {
// add session id to logger context
logger := manager.logger.With().Str("session_id", session.ID()).Logger()
bytes := make(chan []byte)
cancel := make(chan error)
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
manager.wg.Add(1)
go func() {
defer manager.wg.Done()
for {
_, raw, err := connection.ReadMessage()
if err != nil {
cancel <- err
break
}
bytes <- raw
}
}()
for {
select {
case raw := <-bytes:
data := types.WebSocketMessage{}
if err := json.Unmarshal(raw, &data); err != nil {
logger.Err(err).Msg("message unmarshalling has failed")
break
}
// log events if not ignored
if ok, _ := utils.ArrayIn(data.Event, nologEvents); !ok {
payload := data.Payload
if len(payload) > maxPayloadLogLength {
payload = []byte("<truncated>")
}
logger.Debug().
Str("address", connection.RemoteAddr().String()).
Str("event", data.Event).
Str("payload", string(payload)).
Msg("received message from client")
}
handled := manager.handler.Message(session, data)
for _, handler := range manager.handlers {
if handled {
break
}
handled = handler(session, data)
}
if !handled {
logger.Warn().Str("event", data.Event).Msg("unhandled message")
}
case err := <-cancel:
return err
case <-manager.shutdown:
peer.Destroy("connection shutdown")
return nil
case <-ticker.C:
if err := peer.Ping(); err != nil {
return err
}
}
}
}
func (manager *WebSocketManagerCtx) startInactiveCursors() {
if manager.shutdownInactiveCursors != nil {
manager.logger.Warn().Msg("inactive cursors handler already running")
return
}
manager.logger.Info().Msg("starting inactive cursors handler")
manager.shutdownInactiveCursors = make(chan struct{})
manager.wg.Add(1)
go func() {
defer manager.wg.Done()
ticker := time.NewTicker(inactiveCursorsPeriod)
defer ticker.Stop()
var currentEmpty bool
var lastEmpty = false
for {
select {
case <-manager.shutdownInactiveCursors:
manager.logger.Info().Msg("stopping inactive cursors handler")
manager.shutdownInactiveCursors = nil
// remove last cursor entries and send empty message
_ = manager.sessions.PopCursors()
manager.sessions.InactiveCursorsBroadcast(event.SESSION_CURSORS, []message.SessionCursors{})
return
case <-ticker.C:
cursorsMap := manager.sessions.PopCursors()
currentEmpty = len(cursorsMap) == 0
if currentEmpty && lastEmpty {
continue
}
lastEmpty = currentEmpty
sessionCursors := []message.SessionCursors{}
for session, cursors := range cursorsMap {
sessionCursors = append(
sessionCursors,
message.SessionCursors{
ID: session.ID(),
Cursors: cursors,
},
)
}
manager.sessions.InactiveCursorsBroadcast(event.SESSION_CURSORS, sessionCursors)
}
}
}()
}
func (manager *WebSocketManagerCtx) stopInactiveCursors() {
if manager.shutdownInactiveCursors != nil {
close(manager.shutdownInactiveCursors)
}
}

View File

@ -0,0 +1,91 @@
package websocket
import (
"encoding/json"
"errors"
"sync"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/types"
"github.com/demodesk/neko/pkg/types/event"
"github.com/demodesk/neko/pkg/types/message"
"github.com/demodesk/neko/pkg/utils"
)
type WebSocketPeerCtx struct {
mu sync.Mutex
logger zerolog.Logger
connection *websocket.Conn
}
func newPeer(logger zerolog.Logger, connection *websocket.Conn) *WebSocketPeerCtx {
return &WebSocketPeerCtx{
logger: logger.With().Str("submodule", "peer").Logger(),
connection: connection,
}
}
func (peer *WebSocketPeerCtx) Send(event string, payload any) {
peer.mu.Lock()
defer peer.mu.Unlock()
raw, err := json.Marshal(payload)
if err != nil {
peer.logger.Err(err).Str("event", event).Msg("message marshalling has failed")
return
}
err = peer.connection.WriteJSON(types.WebSocketMessage{
Event: event,
Payload: raw,
})
if err != nil {
err = errors.Unwrap(err) // unwrap if possible
peer.logger.Warn().Err(err).Str("event", event).Msg("send message error")
return
}
// log events if not ignored
if ok, _ := utils.ArrayIn(event, nologEvents); !ok {
if len(raw) > maxPayloadLogLength {
raw = []byte("<truncated>")
}
peer.logger.Debug().
Str("address", peer.connection.RemoteAddr().String()).
Str("event", event).
Str("payload", string(raw)).
Msg("sending message to client")
}
}
func (peer *WebSocketPeerCtx) Ping() error {
peer.mu.Lock()
defer peer.mu.Unlock()
// application level heartbeat
if err := peer.connection.WriteJSON(types.WebSocketMessage{
Event: event.SYSTEM_HEARTBEAT,
}); err != nil {
return err
}
return peer.connection.WriteMessage(websocket.PingMessage, nil)
}
func (peer *WebSocketPeerCtx) Destroy(reason string) {
peer.Send(
event.SYSTEM_DISCONNECT,
message.SystemDisconnect{
Message: reason,
})
peer.mu.Lock()
defer peer.mu.Unlock()
err := peer.connection.Close()
peer.logger.Err(err).Msg("peer connection destroyed")
}