mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
move server to server directory.
This commit is contained in:
84
server/internal/api/members/bluk.go
Normal file
84
server/internal/api/members/bluk.go
Normal file
@ -0,0 +1,84 @@
|
||||
package members
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type MemberBulkUpdatePayload struct {
|
||||
IDs []string `json:"ids"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersBulkUpdate(w http.ResponseWriter, r *http.Request) error {
|
||||
bytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("unable to read post body").WithInternalErr(err)
|
||||
}
|
||||
|
||||
header := &MemberBulkUpdatePayload{}
|
||||
if err := json.Unmarshal(bytes, &header); err != nil {
|
||||
return utils.HttpBadRequest("unable to unmarshal payload").WithInternalErr(err)
|
||||
}
|
||||
|
||||
for _, memberId := range header.IDs {
|
||||
// TODO: Bulk select?
|
||||
profile, err := h.members.Select(memberId)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to select member profile").
|
||||
Msgf("failed to update member %s", memberId)
|
||||
}
|
||||
|
||||
body := &MemberBulkUpdatePayload{
|
||||
Profile: profile,
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bytes, &body); err != nil {
|
||||
return utils.HttpBadRequest().
|
||||
WithInternalErr(err).
|
||||
Msgf("unable to unmarshal payload for member %s", memberId)
|
||||
}
|
||||
|
||||
if err := h.members.UpdateProfile(memberId, body.Profile); err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to update member profile").
|
||||
Msgf("failed to update member %s", memberId)
|
||||
}
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
type MemberBulkDeletePayload struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersBulkDelete(w http.ResponseWriter, r *http.Request) error {
|
||||
bytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("unable to read post body").WithInternalErr(err)
|
||||
}
|
||||
|
||||
data := &MemberBulkDeletePayload{}
|
||||
if err := json.Unmarshal(bytes, &data); err != nil {
|
||||
return utils.HttpBadRequest("unable to unmarshal payload").WithInternalErr(err)
|
||||
}
|
||||
|
||||
for _, memberId := range data.IDs {
|
||||
if err := h.members.Delete(memberId); err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to delete member").
|
||||
Msgf("failed to delete member %s", memberId)
|
||||
}
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
144
server/internal/api/members/controler.go
Normal file
144
server/internal/api/members/controler.go
Normal file
@ -0,0 +1,144 @@
|
||||
package members
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type MemberDataPayload struct {
|
||||
ID string `json:"id"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
}
|
||||
|
||||
type MemberCreatePayload struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
}
|
||||
|
||||
type MemberPasswordPayload struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) error {
|
||||
limit, err := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||
if err != nil {
|
||||
// TODO: Default zero.
|
||||
limit = 0
|
||||
}
|
||||
|
||||
offset, err := strconv.Atoi(r.URL.Query().Get("offset"))
|
||||
if err != nil {
|
||||
// TODO: Default zero.
|
||||
offset = 0
|
||||
}
|
||||
|
||||
entries, err := h.members.SelectAll(limit, offset)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
members := []MemberDataPayload{}
|
||||
for id, profile := range entries {
|
||||
members = append(members, MemberDataPayload{
|
||||
ID: id,
|
||||
Profile: profile,
|
||||
})
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, members)
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersCreate(w http.ResponseWriter, r *http.Request) error {
|
||||
data := &MemberCreatePayload{
|
||||
// default values
|
||||
Profile: types.MemberProfile{
|
||||
IsAdmin: false,
|
||||
CanLogin: true,
|
||||
CanConnect: true,
|
||||
CanWatch: true,
|
||||
CanHost: true,
|
||||
CanShareMedia: true,
|
||||
CanAccessClipboard: true,
|
||||
SendsInactiveCursor: true,
|
||||
CanSeeInactiveCursors: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.Username == "" {
|
||||
return utils.HttpBadRequest("username cannot be empty")
|
||||
}
|
||||
|
||||
if data.Password == "" {
|
||||
return utils.HttpBadRequest("password cannot be empty")
|
||||
}
|
||||
|
||||
id, err := h.members.Insert(data.Username, data.Password, data.Profile)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrMemberAlreadyExists) {
|
||||
return utils.HttpUnprocessableEntity("member already exists")
|
||||
}
|
||||
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, MemberDataPayload{
|
||||
ID: id,
|
||||
Profile: data.Profile,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersRead(w http.ResponseWriter, r *http.Request) error {
|
||||
member := GetMember(r)
|
||||
profile := member.Profile
|
||||
|
||||
return utils.HttpSuccess(w, profile)
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersUpdateProfile(w http.ResponseWriter, r *http.Request) error {
|
||||
member := GetMember(r)
|
||||
data := &member.Profile
|
||||
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.members.UpdateProfile(member.ID, *data); err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersUpdatePassword(w http.ResponseWriter, r *http.Request) error {
|
||||
member := GetMember(r)
|
||||
data := &MemberPasswordPayload{}
|
||||
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.members.UpdatePassword(member.ID, data.Password); err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *MembersHandler) membersDelete(w http.ResponseWriter, r *http.Request) error {
|
||||
member := GetMember(r)
|
||||
|
||||
if err := h.members.Delete(member.ID); err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
83
server/internal/api/members/handler.go
Normal file
83
server/internal/api/members/handler.go
Normal file
@ -0,0 +1,83 @@
|
||||
package members
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const keyMemberCtx key = iota
|
||||
|
||||
type MembersHandler struct {
|
||||
members types.MemberManager
|
||||
}
|
||||
|
||||
func New(
|
||||
members types.MemberManager,
|
||||
) *MembersHandler {
|
||||
// Init
|
||||
|
||||
return &MembersHandler{
|
||||
members: members,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MembersHandler) Route(r types.Router) {
|
||||
r.Get("/", h.membersList)
|
||||
|
||||
r.With(auth.AdminsOnly).Group(func(r types.Router) {
|
||||
r.Post("/", h.membersCreate)
|
||||
r.With(h.ExtractMember).Route("/{memberId}", func(r types.Router) {
|
||||
r.Get("/", h.membersRead)
|
||||
r.Post("/", h.membersUpdateProfile)
|
||||
r.Post("/password", h.membersUpdatePassword)
|
||||
r.Delete("/", h.membersDelete)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MembersHandler) RouteBulk(r types.Router) {
|
||||
r.With(auth.AdminsOnly).Group(func(r types.Router) {
|
||||
r.Post("/update", h.membersBulkUpdate)
|
||||
r.Post("/delete", h.membersBulkDelete)
|
||||
})
|
||||
}
|
||||
|
||||
type MemberData struct {
|
||||
ID string
|
||||
Profile types.MemberProfile
|
||||
}
|
||||
|
||||
func SetMember(r *http.Request, session MemberData) context.Context {
|
||||
return context.WithValue(r.Context(), keyMemberCtx, session)
|
||||
}
|
||||
|
||||
func GetMember(r *http.Request) MemberData {
|
||||
return r.Context().Value(keyMemberCtx).(MemberData)
|
||||
}
|
||||
|
||||
func (h *MembersHandler) ExtractMember(w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||
memberId := chi.URLParam(r, "memberId")
|
||||
|
||||
profile, err := h.members.Select(memberId)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrMemberDoesNotExist) {
|
||||
return nil, utils.HttpNotFound("member not found")
|
||||
}
|
||||
|
||||
return nil, utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return SetMember(r, MemberData{
|
||||
ID: memberId,
|
||||
Profile: profile,
|
||||
}), nil
|
||||
}
|
70
server/internal/api/room/broadcast.go
Normal file
70
server/internal/api/room/broadcast.go
Normal file
@ -0,0 +1,70 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type BroadcastStatusPayload struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
func (h *RoomHandler) broadcastStatus(w http.ResponseWriter, r *http.Request) error {
|
||||
broadcast := h.capture.Broadcast()
|
||||
|
||||
return utils.HttpSuccess(w, BroadcastStatusPayload{
|
||||
IsActive: broadcast.Started(),
|
||||
URL: broadcast.Url(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RoomHandler) boradcastStart(w http.ResponseWriter, r *http.Request) error {
|
||||
data := &BroadcastStatusPayload{}
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data.URL == "" {
|
||||
return utils.HttpBadRequest("missing broadcast URL")
|
||||
}
|
||||
|
||||
broadcast := h.capture.Broadcast()
|
||||
if broadcast.Started() {
|
||||
return utils.HttpUnprocessableEntity("server is already broadcasting")
|
||||
}
|
||||
|
||||
if err := broadcast.Start(data.URL); err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
h.sessions.AdminBroadcast(
|
||||
event.BORADCAST_STATUS,
|
||||
message.BroadcastStatus{
|
||||
IsActive: broadcast.Started(),
|
||||
URL: broadcast.Url(),
|
||||
})
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) boradcastStop(w http.ResponseWriter, r *http.Request) error {
|
||||
broadcast := h.capture.Broadcast()
|
||||
if !broadcast.Started() {
|
||||
return utils.HttpUnprocessableEntity("server is not broadcasting")
|
||||
}
|
||||
|
||||
broadcast.Stop()
|
||||
|
||||
h.sessions.AdminBroadcast(
|
||||
event.BORADCAST_STATUS,
|
||||
message.BroadcastStatus{
|
||||
IsActive: broadcast.Started(),
|
||||
URL: broadcast.Url(),
|
||||
})
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
107
server/internal/api/room/clipboard.go
Normal file
107
server/internal/api/room/clipboard.go
Normal file
@ -0,0 +1,107 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
// TODO: Unused now.
|
||||
//"bytes"
|
||||
//"strings"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type ClipboardPayload struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
HTML string `json:"html,omitempty"`
|
||||
}
|
||||
|
||||
func (h *RoomHandler) clipboardGetText(w http.ResponseWriter, r *http.Request) error {
|
||||
data, err := h.desktop.ClipboardGetText()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, ClipboardPayload{
|
||||
Text: data.Text,
|
||||
HTML: data.HTML,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RoomHandler) clipboardSetText(w http.ResponseWriter, r *http.Request) error {
|
||||
data := &ClipboardPayload{}
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := h.desktop.ClipboardSetText(types.ClipboardText{
|
||||
Text: data.Text,
|
||||
HTML: data.HTML,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) clipboardGetImage(w http.ResponseWriter, r *http.Request) error {
|
||||
bytes, err := h.desktop.ClipboardGetBinary("image/png")
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
|
||||
_, err = w.Write(bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
/* TODO: Unused now.
|
||||
func (h *RoomHandler) clipboardSetImage(w http.ResponseWriter, r *http.Request) error {
|
||||
err := r.ParseMultipartForm(MAX_UPLOAD_SIZE)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
|
||||
}
|
||||
|
||||
//nolint
|
||||
defer r.MultipartForm.RemoveAll()
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("no file received").WithInternalErr(err)
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
mime := header.Header.Get("Content-Type")
|
||||
if !strings.HasPrefix(mime, "image/") {
|
||||
return utils.HttpBadRequest("file must be image")
|
||||
}
|
||||
|
||||
buffer := new(bytes.Buffer)
|
||||
_, err = buffer.ReadFrom(file)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable to read from uploaded file")
|
||||
}
|
||||
|
||||
err = h.desktop.ClipboardSetBinary("image/png", buffer.Bytes())
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable set image to clipboard")
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) clipboardGetTargets(w http.ResponseWriter, r *http.Request) error {
|
||||
targets, err := h.desktop.ClipboardGetTargets()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, targets)
|
||||
}
|
||||
|
||||
*/
|
109
server/internal/api/room/control.go
Normal file
109
server/internal/api/room/control.go
Normal file
@ -0,0 +1,109 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type ControlStatusPayload struct {
|
||||
HasHost bool `json:"has_host"`
|
||||
HostId string `json:"host_id,omitempty"`
|
||||
}
|
||||
|
||||
type ControlTargetPayload struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) error {
|
||||
host, hasHost := h.sessions.GetHost()
|
||||
|
||||
var hostId string
|
||||
if hasHost {
|
||||
hostId = host.ID()
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, ControlStatusPayload{
|
||||
HasHost: hasHost,
|
||||
HostId: hostId,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
host, hasHost := h.sessions.GetHost()
|
||||
if hasHost {
|
||||
// TODO: Some throttling mechanism to prevent spamming.
|
||||
|
||||
// let host know that someone wants to take control
|
||||
host.Send(
|
||||
event.CONTROL_REQUEST,
|
||||
message.SessionID{
|
||||
ID: session.ID(),
|
||||
})
|
||||
|
||||
return utils.HttpError(http.StatusAccepted, "control request sent")
|
||||
}
|
||||
|
||||
if h.sessions.Settings().LockedControls && !session.Profile().IsAdmin {
|
||||
return utils.HttpForbidden("controls are locked")
|
||||
}
|
||||
|
||||
session.SetAsHost()
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlRelease(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
if !session.IsHost() {
|
||||
return utils.HttpUnprocessableEntity("session is not the host")
|
||||
}
|
||||
|
||||
h.desktop.ResetKeys()
|
||||
session.ClearHost()
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlTake(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
session.SetAsHost()
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
sessionId := chi.URLParam(r, "sessionId")
|
||||
|
||||
target, ok := h.sessions.Get(sessionId)
|
||||
if !ok {
|
||||
return utils.HttpNotFound("target session was not found")
|
||||
}
|
||||
|
||||
if !target.Profile().CanHost {
|
||||
return utils.HttpBadRequest("target session is not allowed to host")
|
||||
}
|
||||
|
||||
target.SetAsHostBy(session)
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
_, hasHost := h.sessions.GetHost()
|
||||
|
||||
if hasHost {
|
||||
h.desktop.ResetKeys()
|
||||
session.ClearHost()
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
126
server/internal/api/room/handler.go
Normal file
126
server/internal/api/room/handler.go
Normal file
@ -0,0 +1,126 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type RoomHandler struct {
|
||||
sessions types.SessionManager
|
||||
desktop types.DesktopManager
|
||||
capture types.CaptureManager
|
||||
|
||||
privateModeImage []byte
|
||||
}
|
||||
|
||||
func New(
|
||||
sessions types.SessionManager,
|
||||
desktop types.DesktopManager,
|
||||
capture types.CaptureManager,
|
||||
) *RoomHandler {
|
||||
h := &RoomHandler{
|
||||
sessions: sessions,
|
||||
desktop: desktop,
|
||||
capture: capture,
|
||||
}
|
||||
|
||||
// generate fallback image for private mode when needed
|
||||
sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) {
|
||||
if old.PrivateMode && !new.PrivateMode {
|
||||
log.Debug().Msg("clearing private mode fallback image")
|
||||
h.privateModeImage = nil
|
||||
return
|
||||
}
|
||||
|
||||
if !old.PrivateMode && new.PrivateMode {
|
||||
img := h.desktop.GetScreenshotImage()
|
||||
bytes, err := utils.CreateJPGImage(img, 90)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("could not generate private mode fallback image")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msg("using private mode fallback image")
|
||||
h.privateModeImage = bytes
|
||||
}
|
||||
})
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *RoomHandler) Route(r types.Router) {
|
||||
r.With(auth.AdminsOnly).Route("/settings", func(r types.Router) {
|
||||
r.Post("/", h.settingsSet)
|
||||
r.Get("/", h.settingsGet)
|
||||
})
|
||||
|
||||
r.With(auth.AdminsOnly).Route("/broadcast", func(r types.Router) {
|
||||
r.Get("/", h.broadcastStatus)
|
||||
r.Post("/start", h.boradcastStart)
|
||||
r.Post("/stop", h.boradcastStop)
|
||||
})
|
||||
|
||||
r.With(auth.CanAccessClipboardOnly).With(auth.HostsOnly).Route("/clipboard", func(r types.Router) {
|
||||
r.Get("/", h.clipboardGetText)
|
||||
r.Post("/", h.clipboardSetText)
|
||||
r.Get("/image.png", h.clipboardGetImage)
|
||||
|
||||
// TODO: Refactor. xclip is failing to set propper target type
|
||||
// and this content is sent back to client as text in another
|
||||
// clipboard update. Therefore endpoint is not usable!
|
||||
//r.Post("/image", h.clipboardSetImage)
|
||||
|
||||
// TODO: Refactor. If there would be implemented custom target
|
||||
// retrieval, this endpoint would be useful.
|
||||
//r.Get("/targets", h.clipboardGetTargets)
|
||||
})
|
||||
|
||||
r.With(auth.CanHostOnly).Route("/keyboard", func(r types.Router) {
|
||||
r.Get("/map", h.keyboardMapGet)
|
||||
r.With(auth.HostsOnly).Post("/map", h.keyboardMapSet)
|
||||
|
||||
r.Get("/modifiers", h.keyboardModifiersGet)
|
||||
r.With(auth.HostsOnly).Post("/modifiers", h.keyboardModifiersSet)
|
||||
})
|
||||
|
||||
r.With(auth.CanHostOnly).Route("/control", func(r types.Router) {
|
||||
r.Get("/", h.controlStatus)
|
||||
r.Post("/request", h.controlRequest)
|
||||
r.Post("/release", h.controlRelease)
|
||||
|
||||
r.With(auth.AdminsOnly).Post("/take", h.controlTake)
|
||||
r.With(auth.AdminsOnly).Post("/give/{sessionId}", h.controlGive)
|
||||
r.With(auth.AdminsOnly).Post("/reset", h.controlReset)
|
||||
})
|
||||
|
||||
r.With(auth.CanWatchOnly).Route("/screen", func(r types.Router) {
|
||||
r.Get("/", h.screenConfiguration)
|
||||
r.With(auth.AdminsOnly).Post("/", h.screenConfigurationChange)
|
||||
r.With(auth.AdminsOnly).Get("/configurations", h.screenConfigurationsList)
|
||||
|
||||
r.Get("/cast.jpg", h.screenCastGet)
|
||||
r.With(auth.AdminsOnly).Get("/shot.jpg", h.screenShotGet)
|
||||
})
|
||||
|
||||
r.With(h.uploadMiddleware).Route("/upload", func(r types.Router) {
|
||||
r.Post("/drop", h.uploadDrop)
|
||||
r.Post("/dialog", h.uploadDialogPost)
|
||||
r.Delete("/dialog", h.uploadDialogClose)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (h *RoomHandler) uploadMiddleware(w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||
session, ok := auth.GetSession(r)
|
||||
if !ok || (!session.IsHost() && (!session.Profile().CanHost || !h.sessions.Settings().ImplicitHosting)) {
|
||||
return nil, utils.HttpForbidden("without implicit hosting, only host can upload files")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
47
server/internal/api/room/keyboard.go
Normal file
47
server/internal/api/room/keyboard.go
Normal file
@ -0,0 +1,47 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) error {
|
||||
keyboardMap := types.KeyboardMap{}
|
||||
if err := utils.HttpJsonRequest(w, r, &keyboardMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := h.desktop.SetKeyboardMap(keyboardMap)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) keyboardMapGet(w http.ResponseWriter, r *http.Request) error {
|
||||
keyboardMap, err := h.desktop.GetKeyboardMap()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, keyboardMap)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) keyboardModifiersSet(w http.ResponseWriter, r *http.Request) error {
|
||||
keyboardModifiers := types.KeyboardModifiers{}
|
||||
if err := utils.HttpJsonRequest(w, r, &keyboardModifiers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.desktop.SetKeyboardModifiers(keyboardModifiers)
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) keyboardModifiersGet(w http.ResponseWriter, r *http.Request) error {
|
||||
keyboardModifiers := h.desktop.GetKeyboardModifiers()
|
||||
|
||||
return utils.HttpSuccess(w, keyboardModifiers)
|
||||
}
|
101
server/internal/api/room/screen.go
Normal file
101
server/internal/api/room/screen.go
Normal file
@ -0,0 +1,101 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func (h *RoomHandler) screenConfiguration(w http.ResponseWriter, r *http.Request) error {
|
||||
screenSize := h.desktop.GetScreenSize()
|
||||
|
||||
return utils.HttpSuccess(w, screenSize)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.Request) error {
|
||||
auth, _ := auth.GetSession(r)
|
||||
|
||||
data := &types.ScreenSize{}
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
size, err := h.desktop.SetScreenSize(types.ScreenSize{
|
||||
Width: data.Width,
|
||||
Height: data.Height,
|
||||
Rate: data.Rate,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return utils.HttpUnprocessableEntity("cannot set screen size").WithInternalErr(err)
|
||||
}
|
||||
|
||||
h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{
|
||||
ID: auth.ID(),
|
||||
ScreenSize: size,
|
||||
})
|
||||
|
||||
return utils.HttpSuccess(w, data)
|
||||
}
|
||||
|
||||
// TODO: remove.
|
||||
func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Request) error {
|
||||
configurations := h.desktop.ScreenConfigurations()
|
||||
|
||||
return utils.HttpSuccess(w, configurations)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) error {
|
||||
quality, err := strconv.Atoi(r.URL.Query().Get("quality"))
|
||||
if err != nil {
|
||||
quality = 90
|
||||
}
|
||||
|
||||
img := h.desktop.GetScreenshotImage()
|
||||
bytes, err := utils.CreateJPGImage(img, quality)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Content-Type", "image/jpeg")
|
||||
|
||||
_, err = w.Write(bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *RoomHandler) screenCastGet(w http.ResponseWriter, r *http.Request) error {
|
||||
// display fallback image when private mode is enabled even if screencast is not
|
||||
if session, ok := auth.GetSession(r); ok && session.PrivateModeEnabled() {
|
||||
if h.privateModeImage != nil {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Content-Type", "image/jpeg")
|
||||
|
||||
_, err := w.Write(h.privateModeImage)
|
||||
return err
|
||||
}
|
||||
|
||||
return utils.HttpBadRequest("private mode is enabled but no fallback image available")
|
||||
}
|
||||
|
||||
screencast := h.capture.Screencast()
|
||||
if !screencast.Enabled() {
|
||||
return utils.HttpBadRequest("screencast pipeline is not enabled")
|
||||
}
|
||||
|
||||
bytes, err := screencast.Image()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Content-Type", "image/jpeg")
|
||||
|
||||
_, err = w.Write(bytes)
|
||||
return err
|
||||
}
|
38
server/internal/api/room/settings.go
Normal file
38
server/internal/api/room/settings.go
Normal file
@ -0,0 +1,38 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func (h *RoomHandler) settingsGet(w http.ResponseWriter, r *http.Request) error {
|
||||
settings := h.sessions.Settings()
|
||||
return utils.HttpSuccess(w, settings)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) settingsSet(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
|
||||
// We read the request body first and unmashal it inside the UpdateSettingsFunc
|
||||
// to ensure atomicity of the operation.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("unable to read request body").WithInternalErr(err)
|
||||
}
|
||||
|
||||
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
|
||||
err = json.Unmarshal(body, settings)
|
||||
return err == nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("unable to parse provided data").WithInternalErr(err)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
172
server/internal/api/room/upload.go
Normal file
172
server/internal/api/room/upload.go
Normal file
@ -0,0 +1,172 @@
|
||||
package room
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
// TODO: Extract file uploading to custom utility.
|
||||
|
||||
// maximum upload size of 32 MB
|
||||
const maxUploadSize = 32 << 20
|
||||
|
||||
func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) error {
|
||||
if !h.desktop.IsUploadDropEnabled() {
|
||||
return utils.HttpBadRequest("upload drop is disabled")
|
||||
}
|
||||
|
||||
err := r.ParseMultipartForm(maxUploadSize)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
|
||||
}
|
||||
|
||||
//nolint
|
||||
defer r.MultipartForm.RemoveAll()
|
||||
|
||||
X, err := strconv.Atoi(r.FormValue("x"))
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("no X coordinate received").WithInternalErr(err)
|
||||
}
|
||||
|
||||
Y, err := strconv.Atoi(r.FormValue("y"))
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("no Y coordinate received").WithInternalErr(err)
|
||||
}
|
||||
|
||||
req_files := r.MultipartForm.File["files"]
|
||||
if len(req_files) == 0 {
|
||||
return utils.HttpBadRequest("no files received")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "neko-drop-*")
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to create temporary directory")
|
||||
}
|
||||
|
||||
files := []string{}
|
||||
for _, req_file := range req_files {
|
||||
path := path.Join(dir, req_file.Filename)
|
||||
|
||||
srcFile, err := req_file.Open()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to open uploaded file")
|
||||
}
|
||||
|
||||
defer srcFile.Close()
|
||||
|
||||
dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to open destination file")
|
||||
}
|
||||
|
||||
defer dstFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to copy uploaded file to destination file")
|
||||
}
|
||||
|
||||
files = append(files, path)
|
||||
}
|
||||
|
||||
if !h.desktop.DropFiles(X, Y, files) {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalMsg("unable to drop files")
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) uploadDialogPost(w http.ResponseWriter, r *http.Request) error {
|
||||
if !h.desktop.IsFileChooserDialogEnabled() {
|
||||
return utils.HttpBadRequest("file chooser dialog is disabled")
|
||||
}
|
||||
|
||||
err := r.ParseMultipartForm(maxUploadSize)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err)
|
||||
}
|
||||
|
||||
//nolint
|
||||
defer r.MultipartForm.RemoveAll()
|
||||
|
||||
req_files := r.MultipartForm.File["files"]
|
||||
if len(req_files) == 0 {
|
||||
return utils.HttpBadRequest("no files received")
|
||||
}
|
||||
|
||||
if !h.desktop.IsFileChooserDialogOpened() {
|
||||
return utils.HttpUnprocessableEntity("file chooser dialog is not open")
|
||||
}
|
||||
|
||||
dir, err := os.MkdirTemp("", "neko-dialog-*")
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to create temporary directory")
|
||||
}
|
||||
|
||||
for _, req_file := range req_files {
|
||||
path := path.Join(dir, req_file.Filename)
|
||||
|
||||
srcFile, err := req_file.Open()
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to open uploaded file")
|
||||
}
|
||||
|
||||
defer srcFile.Close()
|
||||
|
||||
dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to open destination file")
|
||||
}
|
||||
|
||||
defer dstFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to copy uploaded file to destination file")
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.desktop.HandleFileChooserDialog(dir); err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
WithInternalMsg("unable to handle file chooser dialog")
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *RoomHandler) uploadDialogClose(w http.ResponseWriter, r *http.Request) error {
|
||||
if !h.desktop.IsFileChooserDialogEnabled() {
|
||||
return utils.HttpBadRequest("file chooser dialog is disabled")
|
||||
}
|
||||
|
||||
if !h.desktop.IsFileChooserDialogOpened() {
|
||||
return utils.HttpUnprocessableEntity("file chooser dialog is not open")
|
||||
}
|
||||
|
||||
h.desktop.CloseFileChooserDialog()
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
85
server/internal/api/router.go
Normal file
85
server/internal/api/router.go
Normal file
@ -0,0 +1,85 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/internal/api/members"
|
||||
"github.com/demodesk/neko/internal/api/room"
|
||||
"github.com/demodesk/neko/internal/api/sessions"
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type ApiManagerCtx struct {
|
||||
sessions types.SessionManager
|
||||
members types.MemberManager
|
||||
desktop types.DesktopManager
|
||||
capture types.CaptureManager
|
||||
routers map[string]func(types.Router)
|
||||
}
|
||||
|
||||
func New(
|
||||
sessions types.SessionManager,
|
||||
members types.MemberManager,
|
||||
desktop types.DesktopManager,
|
||||
capture types.CaptureManager,
|
||||
) *ApiManagerCtx {
|
||||
|
||||
return &ApiManagerCtx{
|
||||
sessions: sessions,
|
||||
members: members,
|
||||
desktop: desktop,
|
||||
capture: capture,
|
||||
routers: make(map[string]func(types.Router)),
|
||||
}
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) Route(r types.Router) {
|
||||
r.Post("/login", api.Login)
|
||||
|
||||
// Authenticated area
|
||||
r.Group(func(r types.Router) {
|
||||
r.Use(api.Authenticate)
|
||||
|
||||
r.Post("/logout", api.Logout)
|
||||
r.Get("/whoami", api.Whoami)
|
||||
|
||||
sessionsHandler := sessions.New(api.sessions)
|
||||
r.Route("/sessions", sessionsHandler.Route)
|
||||
|
||||
membersHandler := members.New(api.members)
|
||||
r.Route("/members", membersHandler.Route)
|
||||
r.Route("/members_bulk", membersHandler.RouteBulk)
|
||||
|
||||
roomHandler := room.New(api.sessions, api.desktop, api.capture)
|
||||
r.Route("/room", roomHandler.Route)
|
||||
|
||||
for path, router := range api.routers {
|
||||
r.Route(path, router)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) Authenticate(w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||
session, err := api.sessions.Authenticate(r)
|
||||
if err != nil {
|
||||
if api.sessions.CookieEnabled() {
|
||||
api.sessions.CookieClearToken(w, r)
|
||||
}
|
||||
|
||||
if errors.Is(err, types.ErrSessionLoginDisabled) {
|
||||
return nil, utils.HttpForbidden("login is disabled for this session")
|
||||
}
|
||||
|
||||
return nil, utils.HttpUnauthorized().WithInternalErr(err)
|
||||
}
|
||||
|
||||
return auth.SetSession(r, session), nil
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) AddRouter(path string, router func(types.Router)) {
|
||||
api.routers[path] = router
|
||||
}
|
85
server/internal/api/session.go
Normal file
85
server/internal/api/session.go
Normal file
@ -0,0 +1,85 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type SessionLoginPayload struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type SessionDataPayload struct {
|
||||
ID string `json:"id"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
State types.SessionState `json:"state"`
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) error {
|
||||
data := &SessionLoginPayload{}
|
||||
if err := utils.HttpJsonRequest(w, r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, token, err := api.members.Login(data.Username, data.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrSessionAlreadyConnected) {
|
||||
return utils.HttpUnprocessableEntity("session already connected")
|
||||
} else if errors.Is(err, types.ErrMemberDoesNotExist) || errors.Is(err, types.ErrMemberInvalidPassword) {
|
||||
return utils.HttpUnauthorized().WithInternalErr(err)
|
||||
} else if errors.Is(err, types.ErrSessionLoginsLocked) {
|
||||
return utils.HttpForbidden("logins are locked").WithInternalErr(err)
|
||||
} else {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
sessionData := SessionDataPayload{
|
||||
ID: session.ID(),
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
}
|
||||
|
||||
if api.sessions.CookieEnabled() {
|
||||
api.sessions.CookieSetToken(w, token)
|
||||
} else {
|
||||
sessionData.Token = token
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, sessionData)
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) Logout(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
|
||||
err := api.members.Logout(session.ID())
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrSessionNotFound) {
|
||||
return utils.HttpBadRequest("session is not logged in")
|
||||
} else {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
if api.sessions.CookieEnabled() {
|
||||
api.sessions.CookieClearToken(w, r)
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, true)
|
||||
}
|
||||
|
||||
func (api *ApiManagerCtx) Whoami(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
|
||||
return utils.HttpSuccess(w, SessionDataPayload{
|
||||
ID: session.ID(),
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
})
|
||||
}
|
80
server/internal/api/sessions/controller.go
Normal file
80
server/internal/api/sessions/controller.go
Normal file
@ -0,0 +1,80 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
"github.com/go-chi/chi"
|
||||
)
|
||||
|
||||
type SessionDataPayload struct {
|
||||
ID string `json:"id"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
State types.SessionState `json:"state"`
|
||||
}
|
||||
|
||||
func (h *SessionsHandler) sessionsList(w http.ResponseWriter, r *http.Request) error {
|
||||
sessions := []SessionDataPayload{}
|
||||
for _, session := range h.sessions.List() {
|
||||
sessions = append(sessions, SessionDataPayload{
|
||||
ID: session.ID(),
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
})
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, sessions)
|
||||
}
|
||||
|
||||
func (h *SessionsHandler) sessionsRead(w http.ResponseWriter, r *http.Request) error {
|
||||
sessionId := chi.URLParam(r, "sessionId")
|
||||
|
||||
session, ok := h.sessions.Get(sessionId)
|
||||
if !ok {
|
||||
return utils.HttpNotFound("session not found")
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, SessionDataPayload{
|
||||
ID: session.ID(),
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SessionsHandler) sessionsDelete(w http.ResponseWriter, r *http.Request) error {
|
||||
session, _ := auth.GetSession(r)
|
||||
|
||||
sessionId := chi.URLParam(r, "sessionId")
|
||||
if sessionId == session.ID() {
|
||||
return utils.HttpBadRequest("cannot delete own session")
|
||||
}
|
||||
|
||||
err := h.sessions.Delete(sessionId)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrSessionNotFound) {
|
||||
return utils.HttpBadRequest("session not found")
|
||||
} else {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
||||
|
||||
func (h *SessionsHandler) sessionsDisconnect(w http.ResponseWriter, r *http.Request) error {
|
||||
sessionId := chi.URLParam(r, "sessionId")
|
||||
|
||||
err := h.sessions.Disconnect(sessionId)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrSessionNotFound) {
|
||||
return utils.HttpBadRequest("session not found")
|
||||
} else {
|
||||
return utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
30
server/internal/api/sessions/handler.go
Normal file
30
server/internal/api/sessions/handler.go
Normal file
@ -0,0 +1,30 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type SessionsHandler struct {
|
||||
sessions types.SessionManager
|
||||
}
|
||||
|
||||
func New(
|
||||
sessions types.SessionManager,
|
||||
) *SessionsHandler {
|
||||
// Init
|
||||
|
||||
return &SessionsHandler{
|
||||
sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SessionsHandler) Route(r types.Router) {
|
||||
r.Get("/", h.sessionsList)
|
||||
|
||||
r.With(auth.AdminsOnly).Route("/{sessionId}", func(r types.Router) {
|
||||
r.Get("/", h.sessionsRead)
|
||||
r.Delete("/", h.sessionsDelete)
|
||||
r.Post("/disconnect", h.sessionsDisconnect)
|
||||
})
|
||||
}
|
156
server/internal/capture/broadcast.go
Normal file
156
server/internal/capture/broadcast.go
Normal file
@ -0,0 +1,156 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/gst"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type BroacastManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
mu sync.Mutex
|
||||
|
||||
pipeline gst.Pipeline
|
||||
pipelineMu sync.Mutex
|
||||
pipelineFn func(url string) (string, error)
|
||||
|
||||
url string
|
||||
started bool
|
||||
|
||||
// metrics
|
||||
pipelinesCounter prometheus.Counter
|
||||
pipelinesActive prometheus.Gauge
|
||||
}
|
||||
|
||||
func broadcastNew(pipelineFn func(url string) (string, error), defaultUrl string) *BroacastManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "broadcast").
|
||||
Logger()
|
||||
|
||||
return &BroacastManagerCtx{
|
||||
logger: logger,
|
||||
pipelineFn: pipelineFn,
|
||||
url: defaultUrl,
|
||||
started: defaultUrl != "",
|
||||
|
||||
// metrics
|
||||
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "pipelines_total",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of created pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "broadcast",
|
||||
"video_id": "main",
|
||||
"codec_name": "-",
|
||||
"codec_type": "-",
|
||||
},
|
||||
}),
|
||||
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "pipelines_active",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of active pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "broadcast",
|
||||
"video_id": "main",
|
||||
"codec_name": "-",
|
||||
"codec_type": "-",
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.destroyPipeline()
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) Start(url string) error {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
err := manager.createPipeline()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.url = url
|
||||
manager.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) Stop() {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
manager.started = false
|
||||
manager.destroyPipeline()
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) Started() bool {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
return manager.started
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) Url() string {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
return manager.url
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) createPipeline() error {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline != nil {
|
||||
return types.ErrCapturePipelineAlreadyExists
|
||||
}
|
||||
|
||||
pipelineStr, err := manager.pipelineFn(manager.url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.logger.Info().
|
||||
Str("url", manager.url).
|
||||
Str("src", pipelineStr).
|
||||
Msgf("starting pipeline")
|
||||
|
||||
manager.pipeline, err = gst.CreatePipeline(pipelineStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.pipeline.Play()
|
||||
manager.pipelinesCounter.Inc()
|
||||
manager.pipelinesActive.Set(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *BroacastManagerCtx) destroyPipeline() {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.pipeline.Destroy()
|
||||
manager.logger.Info().Msgf("destroying pipeline")
|
||||
manager.pipeline = nil
|
||||
|
||||
manager.pipelinesActive.Set(0)
|
||||
}
|
269
server/internal/capture/manager.go
Normal file
269
server/internal/capture/manager.go
Normal file
@ -0,0 +1,269 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type CaptureManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
desktop types.DesktopManager
|
||||
config *config.Capture
|
||||
|
||||
// sinks
|
||||
broadcast *BroacastManagerCtx
|
||||
screencast *ScreencastManagerCtx
|
||||
audio *StreamSinkManagerCtx
|
||||
video *StreamSelectorManagerCtx
|
||||
|
||||
// sources
|
||||
webcam *StreamSrcManagerCtx
|
||||
microphone *StreamSrcManagerCtx
|
||||
}
|
||||
|
||||
func New(desktop types.DesktopManager, config *config.Capture) *CaptureManagerCtx {
|
||||
logger := log.With().Str("module", "capture").Logger()
|
||||
|
||||
videos := map[string]types.StreamSinkManager{}
|
||||
for video_id, cnf := range config.VideoPipelines {
|
||||
pipelineConf := cnf
|
||||
|
||||
createPipeline := func() (string, error) {
|
||||
if pipelineConf.GstPipeline != "" {
|
||||
// replace {display} with valid display
|
||||
return strings.Replace(pipelineConf.GstPipeline, "{display}", config.Display, 1), nil
|
||||
}
|
||||
|
||||
screen := desktop.GetScreenSize()
|
||||
pipeline, err := pipelineConf.GetPipeline(screen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"ximagesrc display-name=%s show-pointer=false use-damage=false "+
|
||||
"%s ! appsink name=appsink", config.Display, pipeline,
|
||||
), nil
|
||||
}
|
||||
|
||||
// trigger function to catch evaluation errors at startup
|
||||
pipeline, err := createPipeline()
|
||||
if err != nil {
|
||||
logger.Panic().Err(err).
|
||||
Str("video_id", video_id).
|
||||
Msg("failed to create video pipeline")
|
||||
}
|
||||
|
||||
logger.Info().
|
||||
Str("video_id", video_id).
|
||||
Str("pipeline", pipeline).
|
||||
Msg("syntax check for video stream pipeline passed")
|
||||
|
||||
// append to videos
|
||||
videos[video_id] = streamSinkNew(config.VideoCodec, createPipeline, video_id)
|
||||
}
|
||||
|
||||
return &CaptureManagerCtx{
|
||||
logger: logger,
|
||||
desktop: desktop,
|
||||
config: config,
|
||||
|
||||
// sinks
|
||||
broadcast: broadcastNew(func(url string) (string, error) {
|
||||
if config.BroadcastPipeline != "" {
|
||||
var pipeline = config.BroadcastPipeline
|
||||
// replace {display} with valid display
|
||||
pipeline = strings.Replace(pipeline, "{display}", config.Display, 1)
|
||||
// replace {device} with valid device
|
||||
pipeline = strings.Replace(pipeline, "{device}", config.AudioDevice, 1)
|
||||
// replace {url} with valid URL
|
||||
return strings.Replace(pipeline, "{url}", url, 1), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"flvmux name=mux ! rtmpsink location='%s live=1' "+
|
||||
"pulsesrc device=%s "+
|
||||
"! audio/x-raw,channels=2 "+
|
||||
"! audioconvert "+
|
||||
"! queue "+
|
||||
"! voaacenc bitrate=%d "+
|
||||
"! mux. "+
|
||||
"ximagesrc display-name=%s show-pointer=true use-damage=false "+
|
||||
"! video/x-raw "+
|
||||
"! videoconvert "+
|
||||
"! queue "+
|
||||
"! x264enc threads=4 bitrate=%d key-int-max=15 byte-stream=true tune=zerolatency speed-preset=%s "+
|
||||
"! mux.", url, config.AudioDevice, config.BroadcastAudioBitrate*1000, config.Display, config.BroadcastVideoBitrate, config.BroadcastPreset,
|
||||
), nil
|
||||
}, config.BroadcastUrl),
|
||||
screencast: screencastNew(config.ScreencastEnabled, func() string {
|
||||
if config.ScreencastPipeline != "" {
|
||||
// replace {display} with valid display
|
||||
return strings.Replace(config.ScreencastPipeline, "{display}", config.Display, 1)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"ximagesrc display-name=%s show-pointer=true use-damage=false "+
|
||||
"! video/x-raw,framerate=%s "+
|
||||
"! videoconvert "+
|
||||
"! queue "+
|
||||
"! jpegenc quality=%s "+
|
||||
"! appsink name=appsink", config.Display, config.ScreencastRate, config.ScreencastQuality,
|
||||
)
|
||||
}()),
|
||||
|
||||
audio: streamSinkNew(config.AudioCodec, func() (string, error) {
|
||||
if config.AudioPipeline != "" {
|
||||
// replace {device} with valid device
|
||||
return strings.Replace(config.AudioPipeline, "{device}", config.AudioDevice, 1), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"pulsesrc device=%s "+
|
||||
"! audio/x-raw,channels=2 "+
|
||||
"! audioconvert "+
|
||||
"! queue "+
|
||||
"! %s "+
|
||||
"! appsink name=appsink", config.AudioDevice, config.AudioCodec.Pipeline,
|
||||
), nil
|
||||
}, "audio"),
|
||||
video: streamSelectorNew(config.VideoCodec, videos, config.VideoIDs),
|
||||
|
||||
// sources
|
||||
webcam: streamSrcNew(config.WebcamEnabled, map[string]string{
|
||||
codec.VP8().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
|
||||
fmt.Sprintf("! application/x-rtp, payload=%d, encoding-name=VP8-DRAFT-IETF-01 ", codec.VP8().PayloadType) +
|
||||
"! rtpvp8depay " +
|
||||
"! decodebin " +
|
||||
"! videoconvert " +
|
||||
"! videorate " +
|
||||
"! videoscale " +
|
||||
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
|
||||
"! identity drop-allocation=true " +
|
||||
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
|
||||
// TODO: Test this pipeline.
|
||||
codec.VP9().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
|
||||
"! application/x-rtp " +
|
||||
"! rtpvp9depay " +
|
||||
"! decodebin " +
|
||||
"! videoconvert " +
|
||||
"! videorate " +
|
||||
"! videoscale " +
|
||||
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
|
||||
"! identity drop-allocation=true " +
|
||||
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
|
||||
// TODO: Test this pipeline.
|
||||
codec.H264().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
|
||||
"! application/x-rtp " +
|
||||
"! rtph264depay " +
|
||||
"! decodebin " +
|
||||
"! videoconvert " +
|
||||
"! videorate " +
|
||||
"! videoscale " +
|
||||
fmt.Sprintf("! video/x-raw,width=%d,height=%d ", config.WebcamWidth, config.WebcamHeight) +
|
||||
"! identity drop-allocation=true " +
|
||||
fmt.Sprintf("! v4l2sink sync=false device=%s", config.WebcamDevice),
|
||||
}, "webcam"),
|
||||
microphone: streamSrcNew(config.MicrophoneEnabled, map[string]string{
|
||||
codec.Opus().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
|
||||
fmt.Sprintf("! application/x-rtp, payload=%d, encoding-name=OPUS ", codec.Opus().PayloadType) +
|
||||
"! rtpopusdepay " +
|
||||
"! decodebin " +
|
||||
fmt.Sprintf("! pulsesink device=%s", config.MicrophoneDevice),
|
||||
// TODO: Test this pipeline.
|
||||
codec.G722().Name: "appsrc format=time is-live=true do-timestamp=true name=appsrc " +
|
||||
"! application/x-rtp clock-rate=8000 " +
|
||||
"! rtpg722depay " +
|
||||
"! decodebin " +
|
||||
fmt.Sprintf("! pulsesink device=%s", config.MicrophoneDevice),
|
||||
}, "microphone"),
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Start() {
|
||||
if manager.broadcast.Started() {
|
||||
if err := manager.broadcast.createPipeline(); err != nil {
|
||||
manager.logger.Panic().Err(err).Msg("unable to create broadcast pipeline")
|
||||
}
|
||||
}
|
||||
|
||||
manager.desktop.OnBeforeScreenSizeChange(func() {
|
||||
manager.video.destroyPipelines()
|
||||
|
||||
if manager.broadcast.Started() {
|
||||
manager.broadcast.destroyPipeline()
|
||||
}
|
||||
|
||||
if manager.screencast.Started() {
|
||||
manager.screencast.destroyPipeline()
|
||||
}
|
||||
})
|
||||
|
||||
manager.desktop.OnAfterScreenSizeChange(func() {
|
||||
err := manager.video.recreatePipelines()
|
||||
if err != nil {
|
||||
manager.logger.Panic().Err(err).Msg("unable to recreate video pipelines")
|
||||
}
|
||||
|
||||
if manager.broadcast.Started() {
|
||||
err := manager.broadcast.createPipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
manager.logger.Panic().Err(err).Msg("unable to recreate broadcast pipeline")
|
||||
}
|
||||
}
|
||||
|
||||
if manager.screencast.Started() {
|
||||
err := manager.screencast.createPipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
manager.logger.Panic().Err(err).Msg("unable to recreate screencast pipeline")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.broadcast.shutdown()
|
||||
manager.screencast.shutdown()
|
||||
|
||||
manager.audio.shutdown()
|
||||
manager.video.shutdown()
|
||||
|
||||
manager.webcam.shutdown()
|
||||
manager.microphone.shutdown()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Broadcast() types.BroadcastManager {
|
||||
return manager.broadcast
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Screencast() types.ScreencastManager {
|
||||
return manager.screencast
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Audio() types.StreamSinkManager {
|
||||
return manager.audio
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Video() types.StreamSelectorManager {
|
||||
return manager.video
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Webcam() types.StreamSrcManager {
|
||||
return manager.webcam
|
||||
}
|
||||
|
||||
func (manager *CaptureManagerCtx) Microphone() types.StreamSrcManager {
|
||||
return manager.microphone
|
||||
}
|
257
server/internal/capture/screencast.go
Normal file
257
server/internal/capture/screencast.go
Normal file
@ -0,0 +1,257 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/gst"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
// timeout between intervals, when screencast pipeline is checked
|
||||
const screencastTimeout = 5 * time.Second
|
||||
|
||||
type ScreencastManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
|
||||
pipeline gst.Pipeline
|
||||
pipelineStr string
|
||||
pipelineMu sync.Mutex
|
||||
|
||||
image types.Sample
|
||||
imageMu sync.Mutex
|
||||
tickerStop chan struct{}
|
||||
|
||||
enabled bool
|
||||
started bool
|
||||
expired int32
|
||||
|
||||
// metrics
|
||||
imagesCounter prometheus.Counter
|
||||
pipelinesCounter prometheus.Counter
|
||||
pipelinesActive prometheus.Gauge
|
||||
}
|
||||
|
||||
func screencastNew(enabled bool, pipelineStr string) *ScreencastManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "screencast").
|
||||
Logger()
|
||||
|
||||
manager := &ScreencastManagerCtx{
|
||||
logger: logger,
|
||||
pipelineStr: pipelineStr,
|
||||
tickerStop: make(chan struct{}),
|
||||
enabled: enabled,
|
||||
started: false,
|
||||
|
||||
// metrics
|
||||
imagesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "screencast_images_total",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of created images.",
|
||||
}),
|
||||
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "pipelines_total",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of created pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "screencast",
|
||||
"video_id": "main",
|
||||
"codec_name": "-",
|
||||
"codec_type": "-",
|
||||
},
|
||||
}),
|
||||
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "pipelines_active",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of active pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "screencast",
|
||||
"video_id": "main",
|
||||
"codec_name": "-",
|
||||
"codec_type": "-",
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
manager.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer manager.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(screencastTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-manager.tickerStop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if manager.Started() && !atomic.CompareAndSwapInt32(&manager.expired, 0, 1) {
|
||||
manager.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.destroyPipeline()
|
||||
|
||||
close(manager.tickerStop)
|
||||
manager.wg.Wait()
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) Enabled() bool {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
return manager.enabled
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) Started() bool {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
return manager.started
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) Image() ([]byte, error) {
|
||||
atomic.StoreInt32(&manager.expired, 0)
|
||||
|
||||
err := manager.start()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager.imageMu.Lock()
|
||||
defer manager.imageMu.Unlock()
|
||||
|
||||
if manager.image.Data == nil {
|
||||
return nil, errors.New("image data not found")
|
||||
}
|
||||
|
||||
return manager.image.Data, nil
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) start() error {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
if !manager.enabled {
|
||||
return errors.New("screencast not enabled")
|
||||
}
|
||||
|
||||
err := manager.createPipeline()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) stop() {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
manager.started = false
|
||||
manager.destroyPipeline()
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) createPipeline() error {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline != nil {
|
||||
return types.ErrCapturePipelineAlreadyExists
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
manager.logger.Info().
|
||||
Str("str", manager.pipelineStr).
|
||||
Msgf("creating pipeline")
|
||||
|
||||
manager.pipeline, err = gst.CreatePipeline(manager.pipelineStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.pipeline.AttachAppsink("appsink")
|
||||
manager.pipeline.Play()
|
||||
manager.pipelinesCounter.Inc()
|
||||
manager.pipelinesActive.Set(1)
|
||||
|
||||
// get first image
|
||||
select {
|
||||
case image, ok := <-manager.pipeline.Sample():
|
||||
if !ok {
|
||||
return errors.New("unable to get first image")
|
||||
} else {
|
||||
manager.setImage(image)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
return errors.New("timeouted while waiting for first image")
|
||||
}
|
||||
|
||||
manager.wg.Add(1)
|
||||
pipeline := manager.pipeline
|
||||
|
||||
go func() {
|
||||
manager.logger.Debug().Msg("started receiving images")
|
||||
defer manager.wg.Done()
|
||||
|
||||
for {
|
||||
image, ok := <-pipeline.Sample()
|
||||
if !ok {
|
||||
manager.logger.Debug().Msg("stopped receiving images")
|
||||
return
|
||||
}
|
||||
|
||||
manager.setImage(image)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) setImage(image types.Sample) {
|
||||
manager.imageMu.Lock()
|
||||
manager.image = image
|
||||
manager.imageMu.Unlock()
|
||||
|
||||
manager.imagesCounter.Inc()
|
||||
}
|
||||
|
||||
func (manager *ScreencastManagerCtx) destroyPipeline() {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.pipeline.Destroy()
|
||||
manager.logger.Info().Msgf("destroying pipeline")
|
||||
manager.pipeline = nil
|
||||
|
||||
manager.pipelinesActive.Set(0)
|
||||
}
|
206
server/internal/capture/streamselector.go
Normal file
206
server/internal/capture/streamselector.go
Normal file
@ -0,0 +1,206 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type StreamSelectorManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
codec codec.RTPCodec
|
||||
streams map[string]types.StreamSinkManager
|
||||
streamIDs []string
|
||||
}
|
||||
|
||||
func streamSelectorNew(codec codec.RTPCodec, streams map[string]types.StreamSinkManager, streamIDs []string) *StreamSelectorManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-selector").
|
||||
Logger()
|
||||
|
||||
return &StreamSelectorManagerCtx{
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
streams: streams,
|
||||
streamIDs: streamIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.destroyPipelines()
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) destroyPipelines() {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
stream.DestroyPipeline()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) recreatePipelines() error {
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Started() {
|
||||
err := stream.CreatePipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) IDs() []string {
|
||||
return manager.streamIDs
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *StreamSelectorManagerCtx) GetStream(selector types.StreamSelector) (types.StreamSinkManager, bool) {
|
||||
// select stream by ID
|
||||
if selector.ID != "" {
|
||||
// select lower stream
|
||||
if selector.Type == types.StreamSelectorTypeLower {
|
||||
var lastStream types.StreamSinkManager
|
||||
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||
streamID := manager.streamIDs[i]
|
||||
if streamID == selector.ID {
|
||||
return lastStream, lastStream != nil
|
||||
}
|
||||
stream, ok := manager.streams[streamID]
|
||||
if ok {
|
||||
lastStream = stream
|
||||
}
|
||||
}
|
||||
// we couldn't find a lower stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select higher stream
|
||||
if selector.Type == types.StreamSelectorTypeHigher {
|
||||
var lastStream types.StreamSinkManager
|
||||
for _, streamID := range manager.streamIDs {
|
||||
if streamID == selector.ID {
|
||||
return lastStream, lastStream != nil
|
||||
}
|
||||
stream, ok := manager.streams[streamID]
|
||||
if ok {
|
||||
lastStream = stream
|
||||
}
|
||||
}
|
||||
// we couldn't find a higher stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select exact stream
|
||||
stream, ok := manager.streams[selector.ID]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
// select stream by bitrate
|
||||
if selector.Bitrate != 0 {
|
||||
// select stream by nearest bitrate
|
||||
if selector.Type == types.StreamSelectorTypeNearest {
|
||||
return manager.nearestBitrate(selector.Bitrate), true
|
||||
}
|
||||
|
||||
// select lower stream
|
||||
if selector.Type == types.StreamSelectorTypeLower {
|
||||
// start from the highest stream, and go down, until we find a lower stream
|
||||
for i := len(manager.streamIDs) - 1; i >= 0; i-- {
|
||||
streamID := manager.streamIDs[i]
|
||||
stream := manager.streams[streamID]
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if considered && stream.Bitrate() < selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
// we couldn't find a lower stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select higher stream
|
||||
if selector.Type == types.StreamSelectorTypeHigher {
|
||||
// start from the lowest stream, and go up, until we find a higher stream
|
||||
for _, streamID := range manager.streamIDs {
|
||||
stream := manager.streams[streamID]
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if considered && stream.Bitrate() > selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
// we couldn't find a higher stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// select stream by exact bitrate
|
||||
for _, stream := range manager.streams {
|
||||
if stream.Bitrate() == selector.Bitrate {
|
||||
return stream, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we couldn't find a stream
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// TODO: This is a very naive implementation, we should use a binary search instead.
|
||||
func (manager *StreamSelectorManagerCtx) nearestBitrate(bitrate uint64) types.StreamSinkManager {
|
||||
type streamDiff struct {
|
||||
id string
|
||||
bitrateDiff int
|
||||
}
|
||||
|
||||
sortDiff := func(a, b int) bool {
|
||||
switch {
|
||||
case a < 0 && b < 0:
|
||||
return a > b
|
||||
case a >= 0:
|
||||
if b >= 0 {
|
||||
return a <= b
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var diffs []streamDiff
|
||||
|
||||
for _, stream := range manager.streams {
|
||||
// if stream should be considered in calculation
|
||||
considered := stream.Bitrate() != 0 && stream.Started()
|
||||
if !considered {
|
||||
continue
|
||||
}
|
||||
diffs = append(diffs, streamDiff{
|
||||
id: stream.ID(),
|
||||
bitrateDiff: int(bitrate) - int(stream.Bitrate()),
|
||||
})
|
||||
}
|
||||
|
||||
// no streams available
|
||||
if len(diffs) == 0 {
|
||||
// return first (lowest) stream
|
||||
return manager.streams[manager.streamIDs[0]]
|
||||
}
|
||||
|
||||
sort.Slice(diffs, func(i, j int) bool {
|
||||
return sortDiff(diffs[i].bitrateDiff, diffs[j].bitrateDiff)
|
||||
})
|
||||
|
||||
bestDiff := diffs[0]
|
||||
return manager.streams[bestDiff.id]
|
||||
}
|
414
server/internal/capture/streamsink.go
Normal file
414
server/internal/capture/streamsink.go
Normal file
@ -0,0 +1,414 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/gst"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
var moveSinkListenerMu = sync.Mutex{}
|
||||
|
||||
type StreamSinkManagerCtx struct {
|
||||
id string
|
||||
|
||||
// wait for a keyframe before sending samples
|
||||
waitForKf bool
|
||||
|
||||
bitrate uint64 // atomic
|
||||
brBuckets map[int]float64
|
||||
|
||||
logger zerolog.Logger
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
|
||||
codec codec.RTPCodec
|
||||
pipeline gst.Pipeline
|
||||
pipelineMu sync.Mutex
|
||||
pipelineFn func() (string, error)
|
||||
|
||||
listeners map[uintptr]types.SampleListener
|
||||
listenersKf map[uintptr]types.SampleListener // keyframe lobby
|
||||
listenersMu sync.Mutex
|
||||
|
||||
// metrics
|
||||
currentListeners prometheus.Gauge
|
||||
totalBytes prometheus.Counter
|
||||
pipelinesCounter prometheus.Counter
|
||||
pipelinesActive prometheus.Gauge
|
||||
}
|
||||
|
||||
func streamSinkNew(codec codec.RTPCodec, pipelineFn func() (string, error), id string) *StreamSinkManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-sink").
|
||||
Str("id", id).Logger()
|
||||
|
||||
manager := &StreamSinkManagerCtx{
|
||||
id: id,
|
||||
|
||||
// only wait for keyframes if the codec is video
|
||||
waitForKf: codec.IsVideo(),
|
||||
|
||||
bitrate: 0,
|
||||
brBuckets: map[int]float64{},
|
||||
|
||||
logger: logger,
|
||||
codec: codec,
|
||||
pipelineFn: pipelineFn,
|
||||
|
||||
listeners: map[uintptr]types.SampleListener{},
|
||||
listenersKf: map[uintptr]types.SampleListener{},
|
||||
|
||||
// metrics
|
||||
currentListeners: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "streamsink_listeners",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Current number of listeners for a pipeline.",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
totalBytes: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "streamsink_bytes",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of bytes created by the pipeline.",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
pipelinesCounter: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "pipelines_total",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of created pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
pipelinesActive: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "pipelines_active",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of active pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsink",
|
||||
"video_id": id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
for key := range manager.listeners {
|
||||
delete(manager.listeners, key)
|
||||
}
|
||||
for key := range manager.listenersKf {
|
||||
delete(manager.listenersKf, key)
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
|
||||
manager.DestroyPipeline()
|
||||
manager.wg.Wait()
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) ID() string {
|
||||
return manager.id
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Bitrate() uint64 {
|
||||
return atomic.LoadUint64(&manager.bitrate)
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Codec() codec.RTPCodec {
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) start() error {
|
||||
if len(manager.listeners)+len(manager.listenersKf) == 0 {
|
||||
err := manager.CreatePipeline()
|
||||
if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.logger.Info().Msgf("first listener, starting")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) stop() {
|
||||
if len(manager.listeners)+len(manager.listenersKf) == 0 {
|
||||
manager.DestroyPipeline()
|
||||
manager.logger.Info().Msgf("last listener, stopping")
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) addListener(listener types.SampleListener) {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
emitKeyframe := false
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
if manager.waitForKf {
|
||||
// if this is the first listener, we need to emit a keyframe
|
||||
emitKeyframe = len(manager.listenersKf) == 0
|
||||
// if we're waiting for a keyframe, add it to the keyframe lobby
|
||||
manager.listenersKf[ptr] = listener
|
||||
} else {
|
||||
// otherwise, add it as a regular listener
|
||||
manager.listeners[ptr] = listener
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
|
||||
manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener")
|
||||
manager.currentListeners.Set(float64(manager.ListenersCount()))
|
||||
|
||||
// if we will be waiting for a keyframe, emit one now
|
||||
if manager.pipeline != nil && emitKeyframe {
|
||||
manager.pipeline.EmitVideoKeyframe()
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) removeListener(listener types.SampleListener) {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
delete(manager.listeners, ptr)
|
||||
delete(manager.listenersKf, ptr) // if it's a keyframe listener, remove it too
|
||||
manager.listenersMu.Unlock()
|
||||
|
||||
manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener")
|
||||
manager.currentListeners.Set(float64(manager.ListenersCount()))
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) AddListener(listener types.SampleListener) error {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return errors.New("listener cannot be nil")
|
||||
}
|
||||
|
||||
// start if stopped
|
||||
if err := manager.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add listener
|
||||
manager.addListener(listener)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) RemoveListener(listener types.SampleListener) error {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return errors.New("listener cannot be nil")
|
||||
}
|
||||
|
||||
// remove listener
|
||||
manager.removeListener(listener)
|
||||
|
||||
// stop if started
|
||||
manager.stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// moving listeners between streams ensures, that target pipeline is running
|
||||
// before listener is added, and stops source pipeline if there are 0 listeners
|
||||
func (manager *StreamSinkManagerCtx) MoveListenerTo(listener types.SampleListener, stream types.StreamSinkManager) error {
|
||||
if listener == nil {
|
||||
return errors.New("listener cannot be nil")
|
||||
}
|
||||
|
||||
targetStream, ok := stream.(*StreamSinkManagerCtx)
|
||||
if !ok {
|
||||
return errors.New("target stream manager does not support moving listeners")
|
||||
}
|
||||
|
||||
// we need to acquire both mutextes, from source stream and from target stream
|
||||
// in order to do that safely (without possibility of deadlock) we need third
|
||||
// global mutex, that ensures atomic locking
|
||||
|
||||
// lock global mutex
|
||||
moveSinkListenerMu.Lock()
|
||||
|
||||
// lock source stream
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
|
||||
// lock target stream
|
||||
targetStream.mu.Lock()
|
||||
defer targetStream.mu.Unlock()
|
||||
|
||||
// unlock global mutex
|
||||
moveSinkListenerMu.Unlock()
|
||||
|
||||
// start if stopped
|
||||
if err := targetStream.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// swap listeners
|
||||
manager.removeListener(listener)
|
||||
targetStream.addListener(listener)
|
||||
|
||||
// stop if started
|
||||
manager.stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) ListenersCount() int {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
return len(manager.listeners) + len(manager.listenersKf)
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) Started() bool {
|
||||
return manager.ListenersCount() > 0
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) CreatePipeline() error {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline != nil {
|
||||
return types.ErrCapturePipelineAlreadyExists
|
||||
}
|
||||
|
||||
pipelineStr, err := manager.pipelineFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.logger.Info().
|
||||
Str("codec", manager.codec.Name).
|
||||
Str("src", pipelineStr).
|
||||
Msgf("creating pipeline")
|
||||
|
||||
manager.pipeline, err = gst.CreatePipeline(pipelineStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.pipeline.AttachAppsink("appsink")
|
||||
manager.pipeline.Play()
|
||||
|
||||
manager.wg.Add(1)
|
||||
pipeline := manager.pipeline
|
||||
|
||||
go func() {
|
||||
manager.logger.Debug().Msg("started emitting samples")
|
||||
defer manager.wg.Done()
|
||||
|
||||
for {
|
||||
sample, ok := <-pipeline.Sample()
|
||||
if !ok {
|
||||
manager.logger.Debug().Msg("stopped emitting samples")
|
||||
return
|
||||
}
|
||||
|
||||
manager.onSample(sample)
|
||||
}
|
||||
}()
|
||||
|
||||
manager.pipelinesCounter.Inc()
|
||||
manager.pipelinesActive.Set(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) saveSampleBitrate(timestamp time.Time, delta float64) {
|
||||
// get unix timestamp in seconds
|
||||
sec := timestamp.Unix()
|
||||
// last bucket is timestamp rounded to 3 seconds - 1 second
|
||||
last := int((sec - 1) % 3)
|
||||
// current bucket is timestamp rounded to 3 seconds
|
||||
curr := int(sec % 3)
|
||||
// next bucket is timestamp rounded to 3 seconds + 1 second
|
||||
next := int((sec + 1) % 3)
|
||||
|
||||
if manager.brBuckets[next] != 0 {
|
||||
// atomic update bitrate
|
||||
atomic.StoreUint64(&manager.bitrate, uint64(manager.brBuckets[last]))
|
||||
// empty next bucket
|
||||
manager.brBuckets[next] = 0
|
||||
}
|
||||
|
||||
// add rate to current bucket
|
||||
manager.brBuckets[curr] += delta
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) onSample(sample types.Sample) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
// save to metrics
|
||||
length := float64(sample.Length)
|
||||
manager.totalBytes.Add(length)
|
||||
manager.saveSampleBitrate(sample.Timestamp, length)
|
||||
|
||||
// if is not delta unit -> it can be decoded independently -> it is a keyframe
|
||||
if manager.waitForKf && !sample.DeltaUnit && len(manager.listenersKf) > 0 {
|
||||
// if current sample is a keyframe, move listeners from
|
||||
// keyframe lobby to actual listeners map and clear lobby
|
||||
for k, v := range manager.listenersKf {
|
||||
manager.listeners[k] = v
|
||||
}
|
||||
manager.listenersKf = make(map[uintptr]types.SampleListener)
|
||||
}
|
||||
|
||||
for _, l := range manager.listeners {
|
||||
l.WriteSample(sample)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSinkManagerCtx) DestroyPipeline() {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.pipeline.Destroy()
|
||||
manager.logger.Info().Msgf("destroying pipeline")
|
||||
manager.pipeline = nil
|
||||
|
||||
manager.pipelinesActive.Set(0)
|
||||
|
||||
manager.brBuckets = make(map[int]float64)
|
||||
atomic.StoreUint64(&manager.bitrate, 0)
|
||||
}
|
197
server/internal/capture/streamsrc.go
Normal file
197
server/internal/capture/streamsrc.go
Normal file
@ -0,0 +1,197 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/gst"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type StreamSrcManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
enabled bool
|
||||
codecPipeline map[string]string // codec -> pipeline
|
||||
|
||||
codec codec.RTPCodec
|
||||
pipeline gst.Pipeline
|
||||
pipelineMu sync.Mutex
|
||||
pipelineStr string
|
||||
|
||||
// metrics
|
||||
pushedData map[string]prometheus.Summary
|
||||
pipelinesCounter map[string]prometheus.Counter
|
||||
pipelinesActive map[string]prometheus.Gauge
|
||||
}
|
||||
|
||||
func streamSrcNew(enabled bool, codecPipeline map[string]string, video_id string) *StreamSrcManagerCtx {
|
||||
logger := log.With().
|
||||
Str("module", "capture").
|
||||
Str("submodule", "stream-src").
|
||||
Str("video_id", video_id).Logger()
|
||||
|
||||
pushedData := map[string]prometheus.Summary{}
|
||||
pipelinesCounter := map[string]prometheus.Counter{}
|
||||
pipelinesActive := map[string]prometheus.Gauge{}
|
||||
|
||||
for codecName, pipeline := range codecPipeline {
|
||||
codec, ok := codec.ParseStr(codecName)
|
||||
if !ok {
|
||||
logger.Fatal().
|
||||
Str("codec", codecName).
|
||||
Str("pipeline", pipeline).
|
||||
Msg("unknown codec name")
|
||||
}
|
||||
|
||||
pushedData[codecName] = promauto.NewSummary(prometheus.SummaryOpts{
|
||||
Name: "streamsrc_data_bytes",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Data pushed to a pipeline (in bytes).",
|
||||
ConstLabels: map[string]string{
|
||||
"video_id": video_id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
})
|
||||
pipelinesCounter[codecName] = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "pipelines_total",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of created pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsrc",
|
||||
"video_id": video_id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
})
|
||||
pipelinesActive[codecName] = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "pipelines_active",
|
||||
Namespace: "neko",
|
||||
Subsystem: "capture",
|
||||
Help: "Total number of active pipelines.",
|
||||
ConstLabels: map[string]string{
|
||||
"submodule": "streamsrc",
|
||||
"video_id": video_id,
|
||||
"codec_name": codec.Name,
|
||||
"codec_type": codec.Type.String(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &StreamSrcManagerCtx{
|
||||
logger: logger,
|
||||
enabled: enabled,
|
||||
codecPipeline: codecPipeline,
|
||||
|
||||
// metrics
|
||||
pushedData: pushedData,
|
||||
pipelinesCounter: pipelinesCounter,
|
||||
pipelinesActive: pipelinesActive,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) shutdown() {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
manager.Stop()
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) Codec() codec.RTPCodec {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
return manager.codec
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) Start(codec codec.RTPCodec) error {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline != nil {
|
||||
return types.ErrCapturePipelineAlreadyExists
|
||||
}
|
||||
|
||||
if !manager.enabled {
|
||||
return errors.New("stream-src not enabled")
|
||||
}
|
||||
|
||||
found := false
|
||||
for codecName, pipeline := range manager.codecPipeline {
|
||||
if codecName == codec.Name {
|
||||
manager.pipelineStr = pipeline
|
||||
manager.codec = codec
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return errors.New("no pipeline found for a codec")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
manager.logger.Info().
|
||||
Str("codec", manager.codec.Name).
|
||||
Str("src", manager.pipelineStr).
|
||||
Msgf("creating pipeline")
|
||||
|
||||
manager.pipeline, err = gst.CreatePipeline(manager.pipelineStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.pipeline.AttachAppsrc("appsrc")
|
||||
manager.pipeline.Play()
|
||||
|
||||
manager.pipelinesCounter[manager.codec.Name].Inc()
|
||||
manager.pipelinesActive[manager.codec.Name].Set(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) Stop() {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.pipeline.Destroy()
|
||||
manager.pipeline = nil
|
||||
|
||||
manager.logger.Info().
|
||||
Str("codec", manager.codec.Name).
|
||||
Str("src", manager.pipelineStr).
|
||||
Msgf("destroying pipeline")
|
||||
|
||||
manager.pipelinesActive[manager.codec.Name].Set(0)
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) Push(bytes []byte) {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
if manager.pipeline == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.pipeline.Push(bytes)
|
||||
manager.pushedData[manager.codec.Name].Observe(float64(len(bytes)))
|
||||
}
|
||||
|
||||
func (manager *StreamSrcManagerCtx) Started() bool {
|
||||
manager.pipelineMu.Lock()
|
||||
defer manager.pipelineMu.Unlock()
|
||||
|
||||
return manager.pipeline != nil
|
||||
}
|
246
server/internal/config/capture.go
Normal file
246
server/internal/config/capture.go
Normal file
@ -0,0 +1,246 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type Capture struct {
|
||||
Display string
|
||||
|
||||
VideoCodec codec.RTPCodec
|
||||
VideoIDs []string
|
||||
VideoPipelines map[string]types.VideoConfig
|
||||
|
||||
AudioDevice string
|
||||
AudioCodec codec.RTPCodec
|
||||
AudioPipeline string
|
||||
|
||||
BroadcastAudioBitrate int
|
||||
BroadcastVideoBitrate int
|
||||
BroadcastPreset string
|
||||
BroadcastPipeline string
|
||||
BroadcastUrl string
|
||||
|
||||
ScreencastEnabled bool
|
||||
ScreencastRate string
|
||||
ScreencastQuality string
|
||||
ScreencastPipeline string
|
||||
|
||||
WebcamEnabled bool
|
||||
WebcamDevice string
|
||||
WebcamWidth int
|
||||
WebcamHeight int
|
||||
|
||||
MicrophoneEnabled bool
|
||||
MicrophoneDevice string
|
||||
}
|
||||
|
||||
func (Capture) Init(cmd *cobra.Command) error {
|
||||
// audio
|
||||
cmd.PersistentFlags().String("capture.audio.device", "audio_output.monitor", "pulseaudio device to capture")
|
||||
if err := viper.BindPFlag("capture.audio.device", cmd.PersistentFlags().Lookup("capture.audio.device")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.audio.codec", "opus", "audio codec to be used")
|
||||
if err := viper.BindPFlag("capture.audio.codec", cmd.PersistentFlags().Lookup("capture.audio.codec")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.audio.pipeline", "", "gstreamer pipeline used for audio streaming")
|
||||
if err := viper.BindPFlag("capture.audio.pipeline", cmd.PersistentFlags().Lookup("capture.audio.pipeline")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// videos
|
||||
cmd.PersistentFlags().String("capture.video.codec", "vp8", "video codec to be used")
|
||||
if err := viper.BindPFlag("capture.video.codec", cmd.PersistentFlags().Lookup("capture.video.codec")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSlice("capture.video.ids", []string{}, "ordered list of video ids")
|
||||
if err := viper.BindPFlag("capture.video.ids", cmd.PersistentFlags().Lookup("capture.video.ids")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.video.pipelines", "[]", "pipelines config in JSON used for video streaming")
|
||||
if err := viper.BindPFlag("capture.video.pipelines", cmd.PersistentFlags().Lookup("capture.video.pipelines")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// broadcast
|
||||
cmd.PersistentFlags().Int("capture.broadcast.audio_bitrate", 128, "broadcast audio bitrate in KB/s")
|
||||
if err := viper.BindPFlag("capture.broadcast.audio_bitrate", cmd.PersistentFlags().Lookup("capture.broadcast.audio_bitrate")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("capture.broadcast.video_bitrate", 4096, "broadcast video bitrate in KB/s")
|
||||
if err := viper.BindPFlag("capture.broadcast.video_bitrate", cmd.PersistentFlags().Lookup("capture.broadcast.video_bitrate")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.broadcast.preset", "veryfast", "broadcast speed preset for h264 encoding")
|
||||
if err := viper.BindPFlag("capture.broadcast.preset", cmd.PersistentFlags().Lookup("capture.broadcast.preset")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.broadcast.pipeline", "", "gstreamer pipeline used for broadcasting")
|
||||
if err := viper.BindPFlag("capture.broadcast.pipeline", cmd.PersistentFlags().Lookup("capture.broadcast.pipeline")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.broadcast.url", "", "initial URL for broadcasting, setting this value will automatically start broadcasting")
|
||||
if err := viper.BindPFlag("capture.broadcast.url", cmd.PersistentFlags().Lookup("capture.broadcast.url")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// screencast
|
||||
cmd.PersistentFlags().Bool("capture.screencast.enabled", false, "enable screencast")
|
||||
if err := viper.BindPFlag("capture.screencast.enabled", cmd.PersistentFlags().Lookup("capture.screencast.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.screencast.rate", "10/1", "screencast frame rate")
|
||||
if err := viper.BindPFlag("capture.screencast.rate", cmd.PersistentFlags().Lookup("capture.screencast.rate")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.screencast.quality", "60", "screencast JPEG quality")
|
||||
if err := viper.BindPFlag("capture.screencast.quality", cmd.PersistentFlags().Lookup("capture.screencast.quality")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.screencast.pipeline", "", "gstreamer pipeline used for screencasting")
|
||||
if err := viper.BindPFlag("capture.screencast.pipeline", cmd.PersistentFlags().Lookup("capture.screencast.pipeline")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// webcam
|
||||
cmd.PersistentFlags().Bool("capture.webcam.enabled", false, "enable webcam stream")
|
||||
if err := viper.BindPFlag("capture.webcam.enabled", cmd.PersistentFlags().Lookup("capture.webcam.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sudo apt install v4l2loopback-dkms v4l2loopback-utils
|
||||
// sudo apt-get install linux-headers-`uname -r` linux-modules-extra-`uname -r`
|
||||
// sudo modprobe v4l2loopback exclusive_caps=1
|
||||
cmd.PersistentFlags().String("capture.webcam.device", "/dev/video0", "v4l2sink device used for webcam")
|
||||
if err := viper.BindPFlag("capture.webcam.device", cmd.PersistentFlags().Lookup("capture.webcam.device")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("capture.webcam.width", 1280, "webcam stream width")
|
||||
if err := viper.BindPFlag("capture.webcam.width", cmd.PersistentFlags().Lookup("capture.webcam.width")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("capture.webcam.height", 720, "webcam stream height")
|
||||
if err := viper.BindPFlag("capture.webcam.height", cmd.PersistentFlags().Lookup("capture.webcam.height")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// microphone
|
||||
cmd.PersistentFlags().Bool("capture.microphone.enabled", true, "enable microphone stream")
|
||||
if err := viper.BindPFlag("capture.microphone.enabled", cmd.PersistentFlags().Lookup("capture.microphone.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("capture.microphone.device", "audio_input", "pulseaudio device used for microphone")
|
||||
if err := viper.BindPFlag("capture.microphone.device", cmd.PersistentFlags().Lookup("capture.microphone.device")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Capture) Set() {
|
||||
var ok bool
|
||||
|
||||
// Display is provided by env variable
|
||||
s.Display = os.Getenv("DISPLAY")
|
||||
|
||||
// video
|
||||
videoCodec := viper.GetString("capture.video.codec")
|
||||
s.VideoCodec, ok = codec.ParseStr(videoCodec)
|
||||
if !ok || !s.VideoCodec.IsVideo() {
|
||||
log.Warn().Str("codec", videoCodec).Msgf("unknown video codec, using Vp8")
|
||||
s.VideoCodec = codec.VP8()
|
||||
}
|
||||
|
||||
s.VideoIDs = viper.GetStringSlice("capture.video.ids")
|
||||
if err := viper.UnmarshalKey("capture.video.pipelines", &s.VideoPipelines, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode(s.VideoPipelines),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse video pipelines")
|
||||
}
|
||||
|
||||
// default video
|
||||
if len(s.VideoPipelines) == 0 {
|
||||
log.Warn().Msgf("no video pipelines specified, using defaults")
|
||||
|
||||
s.VideoCodec = codec.VP8()
|
||||
s.VideoPipelines = map[string]types.VideoConfig{
|
||||
"main": {
|
||||
Fps: "25",
|
||||
GstEncoder: "vp8enc",
|
||||
GstParams: map[string]string{
|
||||
"target-bitrate": "round(3072 * 650)",
|
||||
"cpu-used": "4",
|
||||
"end-usage": "cbr",
|
||||
"threads": "4",
|
||||
"deadline": "1",
|
||||
"undershoot": "95",
|
||||
"buffer-size": "(3072 * 4)",
|
||||
"buffer-initial-size": "(3072 * 2)",
|
||||
"buffer-optimal-size": "(3072 * 3)",
|
||||
"keyframe-max-dist": "25",
|
||||
"min-quantizer": "4",
|
||||
"max-quantizer": "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
s.VideoIDs = []string{"main"}
|
||||
}
|
||||
|
||||
// audio
|
||||
s.AudioDevice = viper.GetString("capture.audio.device")
|
||||
s.AudioPipeline = viper.GetString("capture.audio.pipeline")
|
||||
|
||||
audioCodec := viper.GetString("capture.audio.codec")
|
||||
s.AudioCodec, ok = codec.ParseStr(audioCodec)
|
||||
if !ok || !s.AudioCodec.IsAudio() {
|
||||
log.Warn().Str("codec", audioCodec).Msgf("unknown audio codec, using Opus")
|
||||
s.AudioCodec = codec.Opus()
|
||||
}
|
||||
|
||||
// broadcast
|
||||
s.BroadcastAudioBitrate = viper.GetInt("capture.broadcast.audio_bitrate")
|
||||
s.BroadcastVideoBitrate = viper.GetInt("capture.broadcast.video_bitrate")
|
||||
s.BroadcastPreset = viper.GetString("capture.broadcast.preset")
|
||||
s.BroadcastPipeline = viper.GetString("capture.broadcast.pipeline")
|
||||
s.BroadcastUrl = viper.GetString("capture.broadcast.url")
|
||||
|
||||
// screencast
|
||||
s.ScreencastEnabled = viper.GetBool("capture.screencast.enabled")
|
||||
s.ScreencastRate = viper.GetString("capture.screencast.rate")
|
||||
s.ScreencastQuality = viper.GetString("capture.screencast.quality")
|
||||
s.ScreencastPipeline = viper.GetString("capture.screencast.pipeline")
|
||||
|
||||
// webcam
|
||||
s.WebcamEnabled = viper.GetBool("capture.webcam.enabled")
|
||||
s.WebcamDevice = viper.GetString("capture.webcam.device")
|
||||
s.WebcamWidth = viper.GetInt("capture.webcam.width")
|
||||
s.WebcamHeight = viper.GetInt("capture.webcam.height")
|
||||
|
||||
// microphone
|
||||
s.MicrophoneEnabled = viper.GetBool("capture.microphone.enabled")
|
||||
s.MicrophoneDevice = viper.GetString("capture.microphone.device")
|
||||
}
|
8
server/internal/config/config.go
Normal file
8
server/internal/config/config.go
Normal file
@ -0,0 +1,8 @@
|
||||
package config
|
||||
|
||||
import "github.com/spf13/cobra"
|
||||
|
||||
type Config interface {
|
||||
Init(cmd *cobra.Command) error
|
||||
Set()
|
||||
}
|
91
server/internal/config/desktop.go
Normal file
91
server/internal/config/desktop.go
Normal file
@ -0,0 +1,91 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type Desktop struct {
|
||||
Display string
|
||||
|
||||
ScreenSize types.ScreenSize
|
||||
|
||||
UseInputDriver bool
|
||||
InputSocket string
|
||||
|
||||
Unminimize bool
|
||||
UploadDrop bool
|
||||
FileChooserDialog bool
|
||||
}
|
||||
|
||||
func (Desktop) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().String("desktop.screen", "1280x720@30", "default screen size and framerate")
|
||||
if err := viper.BindPFlag("desktop.screen", cmd.PersistentFlags().Lookup("desktop.screen")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("desktop.input.enabled", true, "whether custom xf86 input driver should be used to handle touchscreen")
|
||||
if err := viper.BindPFlag("desktop.input.enabled", cmd.PersistentFlags().Lookup("desktop.input.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("desktop.input.socket", "/tmp/xf86-input-neko.sock", "socket path for custom xf86 input driver connection")
|
||||
if err := viper.BindPFlag("desktop.input.socket", cmd.PersistentFlags().Lookup("desktop.input.socket")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("desktop.unminimize", true, "automatically unminimize window when it is minimized")
|
||||
if err := viper.BindPFlag("desktop.unminimize", cmd.PersistentFlags().Lookup("desktop.unminimize")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("desktop.upload_drop", true, "whether drop upload is enabled")
|
||||
if err := viper.BindPFlag("desktop.upload_drop", cmd.PersistentFlags().Lookup("desktop.upload_drop")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("desktop.file_chooser_dialog", false, "whether to handle file chooser dialog externally")
|
||||
if err := viper.BindPFlag("desktop.file_chooser_dialog", cmd.PersistentFlags().Lookup("desktop.file_chooser_dialog")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Desktop) Set() {
|
||||
// Display is provided by env variable
|
||||
s.Display = os.Getenv("DISPLAY")
|
||||
|
||||
s.ScreenSize = types.ScreenSize{
|
||||
Width: 1280,
|
||||
Height: 720,
|
||||
Rate: 30,
|
||||
}
|
||||
|
||||
r := regexp.MustCompile(`([0-9]{1,4})x([0-9]{1,4})@([0-9]{1,3})`)
|
||||
res := r.FindStringSubmatch(viper.GetString("desktop.screen"))
|
||||
|
||||
if len(res) > 0 {
|
||||
width, err1 := strconv.ParseInt(res[1], 10, 64)
|
||||
height, err2 := strconv.ParseInt(res[2], 10, 64)
|
||||
rate, err3 := strconv.ParseInt(res[3], 10, 64)
|
||||
|
||||
if err1 == nil && err2 == nil && err3 == nil {
|
||||
s.ScreenSize.Width = int(width)
|
||||
s.ScreenSize.Height = int(height)
|
||||
s.ScreenSize.Rate = int16(rate)
|
||||
}
|
||||
}
|
||||
|
||||
s.UseInputDriver = viper.GetBool("desktop.input.enabled")
|
||||
s.InputSocket = viper.GetString("desktop.input.socket")
|
||||
s.Unminimize = viper.GetBool("desktop.unminimize")
|
||||
s.UploadDrop = viper.GetBool("desktop.upload_drop")
|
||||
s.FileChooserDialog = viper.GetBool("desktop.file_chooser_dialog")
|
||||
}
|
128
server/internal/config/member.go
Normal file
128
server/internal/config/member.go
Normal file
@ -0,0 +1,128 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/demodesk/neko/internal/member/file"
|
||||
"github.com/demodesk/neko/internal/member/multiuser"
|
||||
"github.com/demodesk/neko/internal/member/object"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type Member struct {
|
||||
Provider string
|
||||
|
||||
// providers
|
||||
File file.Config
|
||||
Object object.Config
|
||||
Multiuser multiuser.Config
|
||||
}
|
||||
|
||||
func (Member) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().String("member.provider", "multiuser", "choose member provider")
|
||||
if err := viper.BindPFlag("member.provider", cmd.PersistentFlags().Lookup("member.provider")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// file provider
|
||||
cmd.PersistentFlags().String("member.file.path", "", "member file provider: storage path")
|
||||
if err := viper.BindPFlag("member.file.path", cmd.PersistentFlags().Lookup("member.file.path")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("member.file.hash", true, "member file provider: whether to hash passwords using sha256 (recommended)")
|
||||
if err := viper.BindPFlag("member.file.hash", cmd.PersistentFlags().Lookup("member.file.hash")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// object provider
|
||||
cmd.PersistentFlags().String("member.object.users", "[]", "member object provider: users in JSON format")
|
||||
if err := viper.BindPFlag("member.object.users", cmd.PersistentFlags().Lookup("member.object.users")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// multiuser provider
|
||||
cmd.PersistentFlags().String("member.multiuser.user_password", "neko", "member multiuser provider: user password")
|
||||
if err := viper.BindPFlag("member.multiuser.user_password", cmd.PersistentFlags().Lookup("member.multiuser.user_password")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("member.multiuser.admin_password", "admin", "member multiuser provider: admin password")
|
||||
if err := viper.BindPFlag("member.multiuser.admin_password", cmd.PersistentFlags().Lookup("member.multiuser.admin_password")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("member.multiuser.user_profile", "{}", "member multiuser provider: user profile in JSON format")
|
||||
if err := viper.BindPFlag("member.multiuser.user_profile", cmd.PersistentFlags().Lookup("member.multiuser.user_profile")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("member.multiuser.admin_profile", "{}", "member multiuser provider: admin profile in JSON format")
|
||||
if err := viper.BindPFlag("member.multiuser.admin_profile", cmd.PersistentFlags().Lookup("member.multiuser.admin_profile")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Member) Set() {
|
||||
s.Provider = viper.GetString("member.provider")
|
||||
|
||||
// file provider
|
||||
s.File.Path = viper.GetString("member.file.path")
|
||||
s.File.Hash = viper.GetBool("member.file.hash")
|
||||
|
||||
// object provider
|
||||
if err := viper.UnmarshalKey("member.object.users", &s.Object.Users, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode(s.Object.Users),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse member object users")
|
||||
}
|
||||
|
||||
// multiuser provider
|
||||
s.Multiuser.UserPassword = viper.GetString("member.multiuser.user_password")
|
||||
s.Multiuser.AdminPassword = viper.GetString("member.multiuser.admin_password")
|
||||
|
||||
// default user profile
|
||||
s.Multiuser.UserProfile = types.MemberProfile{
|
||||
IsAdmin: false,
|
||||
CanLogin: true,
|
||||
CanConnect: true,
|
||||
CanWatch: true,
|
||||
CanHost: true,
|
||||
CanShareMedia: true,
|
||||
CanAccessClipboard: true,
|
||||
SendsInactiveCursor: true,
|
||||
CanSeeInactiveCursors: false,
|
||||
}
|
||||
|
||||
// override user profile
|
||||
if err := viper.UnmarshalKey("member.multiuser.user_profile", &s.Multiuser.UserProfile, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode(s.Multiuser.UserProfile),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse member multiuser user profile")
|
||||
}
|
||||
|
||||
// default admin profile
|
||||
s.Multiuser.AdminProfile = types.MemberProfile{
|
||||
IsAdmin: true,
|
||||
CanLogin: true,
|
||||
CanConnect: true,
|
||||
CanWatch: true,
|
||||
CanHost: true,
|
||||
CanShareMedia: true,
|
||||
CanAccessClipboard: true,
|
||||
SendsInactiveCursor: true,
|
||||
CanSeeInactiveCursors: true,
|
||||
}
|
||||
|
||||
// override admin profile
|
||||
if err := viper.UnmarshalKey("member.multiuser.admin_profile", &s.Multiuser.AdminProfile, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode(s.Multiuser.AdminProfile),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse member multiuser admin profile")
|
||||
}
|
||||
}
|
37
server/internal/config/plugins.go
Normal file
37
server/internal/config/plugins.go
Normal file
@ -0,0 +1,37 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Plugins struct {
|
||||
Enabled bool
|
||||
Dir string
|
||||
Required bool
|
||||
}
|
||||
|
||||
func (Plugins) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().Bool("plugins.enabled", false, "load plugins in runtime")
|
||||
if err := viper.BindPFlag("plugins.enabled", cmd.PersistentFlags().Lookup("plugins.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("plugins.dir", "./bin/plugins", "path to neko plugins to load")
|
||||
if err := viper.BindPFlag("plugins.dir", cmd.PersistentFlags().Lookup("plugins.dir")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("plugins.required", false, "if true, neko will exit if there is an error when loading a plugin")
|
||||
if err := viper.BindPFlag("plugins.required", cmd.PersistentFlags().Lookup("plugins.required")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Plugins) Set() {
|
||||
s.Enabled = viper.GetBool("plugins.enabled")
|
||||
s.Dir = viper.GetString("plugins.dir")
|
||||
s.Required = viper.GetBool("plugins.required")
|
||||
}
|
97
server/internal/config/root.go
Normal file
97
server/internal/config/root.go
Normal file
@ -0,0 +1,97 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Root struct {
|
||||
Config string
|
||||
|
||||
LogLevel zerolog.Level
|
||||
LogTime string
|
||||
LogJson bool
|
||||
LogNocolor bool
|
||||
LogDir string
|
||||
}
|
||||
|
||||
func (Root) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().StringP("config", "c", "", "configuration file path")
|
||||
if err := viper.BindPFlag("config", cmd.PersistentFlags().Lookup("config")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// just a shortcut
|
||||
cmd.PersistentFlags().BoolP("debug", "d", false, "enable debug mode")
|
||||
if err := viper.BindPFlag("debug", cmd.PersistentFlags().Lookup("debug")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("log.level", "info", "set log level (trace, debug, info, warn, error, fatal, panic, disabled)")
|
||||
if err := viper.BindPFlag("log.level", cmd.PersistentFlags().Lookup("log.level")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("log.time", "unix", "time format used in logs (unix, unixms, unixmicro)")
|
||||
if err := viper.BindPFlag("log.time", cmd.PersistentFlags().Lookup("log.time")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("log.json", false, "logs in JSON format")
|
||||
if err := viper.BindPFlag("log.json", cmd.PersistentFlags().Lookup("log.json")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("log.nocolor", false, "no ANSI colors in non-JSON output")
|
||||
if err := viper.BindPFlag("log.nocolor", cmd.PersistentFlags().Lookup("log.nocolor")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("log.dir", "", "logging directory to store logs")
|
||||
if err := viper.BindPFlag("log.dir", cmd.PersistentFlags().Lookup("log.dir")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Root) Set() {
|
||||
s.Config = viper.GetString("config")
|
||||
|
||||
logLevel := viper.GetString("log.level")
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("unknown log level %s", logLevel)
|
||||
} else {
|
||||
s.LogLevel = level
|
||||
}
|
||||
|
||||
logTime := viper.GetString("log.time")
|
||||
switch logTime {
|
||||
case "unix":
|
||||
s.LogTime = zerolog.TimeFormatUnix
|
||||
case "unixms":
|
||||
s.LogTime = zerolog.TimeFormatUnixMs
|
||||
case "unixmicro":
|
||||
s.LogTime = zerolog.TimeFormatUnixMicro
|
||||
default:
|
||||
log.Warn().Msgf("unknown log time %s", logTime)
|
||||
}
|
||||
|
||||
s.LogJson = viper.GetBool("log.json")
|
||||
s.LogNocolor = viper.GetBool("log.nocolor")
|
||||
s.LogDir = viper.GetString("log.dir")
|
||||
|
||||
if viper.GetBool("debug") && s.LogLevel != zerolog.TraceLevel {
|
||||
s.LogLevel = zerolog.DebugLevel
|
||||
}
|
||||
|
||||
// support for NO_COLOR env variable: https://no-color.org/
|
||||
if os.Getenv("NO_COLOR") != "" {
|
||||
s.LogNocolor = true
|
||||
}
|
||||
}
|
103
server/internal/config/server.go
Normal file
103
server/internal/config/server.go
Normal file
@ -0,0 +1,103 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Cert string
|
||||
Key string
|
||||
Bind string
|
||||
Proxy bool
|
||||
Static string
|
||||
PathPrefix string
|
||||
PProf bool
|
||||
Metrics bool
|
||||
CORS []string
|
||||
}
|
||||
|
||||
func (Server) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().String("server.bind", "127.0.0.1:8080", "address/port/socket to serve neko")
|
||||
if err := viper.BindPFlag("server.bind", cmd.PersistentFlags().Lookup("server.bind")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("server.cert", "", "path to the SSL cert used to secure the neko server")
|
||||
if err := viper.BindPFlag("server.cert", cmd.PersistentFlags().Lookup("server.cert")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("server.key", "", "path to the SSL key used to secure the neko server")
|
||||
if err := viper.BindPFlag("server.key", cmd.PersistentFlags().Lookup("server.key")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("server.proxy", false, "trust reverse proxy headers")
|
||||
if err := viper.BindPFlag("server.proxy", cmd.PersistentFlags().Lookup("server.proxy")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("server.static", "", "path to neko client files to serve")
|
||||
if err := viper.BindPFlag("server.static", cmd.PersistentFlags().Lookup("server.static")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("server.path_prefix", "/", "path prefix for HTTP requests")
|
||||
if err := viper.BindPFlag("server.path_prefix", cmd.PersistentFlags().Lookup("server.path_prefix")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("server.pprof", false, "enable pprof endpoint available at /debug/pprof")
|
||||
if err := viper.BindPFlag("server.pprof", cmd.PersistentFlags().Lookup("server.pprof")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("server.metrics", true, "enable prometheus metrics available at /metrics")
|
||||
if err := viper.BindPFlag("server.metrics", cmd.PersistentFlags().Lookup("server.metrics")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSlice("server.cors", []string{}, "list of allowed origins for CORS, if empty CORS is disabled, if '*' is present all origins are allowed")
|
||||
if err := viper.BindPFlag("server.cors", cmd.PersistentFlags().Lookup("server.cors")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Set() {
|
||||
s.Cert = viper.GetString("server.cert")
|
||||
s.Key = viper.GetString("server.key")
|
||||
s.Bind = viper.GetString("server.bind")
|
||||
s.Proxy = viper.GetBool("server.proxy")
|
||||
s.Static = viper.GetString("server.static")
|
||||
s.PathPrefix = path.Join("/", path.Clean(viper.GetString("server.path_prefix")))
|
||||
s.PProf = viper.GetBool("server.pprof")
|
||||
s.Metrics = viper.GetBool("server.metrics")
|
||||
|
||||
s.CORS = viper.GetStringSlice("server.cors")
|
||||
in, _ := utils.ArrayIn("*", s.CORS)
|
||||
if len(s.CORS) == 0 || in {
|
||||
s.CORS = []string{"*"}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) HasCors() bool {
|
||||
return len(s.CORS) > 0
|
||||
}
|
||||
|
||||
func (s *Server) AllowOrigin(origin string) bool {
|
||||
// if CORS is disabled, allow all origins
|
||||
if len(s.CORS) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// if CORS is enabled, allow only origins in the list
|
||||
in, _ := utils.ArrayIn(origin, s.CORS)
|
||||
return in || s.CORS[0] == "*"
|
||||
}
|
114
server/internal/config/session.go
Normal file
114
server/internal/config/session.go
Normal file
@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
File string
|
||||
|
||||
PrivateMode bool
|
||||
LockedLogins bool
|
||||
LockedControls bool
|
||||
ControlProtection bool
|
||||
ImplicitHosting bool
|
||||
InactiveCursors bool
|
||||
MercifulReconnect bool
|
||||
APIToken string
|
||||
|
||||
CookieEnabled bool
|
||||
CookieName string
|
||||
CookieExpiration time.Duration
|
||||
CookieSecure bool
|
||||
}
|
||||
|
||||
func (Session) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().String("session.file", "", "if sessions should be stored in a file, otherwise they will be stored only in memory")
|
||||
if err := viper.BindPFlag("session.file", cmd.PersistentFlags().Lookup("session.file")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.private_mode", false, "whether private mode should be enabled initially")
|
||||
if err := viper.BindPFlag("session.private_mode", cmd.PersistentFlags().Lookup("session.private_mode")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.locked_logins", false, "whether logins should be locked for users initially")
|
||||
if err := viper.BindPFlag("session.locked_logins", cmd.PersistentFlags().Lookup("session.locked_logins")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.locked_controls", false, "whether controls should be locked for users initially")
|
||||
if err := viper.BindPFlag("session.locked_controls", cmd.PersistentFlags().Lookup("session.locked_controls")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.control_protection", false, "users can gain control only if at least one admin is in the room")
|
||||
if err := viper.BindPFlag("session.control_protection", cmd.PersistentFlags().Lookup("session.control_protection")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.implicit_hosting", true, "allow implicit control switching")
|
||||
if err := viper.BindPFlag("session.implicit_hosting", cmd.PersistentFlags().Lookup("session.implicit_hosting")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.inactive_cursors", false, "show inactive cursors on the screen")
|
||||
if err := viper.BindPFlag("session.inactive_cursors", cmd.PersistentFlags().Lookup("session.inactive_cursors")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.merciful_reconnect", true, "allow reconnecting to websocket even if previous connection was not closed")
|
||||
if err := viper.BindPFlag("session.merciful_reconnect", cmd.PersistentFlags().Lookup("session.merciful_reconnect")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("session.api_token", "", "API token for interacting with external services")
|
||||
if err := viper.BindPFlag("session.api_token", cmd.PersistentFlags().Lookup("session.api_token")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cookie
|
||||
cmd.PersistentFlags().Bool("session.cookie.enabled", true, "whether cookies authentication should be enabled")
|
||||
if err := viper.BindPFlag("session.cookie.enabled", cmd.PersistentFlags().Lookup("session.cookie.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("session.cookie.name", "NEKO_SESSION", "name of the cookie that holds token")
|
||||
if err := viper.BindPFlag("session.cookie.name", cmd.PersistentFlags().Lookup("session.cookie.name")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("session.cookie.expiration", 365*24, "expiration of the cookie in hours")
|
||||
if err := viper.BindPFlag("session.cookie.expiration", cmd.PersistentFlags().Lookup("session.cookie.expiration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("session.cookie.secure", true, "use secure cookies")
|
||||
if err := viper.BindPFlag("session.cookie.secure", cmd.PersistentFlags().Lookup("session.cookie.secure")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) Set() {
|
||||
s.File = viper.GetString("session.file")
|
||||
|
||||
s.PrivateMode = viper.GetBool("session.private_mode")
|
||||
s.LockedLogins = viper.GetBool("session.locked_logins")
|
||||
s.LockedControls = viper.GetBool("session.locked_controls")
|
||||
s.ControlProtection = viper.GetBool("session.control_protection")
|
||||
s.ImplicitHosting = viper.GetBool("session.implicit_hosting")
|
||||
s.InactiveCursors = viper.GetBool("session.inactive_cursors")
|
||||
s.MercifulReconnect = viper.GetBool("session.merciful_reconnect")
|
||||
s.APIToken = viper.GetString("session.api_token")
|
||||
|
||||
s.CookieEnabled = viper.GetBool("session.cookie.enabled")
|
||||
s.CookieName = viper.GetString("session.cookie.name")
|
||||
s.CookieExpiration = time.Duration(viper.GetInt("session.cookie.expiration")) * time.Hour
|
||||
s.CookieSecure = viper.GetBool("session.cookie.secure")
|
||||
}
|
273
server/internal/config/webrtc.go
Normal file
273
server/internal/config/webrtc.go
Normal file
@ -0,0 +1,273 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
// default stun server
|
||||
const defStunSrv = "stun:stun.l.google.com:19302"
|
||||
|
||||
type WebRTCEstimator struct {
|
||||
Enabled bool
|
||||
Passive bool
|
||||
Debug bool
|
||||
InitialBitrate int
|
||||
|
||||
// how often to read and process bandwidth estimation reports
|
||||
ReadInterval time.Duration
|
||||
// how long to wait for stable connection (only neutral or upward trend) before upgrading
|
||||
StableDuration time.Duration
|
||||
// how long to wait for unstable connection (downward trend) before downgrading
|
||||
UnstableDuration time.Duration
|
||||
// how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading
|
||||
StalledDuration time.Duration
|
||||
// how long to wait before downgrading again after previous downgrade
|
||||
DowngradeBackoff time.Duration
|
||||
// how long to wait before upgrading again after previous upgrade
|
||||
UpgradeBackoff time.Duration
|
||||
// how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade
|
||||
DiffThreshold float64
|
||||
}
|
||||
|
||||
type WebRTC struct {
|
||||
ICELite bool
|
||||
ICETrickle bool
|
||||
ICEServersFrontend []types.ICEServer
|
||||
ICEServersBackend []types.ICEServer
|
||||
EphemeralMin uint16
|
||||
EphemeralMax uint16
|
||||
TCPMux int
|
||||
UDPMux int
|
||||
|
||||
NAT1To1IPs []string
|
||||
IpRetrievalUrl string
|
||||
|
||||
Estimator WebRTCEstimator
|
||||
}
|
||||
|
||||
func (WebRTC) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().Bool("webrtc.icelite", false, "configures whether or not the ICE agent should be a lite agent")
|
||||
if err := viper.BindPFlag("webrtc.icelite", cmd.PersistentFlags().Lookup("webrtc.icelite")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("webrtc.icetrickle", true, "configures whether cadidates should be sent asynchronously using Trickle ICE")
|
||||
if err := viper.BindPFlag("webrtc.icetrickle", cmd.PersistentFlags().Lookup("webrtc.icetrickle")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Looks like this is conflicting with the frontend and backend ICE servers since latest versions
|
||||
//cmd.PersistentFlags().String("webrtc.iceservers", "[]", "Global STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
|
||||
//if err := viper.BindPFlag("webrtc.iceservers", cmd.PersistentFlags().Lookup("webrtc.iceservers")); err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
cmd.PersistentFlags().String("webrtc.iceservers.frontend", "[]", "Frontend only STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
|
||||
if err := viper.BindPFlag("webrtc.iceservers.frontend", cmd.PersistentFlags().Lookup("webrtc.iceservers.frontend")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("webrtc.iceservers.backend", "[]", "Backend only STUN and TURN servers in JSON format with `urls`, `username` and `credential` keys")
|
||||
if err := viper.BindPFlag("webrtc.iceservers.backend", cmd.PersistentFlags().Lookup("webrtc.iceservers.backend")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("webrtc.epr", "", "limits the pool of ephemeral ports that ICE UDP connections can allocate from")
|
||||
if err := viper.BindPFlag("webrtc.epr", cmd.PersistentFlags().Lookup("webrtc.epr")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("webrtc.tcpmux", 0, "single TCP mux port for all peers")
|
||||
if err := viper.BindPFlag("webrtc.tcpmux", cmd.PersistentFlags().Lookup("webrtc.tcpmux")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("webrtc.udpmux", 0, "single UDP mux port for all peers, replaces EPR")
|
||||
if err := viper.BindPFlag("webrtc.udpmux", cmd.PersistentFlags().Lookup("webrtc.udpmux")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSlice("webrtc.nat1to1", []string{}, "sets a list of external IP addresses of 1:1 (D)NAT and a candidate type for which the external IP address is used")
|
||||
if err := viper.BindPFlag("webrtc.nat1to1", cmd.PersistentFlags().Lookup("webrtc.nat1to1")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("webrtc.ip_retrieval_url", "https://checkip.amazonaws.com", "URL address used for retrieval of the external IP address")
|
||||
if err := viper.BindPFlag("webrtc.ip_retrieval_url", cmd.PersistentFlags().Lookup("webrtc.ip_retrieval_url")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// bandwidth estimator
|
||||
|
||||
cmd.PersistentFlags().Bool("webrtc.estimator.enabled", false, "enables the bandwidth estimator")
|
||||
if err := viper.BindPFlag("webrtc.estimator.enabled", cmd.PersistentFlags().Lookup("webrtc.estimator.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("webrtc.estimator.passive", false, "passive estimator mode, when it does not switch pipelines, only estimates")
|
||||
if err := viper.BindPFlag("webrtc.estimator.passive", cmd.PersistentFlags().Lookup("webrtc.estimator.passive")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Bool("webrtc.estimator.debug", false, "enables debug logging for the bandwidth estimator")
|
||||
if err := viper.BindPFlag("webrtc.estimator.debug", cmd.PersistentFlags().Lookup("webrtc.estimator.debug")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Int("webrtc.estimator.initial_bitrate", 1_000_000, "initial bitrate for the bandwidth estimator")
|
||||
if err := viper.BindPFlag("webrtc.estimator.initial_bitrate", cmd.PersistentFlags().Lookup("webrtc.estimator.initial_bitrate")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.read_interval", 2*time.Second, "how often to read and process bandwidth estimation reports")
|
||||
if err := viper.BindPFlag("webrtc.estimator.read_interval", cmd.PersistentFlags().Lookup("webrtc.estimator.read_interval")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.stable_duration", 12*time.Second, "how long to wait for stable connection (upward or neutral trend) before upgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.stable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stable_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.unstable_duration", 6*time.Second, "how long to wait for stalled connection (neutral trend with low bandwidth) before downgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.unstable_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.unstable_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.stalled_duration", 24*time.Second, "how long to wait for stalled bandwidth estimation before downgrading")
|
||||
if err := viper.BindPFlag("webrtc.estimator.stalled_duration", cmd.PersistentFlags().Lookup("webrtc.estimator.stalled_duration")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.downgrade_backoff", 10*time.Second, "how long to wait before downgrading again after previous downgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.downgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.downgrade_backoff")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("webrtc.estimator.upgrade_backoff", 5*time.Second, "how long to wait before upgrading again after previous upgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.upgrade_backoff", cmd.PersistentFlags().Lookup("webrtc.estimator.upgrade_backoff")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Float64("webrtc.estimator.diff_threshold", 0.15, "how bigger the difference between estimated and stream bitrate must be to trigger upgrade/downgrade")
|
||||
if err := viper.BindPFlag("webrtc.estimator.diff_threshold", cmd.PersistentFlags().Lookup("webrtc.estimator.diff_threshold")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebRTC) Set() {
|
||||
s.ICELite = viper.GetBool("webrtc.icelite")
|
||||
s.ICETrickle = viper.GetBool("webrtc.icetrickle")
|
||||
|
||||
// parse frontend ice servers
|
||||
if err := viper.UnmarshalKey("webrtc.iceservers.frontend", &s.ICEServersFrontend, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode([]types.ICEServer{}),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse frontend ICE servers")
|
||||
}
|
||||
|
||||
// parse backend ice servers
|
||||
if err := viper.UnmarshalKey("webrtc.iceservers.backend", &s.ICEServersBackend, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode([]types.ICEServer{}),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse backend ICE servers")
|
||||
}
|
||||
|
||||
if s.ICELite && len(s.ICEServersBackend) > 0 {
|
||||
log.Warn().Msgf("ICE Lite is enabled, but backend ICE servers are configured. Backend ICE servers will be ignored.")
|
||||
}
|
||||
|
||||
// if no frontend or backend ice servers are configured
|
||||
if len(s.ICEServersFrontend) == 0 && len(s.ICEServersBackend) == 0 {
|
||||
// parse global ice servers
|
||||
var iceServers []types.ICEServer
|
||||
if err := viper.UnmarshalKey("webrtc.iceservers", &iceServers, viper.DecodeHook(
|
||||
utils.JsonStringAutoDecode([]types.ICEServer{}),
|
||||
)); err != nil {
|
||||
log.Warn().Err(err).Msgf("unable to parse global ICE servers")
|
||||
}
|
||||
|
||||
// add default stun server if none are configured
|
||||
if len(iceServers) == 0 {
|
||||
iceServers = append(iceServers, types.ICEServer{
|
||||
URLs: []string{defStunSrv},
|
||||
})
|
||||
}
|
||||
|
||||
s.ICEServersFrontend = append(s.ICEServersFrontend, iceServers...)
|
||||
s.ICEServersBackend = append(s.ICEServersBackend, iceServers...)
|
||||
}
|
||||
|
||||
s.TCPMux = viper.GetInt("webrtc.tcpmux")
|
||||
s.UDPMux = viper.GetInt("webrtc.udpmux")
|
||||
|
||||
epr := viper.GetString("webrtc.epr")
|
||||
if epr != "" {
|
||||
ports := strings.SplitN(epr, "-", -1)
|
||||
if len(ports) > 1 {
|
||||
min, err := strconv.ParseUint(ports[0], 10, 16)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msgf("unable to parse ephemeral min port")
|
||||
}
|
||||
|
||||
max, err := strconv.ParseUint(ports[1], 10, 16)
|
||||
if err != nil {
|
||||
log.Panic().Err(err).Msgf("unable to parse ephemeral max port")
|
||||
}
|
||||
|
||||
s.EphemeralMin = uint16(min)
|
||||
s.EphemeralMax = uint16(max)
|
||||
}
|
||||
|
||||
if s.EphemeralMin > s.EphemeralMax {
|
||||
log.Panic().Msgf("ephemeral min port cannot be bigger than max")
|
||||
}
|
||||
}
|
||||
|
||||
if epr == "" && s.TCPMux == 0 && s.UDPMux == 0 {
|
||||
// using default epr range
|
||||
s.EphemeralMin = 59000
|
||||
s.EphemeralMax = 59100
|
||||
|
||||
log.Warn().
|
||||
Uint16("min", s.EphemeralMin).
|
||||
Uint16("max", s.EphemeralMax).
|
||||
Msgf("no TCP, UDP mux or epr specified, using default epr range")
|
||||
}
|
||||
|
||||
s.NAT1To1IPs = viper.GetStringSlice("webrtc.nat1to1")
|
||||
s.IpRetrievalUrl = viper.GetString("webrtc.ip_retrieval_url")
|
||||
if s.IpRetrievalUrl != "" && len(s.NAT1To1IPs) == 0 {
|
||||
ip, err := utils.HttpRequestGET(s.IpRetrievalUrl)
|
||||
if err == nil {
|
||||
s.NAT1To1IPs = append(s.NAT1To1IPs, ip)
|
||||
} else {
|
||||
log.Warn().Err(err).Msgf("IP retrieval failed")
|
||||
}
|
||||
}
|
||||
|
||||
// bandwidth estimator
|
||||
|
||||
s.Estimator.Enabled = viper.GetBool("webrtc.estimator.enabled")
|
||||
s.Estimator.Passive = viper.GetBool("webrtc.estimator.passive")
|
||||
s.Estimator.Debug = viper.GetBool("webrtc.estimator.debug")
|
||||
s.Estimator.InitialBitrate = viper.GetInt("webrtc.estimator.initial_bitrate")
|
||||
s.Estimator.ReadInterval = viper.GetDuration("webrtc.estimator.read_interval")
|
||||
s.Estimator.StableDuration = viper.GetDuration("webrtc.estimator.stable_duration")
|
||||
s.Estimator.UnstableDuration = viper.GetDuration("webrtc.estimator.unstable_duration")
|
||||
s.Estimator.StalledDuration = viper.GetDuration("webrtc.estimator.stalled_duration")
|
||||
s.Estimator.DowngradeBackoff = viper.GetDuration("webrtc.estimator.downgrade_backoff")
|
||||
s.Estimator.UpgradeBackoff = viper.GetDuration("webrtc.estimator.upgrade_backoff")
|
||||
s.Estimator.DiffThreshold = viper.GetFloat64("webrtc.estimator.diff_threshold")
|
||||
}
|
122
server/internal/desktop/clipboard.go
Normal file
122
server/internal/desktop/clipboard.go
Normal file
@ -0,0 +1,122 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/xevent"
|
||||
)
|
||||
|
||||
func (manager *DesktopManagerCtx) ClipboardGetText() (*types.ClipboardText, error) {
|
||||
text, err := manager.ClipboardGetBinary("STRING")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rich text must not always be available, can fail silently.
|
||||
html, _ := manager.ClipboardGetBinary("text/html")
|
||||
|
||||
return &types.ClipboardText{
|
||||
Text: string(text),
|
||||
HTML: string(html),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ClipboardSetText(data types.ClipboardText) error {
|
||||
// TODO: Refactor.
|
||||
// Current implementation is unable to set multiple targets. HTML
|
||||
// is set, if available. Otherwise plain text.
|
||||
|
||||
if data.HTML != "" {
|
||||
return manager.ClipboardSetBinary("text/html", []byte(data.HTML))
|
||||
}
|
||||
|
||||
return manager.ClipboardSetBinary("STRING", []byte(data.Text))
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ClipboardGetBinary(mime string) ([]byte, error) {
|
||||
cmd := exec.Command("xclip", "-selection", "clipboard", "-out", "-target", mime)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(stderr.String())
|
||||
return nil, fmt.Errorf("%s", msg)
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ClipboardSetBinary(mime string, data []byte) error {
|
||||
cmd := exec.Command("xclip", "-selection", "clipboard", "-in", "-target", mime)
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Refactor.
|
||||
// We need to wait until the data came to the clipboard.
|
||||
wait := make(chan struct{})
|
||||
xevent.Emmiter.Once("clipboard-updated", func(payload ...any) {
|
||||
wait <- struct{}{}
|
||||
})
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(stderr.String())
|
||||
return fmt.Errorf("%s", msg)
|
||||
}
|
||||
|
||||
_, err = stdin.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stdin.Close()
|
||||
|
||||
// TODO: Refactor.
|
||||
// cmd.Wait()
|
||||
<-wait
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ClipboardGetTargets() ([]string, error) {
|
||||
cmd := exec.Command("xclip", "-selection", "clipboard", "-out", "-target", "TARGETS")
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(stderr.String())
|
||||
return nil, fmt.Errorf("%s", msg)
|
||||
}
|
||||
|
||||
var response []string
|
||||
targets := strings.Split(stdout.String(), "\n")
|
||||
for _, target := range targets {
|
||||
if target == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(target, "/") {
|
||||
continue
|
||||
}
|
||||
|
||||
response = append(response, target)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
68
server/internal/desktop/drop.go
Normal file
68
server/internal/desktop/drop.go
Normal file
@ -0,0 +1,68 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/drop"
|
||||
)
|
||||
|
||||
// repeat move event multiple times
|
||||
const dropMoveRepeat = 4
|
||||
|
||||
// wait after each repeated move event
|
||||
const dropMoveDelay = 100 * time.Millisecond
|
||||
|
||||
func (manager *DesktopManagerCtx) DropFiles(x int, y int, files []string) bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
drop.Emmiter.Clear()
|
||||
|
||||
drop.Emmiter.Once("create", func(payload ...any) {
|
||||
manager.Move(0, 0)
|
||||
})
|
||||
|
||||
drop.Emmiter.Once("cursor-enter", func(payload ...any) {
|
||||
//nolint
|
||||
manager.ButtonDown(1)
|
||||
})
|
||||
|
||||
drop.Emmiter.Once("button-press", func(payload ...any) {
|
||||
manager.Move(x, y)
|
||||
})
|
||||
|
||||
drop.Emmiter.Once("begin", func(payload ...any) {
|
||||
for i := 0; i < dropMoveRepeat; i++ {
|
||||
manager.Move(x, y)
|
||||
time.Sleep(dropMoveDelay)
|
||||
}
|
||||
|
||||
//nolint
|
||||
manager.ButtonUp(1)
|
||||
})
|
||||
|
||||
finished := make(chan bool)
|
||||
drop.Emmiter.Once("finish", func(payload ...any) {
|
||||
b, ok := payload[0].(bool)
|
||||
// workaround until https://github.com/kataras/go-events/pull/8 is merged
|
||||
if !ok {
|
||||
b = (payload[0].([]any))[0].(bool)
|
||||
}
|
||||
finished <- b
|
||||
})
|
||||
|
||||
manager.ResetKeys()
|
||||
go drop.OpenWindow(files)
|
||||
|
||||
select {
|
||||
case succeeded := <-finished:
|
||||
return succeeded
|
||||
case <-time.After(1 * time.Second):
|
||||
drop.CloseWindow()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) IsUploadDropEnabled() bool {
|
||||
return manager.config.UploadDrop
|
||||
}
|
102
server/internal/desktop/filechooserdialog.go
Normal file
102
server/internal/desktop/filechooserdialog.go
Normal file
@ -0,0 +1,102 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
|
||||
"github.com/demodesk/neko/pkg/xorg"
|
||||
)
|
||||
|
||||
// name of the window that is being controlled
|
||||
const fileChooserDialogName = "Open File"
|
||||
|
||||
// short sleep value between fake user interactions
|
||||
const fileChooserDialogShortSleep = "0.2"
|
||||
|
||||
// long sleep value between fake user interactions
|
||||
const fileChooserDialogLongSleep = "0.4"
|
||||
|
||||
func (manager *DesktopManagerCtx) HandleFileChooserDialog(uri string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// TODO: Use native API.
|
||||
err1 := exec.Command(
|
||||
"xdotool",
|
||||
"search", "--name", fileChooserDialogName, "windowfocus",
|
||||
"sleep", fileChooserDialogShortSleep,
|
||||
"key", "--clearmodifiers", "ctrl+l",
|
||||
"type", "--args", "1", uri+"//",
|
||||
"sleep", fileChooserDialogShortSleep,
|
||||
"key", "Delete", // remove autocomplete results
|
||||
"sleep", fileChooserDialogShortSleep,
|
||||
"key", "Return",
|
||||
"sleep", fileChooserDialogLongSleep,
|
||||
"key", "Down",
|
||||
"key", "--clearmodifiers", "ctrl+a",
|
||||
"key", "Return",
|
||||
"sleep", fileChooserDialogLongSleep,
|
||||
).Run()
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
// TODO: Use native API.
|
||||
err2 := exec.Command(
|
||||
"xdotool",
|
||||
"search", "--name", fileChooserDialogName,
|
||||
).Run()
|
||||
|
||||
// if last command didn't return error, consider dialog as still open
|
||||
if err2 == nil {
|
||||
return errors.New("unable to select files in dialog")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) CloseFileChooserDialog() {
|
||||
for i := 0; i < 5; i++ {
|
||||
mu.Lock()
|
||||
|
||||
manager.logger.Debug().Msg("attempting to close file chooser dialog")
|
||||
|
||||
// TODO: Use native API.
|
||||
err := exec.Command(
|
||||
"xdotool",
|
||||
"search", "--name", fileChooserDialogName, "windowfocus",
|
||||
).Run()
|
||||
|
||||
if err != nil {
|
||||
mu.Unlock()
|
||||
manager.logger.Info().Msg("file chooser dialog is closed")
|
||||
return
|
||||
}
|
||||
|
||||
// custom press Alt + F4
|
||||
// because xdotool is failing to send proper Alt+F4
|
||||
|
||||
//nolint
|
||||
manager.KeyPress(xorg.XK_Alt_L, xorg.XK_F4)
|
||||
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) IsFileChooserDialogEnabled() bool {
|
||||
return manager.config.FileChooserDialog
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) IsFileChooserDialogOpened() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// TODO: Use native API.
|
||||
err := exec.Command(
|
||||
"xdotool",
|
||||
"search", "--name", fileChooserDialogName,
|
||||
).Run()
|
||||
|
||||
return err == nil
|
||||
}
|
138
server/internal/desktop/manager.go
Normal file
138
server/internal/desktop/manager.go
Normal file
@ -0,0 +1,138 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/go-events"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/xevent"
|
||||
"github.com/demodesk/neko/pkg/xinput"
|
||||
"github.com/demodesk/neko/pkg/xorg"
|
||||
)
|
||||
|
||||
var mu = sync.Mutex{}
|
||||
|
||||
type DesktopManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
wg sync.WaitGroup
|
||||
shutdown chan struct{}
|
||||
emmiter events.EventEmmiter
|
||||
config *config.Desktop
|
||||
screenSize types.ScreenSize // cached screen size
|
||||
input xinput.Driver
|
||||
}
|
||||
|
||||
func New(config *config.Desktop) *DesktopManagerCtx {
|
||||
var input xinput.Driver
|
||||
if config.UseInputDriver {
|
||||
input = xinput.NewDriver(config.InputSocket)
|
||||
} else {
|
||||
input = xinput.NewDummy()
|
||||
}
|
||||
|
||||
return &DesktopManagerCtx{
|
||||
logger: log.With().Str("module", "desktop").Logger(),
|
||||
shutdown: make(chan struct{}),
|
||||
emmiter: events.New(),
|
||||
config: config,
|
||||
screenSize: config.ScreenSize,
|
||||
input: input,
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) Start() {
|
||||
if xorg.DisplayOpen(manager.config.Display) {
|
||||
manager.logger.Panic().Str("display", manager.config.Display).Msg("unable to open display")
|
||||
}
|
||||
|
||||
// X11 can throw errors below, and the default error handler exits
|
||||
xevent.SetupErrorHandler()
|
||||
|
||||
xorg.GetScreenConfigurations()
|
||||
|
||||
screenSize, err := xorg.ChangeScreenSize(manager.config.ScreenSize)
|
||||
if err != nil {
|
||||
manager.logger.Err(err).
|
||||
Str("screen_size", screenSize.String()).
|
||||
Msgf("unable to set initial screen size")
|
||||
} else {
|
||||
// cache screen size
|
||||
manager.screenSize = screenSize
|
||||
manager.logger.Info().
|
||||
Str("screen_size", screenSize.String()).
|
||||
Msgf("setting initial screen size")
|
||||
}
|
||||
|
||||
err = manager.input.Connect()
|
||||
if err != nil {
|
||||
// TODO: fail silently to dummy driver?
|
||||
manager.logger.Panic().Err(err).Msg("unable to connect to input driver")
|
||||
}
|
||||
|
||||
// set up event listeners
|
||||
xevent.Unminimize = manager.config.Unminimize
|
||||
xevent.FileChooserDialog = manager.config.FileChooserDialog
|
||||
go xevent.EventLoop(manager.config.Display)
|
||||
|
||||
// in case it was opened
|
||||
if manager.config.FileChooserDialog {
|
||||
go manager.CloseFileChooserDialog()
|
||||
}
|
||||
|
||||
manager.OnEventError(func(error_code uint8, message string, request_code uint8, minor_code uint8) {
|
||||
manager.logger.Warn().
|
||||
Uint8("error_code", error_code).
|
||||
Str("message", message).
|
||||
Uint8("request_code", request_code).
|
||||
Uint8("minor_code", minor_code).
|
||||
Msg("X event error occured")
|
||||
})
|
||||
|
||||
manager.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer manager.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
const debounceDuration = 10 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-manager.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
xorg.CheckKeys(debounceDuration)
|
||||
manager.input.Debounce(debounceDuration)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnBeforeScreenSizeChange(listener func()) {
|
||||
manager.emmiter.On("before_screen_size_change", func(payload ...any) {
|
||||
listener()
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnAfterScreenSizeChange(listener func()) {
|
||||
manager.emmiter.On("after_screen_size_change", func(payload ...any) {
|
||||
listener()
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msgf("shutdown")
|
||||
|
||||
close(manager.shutdown)
|
||||
manager.wg.Wait()
|
||||
|
||||
xorg.DisplayClose()
|
||||
return nil
|
||||
}
|
35
server/internal/desktop/xevent.go
Normal file
35
server/internal/desktop/xevent.go
Normal file
@ -0,0 +1,35 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/xevent"
|
||||
)
|
||||
|
||||
func (manager *DesktopManagerCtx) OnCursorChanged(listener func(serial uint64)) {
|
||||
xevent.Emmiter.On("cursor-changed", func(payload ...any) {
|
||||
listener(payload[0].(uint64))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnClipboardUpdated(listener func()) {
|
||||
xevent.Emmiter.On("clipboard-updated", func(payload ...any) {
|
||||
listener()
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnFileChooserDialogOpened(listener func()) {
|
||||
xevent.Emmiter.On("file-chooser-dialog-opened", func(payload ...any) {
|
||||
listener()
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnFileChooserDialogClosed(listener func()) {
|
||||
xevent.Emmiter.On("file-chooser-dialog-closed", func(payload ...any) {
|
||||
listener()
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) OnEventError(listener func(error_code uint8, message string, request_code uint8, minor_code uint8)) {
|
||||
xevent.Emmiter.On("event-error", func(payload ...any) {
|
||||
listener(payload[0].(uint8), payload[1].(string), payload[2].(uint8), payload[3].(uint8))
|
||||
})
|
||||
}
|
36
server/internal/desktop/xinput.go
Normal file
36
server/internal/desktop/xinput.go
Normal file
@ -0,0 +1,36 @@
|
||||
package desktop
|
||||
|
||||
import "github.com/demodesk/neko/pkg/xinput"
|
||||
|
||||
func (manager *DesktopManagerCtx) inputRelToAbs(x, y int) (int, int) {
|
||||
return (x * xinput.AbsX) / manager.screenSize.Width, (y * xinput.AbsY) / manager.screenSize.Height
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) HasTouchSupport() bool {
|
||||
// we assume now, that if the input driver is enabled, we have touch support
|
||||
return manager.config.UseInputDriver
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) TouchBegin(touchId uint32, x, y int, pressure uint8) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
x, y = manager.inputRelToAbs(x, y)
|
||||
return manager.input.TouchBegin(touchId, x, y, pressure)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) TouchUpdate(touchId uint32, x, y int, pressure uint8) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
x, y = manager.inputRelToAbs(x, y)
|
||||
return manager.input.TouchUpdate(touchId, x, y, pressure)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) TouchEnd(touchId uint32, x, y int, pressure uint8) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
x, y = manager.inputRelToAbs(x, y)
|
||||
return manager.input.TouchEnd(touchId, x, y, pressure)
|
||||
}
|
202
server/internal/desktop/xorg.go
Normal file
202
server/internal/desktop/xorg.go
Normal file
@ -0,0 +1,202 @@
|
||||
package desktop
|
||||
|
||||
import (
|
||||
"image"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/xorg"
|
||||
)
|
||||
|
||||
func (manager *DesktopManagerCtx) Move(x, y int) {
|
||||
xorg.Move(x, y)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetCursorPosition() (int, int) {
|
||||
return xorg.GetCursorPosition()
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) Scroll(deltaX, deltaY int, controlKey bool) {
|
||||
xorg.Scroll(deltaX, deltaY, controlKey)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ButtonDown(code uint32) error {
|
||||
return xorg.ButtonDown(code)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) KeyDown(code uint32) error {
|
||||
return xorg.KeyDown(code)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ButtonUp(code uint32) error {
|
||||
return xorg.ButtonUp(code)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) KeyUp(code uint32) error {
|
||||
return xorg.KeyUp(code)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ButtonPress(code uint32) error {
|
||||
xorg.ResetKeys()
|
||||
defer xorg.ResetKeys()
|
||||
|
||||
return xorg.ButtonDown(code)
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) KeyPress(codes ...uint32) error {
|
||||
xorg.ResetKeys()
|
||||
defer xorg.ResetKeys()
|
||||
|
||||
for _, code := range codes {
|
||||
if err := xorg.KeyDown(code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(codes) > 1 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ResetKeys() {
|
||||
xorg.ResetKeys()
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) ScreenConfigurations() []types.ScreenSize {
|
||||
var configs []types.ScreenSize
|
||||
for _, size := range xorg.ScreenConfigurations {
|
||||
for _, fps := range size.Rates {
|
||||
// filter out all irrelevant rates
|
||||
if fps > 60 || (fps > 30 && fps%10 != 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
configs = append(configs, types.ScreenSize{
|
||||
Width: size.Width,
|
||||
Height: size.Height,
|
||||
Rate: fps,
|
||||
})
|
||||
}
|
||||
}
|
||||
return configs
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) SetScreenSize(screenSize types.ScreenSize) (types.ScreenSize, error) {
|
||||
mu.Lock()
|
||||
manager.emmiter.Emit("before_screen_size_change")
|
||||
|
||||
defer func() {
|
||||
manager.emmiter.Emit("after_screen_size_change")
|
||||
mu.Unlock()
|
||||
}()
|
||||
|
||||
screenSize, err := xorg.ChangeScreenSize(screenSize)
|
||||
if err == nil {
|
||||
// cache the new screen size
|
||||
manager.screenSize = screenSize
|
||||
}
|
||||
|
||||
return screenSize, err
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetScreenSize() types.ScreenSize {
|
||||
return xorg.GetScreenSize()
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) SetKeyboardMap(kbd types.KeyboardMap) error {
|
||||
// TOOD: Use native API.
|
||||
cmd := exec.Command("setxkbmap", "-layout", kbd.Layout, "-variant", kbd.Variant)
|
||||
_, err := cmd.Output()
|
||||
return err
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetKeyboardMap() (*types.KeyboardMap, error) {
|
||||
// TOOD: Use native API.
|
||||
cmd := exec.Command("setxkbmap", "-query")
|
||||
res, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kbd := types.KeyboardMap{}
|
||||
|
||||
re := regexp.MustCompile(`layout:\s+(.*)\n`)
|
||||
arr := re.FindStringSubmatch(string(res))
|
||||
if len(arr) > 1 {
|
||||
kbd.Layout = arr[1]
|
||||
}
|
||||
|
||||
re = regexp.MustCompile(`variant:\s+(.*)\n`)
|
||||
arr = re.FindStringSubmatch(string(res))
|
||||
if len(arr) > 1 {
|
||||
kbd.Variant = arr[1]
|
||||
}
|
||||
|
||||
return &kbd, nil
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) SetKeyboardModifiers(mod types.KeyboardModifiers) {
|
||||
if mod.Shift != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModShift, *mod.Shift)
|
||||
}
|
||||
|
||||
if mod.CapsLock != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModCapsLock, *mod.CapsLock)
|
||||
}
|
||||
|
||||
if mod.Control != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModControl, *mod.Control)
|
||||
}
|
||||
|
||||
if mod.Alt != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModAlt, *mod.Alt)
|
||||
}
|
||||
|
||||
if mod.NumLock != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModNumLock, *mod.NumLock)
|
||||
}
|
||||
|
||||
if mod.Meta != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModMeta, *mod.Meta)
|
||||
}
|
||||
|
||||
if mod.Super != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModSuper, *mod.Super)
|
||||
}
|
||||
|
||||
if mod.AltGr != nil {
|
||||
xorg.SetKeyboardModifier(xorg.KbdModAltGr, *mod.AltGr)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetKeyboardModifiers() types.KeyboardModifiers {
|
||||
modifiers := xorg.GetKeyboardModifiers()
|
||||
|
||||
isset := func(mod xorg.KbdMod) *bool {
|
||||
x := modifiers&mod != 0
|
||||
return &x
|
||||
}
|
||||
|
||||
return types.KeyboardModifiers{
|
||||
Shift: isset(xorg.KbdModShift),
|
||||
CapsLock: isset(xorg.KbdModCapsLock),
|
||||
Control: isset(xorg.KbdModControl),
|
||||
Alt: isset(xorg.KbdModAlt),
|
||||
NumLock: isset(xorg.KbdModNumLock),
|
||||
Meta: isset(xorg.KbdModMeta),
|
||||
Super: isset(xorg.KbdModSuper),
|
||||
AltGr: isset(xorg.KbdModAltGr),
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetCursorImage() *types.CursorImage {
|
||||
return xorg.GetCursorImage()
|
||||
}
|
||||
|
||||
func (manager *DesktopManagerCtx) GetScreenshotImage() *image.RGBA {
|
||||
return xorg.GetScreenshotImage()
|
||||
}
|
123
server/internal/http/batch.go
Normal file
123
server/internal/http/batch.go
Normal file
@ -0,0 +1,123 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type BatchRequest struct {
|
||||
Path string `json:"path"`
|
||||
Method string `json:"method"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
type BatchResponse struct {
|
||||
Path string `json:"path"`
|
||||
Method string `json:"method"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
func (b *BatchResponse) Error(httpErr *utils.HTTPError) (err error) {
|
||||
b.Body, err = json.Marshal(httpErr)
|
||||
b.Status = httpErr.Code
|
||||
return
|
||||
}
|
||||
|
||||
type batchHandler struct {
|
||||
Router types.Router
|
||||
PathPrefix string
|
||||
Excluded []string
|
||||
}
|
||||
|
||||
func (b *batchHandler) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
var requests []BatchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&requests); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
responses := make([]BatchResponse, len(requests))
|
||||
for i, request := range requests {
|
||||
res := BatchResponse{
|
||||
Path: request.Path,
|
||||
Method: request.Method,
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(request.Path, b.PathPrefix) {
|
||||
res.Error(utils.HttpBadRequest("this path is not allowed in batch requests"))
|
||||
responses[i] = res
|
||||
continue
|
||||
}
|
||||
|
||||
if exists, _ := utils.ArrayIn(request.Path, b.Excluded); exists {
|
||||
res.Error(utils.HttpBadRequest("this path is excluded from batch requests"))
|
||||
responses[i] = res
|
||||
continue
|
||||
}
|
||||
|
||||
// prepare request
|
||||
req, err := http.NewRequest(request.Method, request.Path, bytes.NewBuffer(request.Body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy headers
|
||||
for k, vv := range r.Header {
|
||||
for _, v := range vv {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// execute request
|
||||
rr := newResponseRecorder()
|
||||
b.Router.ServeHTTP(rr, req)
|
||||
|
||||
// read response
|
||||
body, err := io.ReadAll(rr.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write response
|
||||
responses[i] = BatchResponse{
|
||||
Path: request.Path,
|
||||
Method: request.Method,
|
||||
Body: body,
|
||||
Status: rr.Code,
|
||||
}
|
||||
}
|
||||
|
||||
return utils.HttpSuccess(w, responses)
|
||||
}
|
||||
|
||||
type responseRecorder struct {
|
||||
Code int
|
||||
HeaderMap http.Header
|
||||
Body *bytes.Buffer
|
||||
}
|
||||
|
||||
func newResponseRecorder() *responseRecorder {
|
||||
return &responseRecorder{
|
||||
Code: http.StatusOK,
|
||||
HeaderMap: make(http.Header),
|
||||
Body: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseRecorder) Header() http.Header {
|
||||
return w.HeaderMap
|
||||
}
|
||||
|
||||
func (w *responseRecorder) Write(b []byte) (int, error) {
|
||||
return w.Body.Write(b)
|
||||
}
|
||||
|
||||
func (w *responseRecorder) WriteHeader(code int) {
|
||||
w.Code = code
|
||||
}
|
36
server/internal/http/debug.go
Normal file
36
server/internal/http/debug.go
Normal file
@ -0,0 +1,36 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func pprofHandler(r types.Router) {
|
||||
r.Get("/debug/pprof/", func(w http.ResponseWriter, r *http.Request) error {
|
||||
pprof.Index(w, r)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Get("/debug/pprof/{action}", func(w http.ResponseWriter, r *http.Request) error {
|
||||
action := chi.URLParam(r, "action")
|
||||
|
||||
switch action {
|
||||
case "cmdline":
|
||||
pprof.Cmdline(w, r)
|
||||
case "profile":
|
||||
pprof.Profile(w, r)
|
||||
case "symbol":
|
||||
pprof.Symbol(w, r)
|
||||
case "trace":
|
||||
pprof.Trace(w, r)
|
||||
default:
|
||||
pprof.Handler(action).ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
135
server/internal/http/logger.go
Normal file
135
server/internal/http/logger.go
Normal file
@ -0,0 +1,135 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type logFormatter struct {
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (l *logFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
|
||||
// exclude health & metrics from logs
|
||||
if r.RequestURI == "/health" || r.RequestURI == "/metrics" {
|
||||
return &nulllog{}
|
||||
}
|
||||
|
||||
req := map[string]any{}
|
||||
|
||||
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
|
||||
req["id"] = reqID
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
req["scheme"] = scheme
|
||||
req["proto"] = r.Proto
|
||||
req["method"] = r.Method
|
||||
req["remote"] = r.RemoteAddr
|
||||
req["agent"] = r.UserAgent()
|
||||
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
|
||||
|
||||
return &logEntry{
|
||||
logger: l.logger.With().Interface("req", req).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
type logEntry struct {
|
||||
logger zerolog.Logger
|
||||
err error
|
||||
panic *logPanic
|
||||
session types.Session
|
||||
}
|
||||
|
||||
type logPanic struct {
|
||||
message string
|
||||
stack string
|
||||
}
|
||||
|
||||
func (e *logEntry) Panic(v any, stack []byte) {
|
||||
e.panic = &logPanic{
|
||||
message: fmt.Sprintf("%+v", v),
|
||||
stack: string(stack),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *logEntry) Error(err error) {
|
||||
e.err = err
|
||||
}
|
||||
|
||||
func (e *logEntry) SetSession(session types.Session) {
|
||||
e.session = session
|
||||
}
|
||||
|
||||
func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) {
|
||||
res := map[string]any{}
|
||||
res["time"] = time.Now().UTC().Format(time.RFC1123)
|
||||
res["status"] = status
|
||||
res["bytes"] = bytes
|
||||
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
|
||||
|
||||
logger := e.logger.With().Interface("res", res).Logger()
|
||||
|
||||
// add session ID to logs (if exists)
|
||||
if e.session != nil {
|
||||
logger = logger.With().Str("session_id", e.session.ID()).Logger()
|
||||
}
|
||||
|
||||
// handle panic error message
|
||||
if e.panic != nil {
|
||||
logger.WithLevel(zerolog.PanicLevel).
|
||||
Err(e.err).
|
||||
Str("stack", e.panic.stack).
|
||||
Msgf("request failed (%d): %s", status, e.panic.message)
|
||||
return
|
||||
}
|
||||
|
||||
// handle panic error message
|
||||
if e.err != nil {
|
||||
httpErr, ok := e.err.(*utils.HTTPError)
|
||||
if !ok {
|
||||
logger.Err(e.err).Msgf("request failed (%d)", status)
|
||||
return
|
||||
}
|
||||
|
||||
if httpErr.Message == "" {
|
||||
httpErr.Message = http.StatusText(httpErr.Code)
|
||||
}
|
||||
|
||||
var logLevel zerolog.Level
|
||||
if httpErr.Code < 500 {
|
||||
logLevel = zerolog.WarnLevel
|
||||
} else {
|
||||
logLevel = zerolog.ErrorLevel
|
||||
}
|
||||
|
||||
message := httpErr.Message
|
||||
if httpErr.InternalMsg != "" {
|
||||
message = httpErr.InternalMsg
|
||||
}
|
||||
|
||||
logger.WithLevel(logLevel).Err(httpErr.InternalErr).Msgf("request failed (%d): %s", status, message)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug().Msgf("request complete (%d)", status)
|
||||
}
|
||||
|
||||
type nulllog struct{}
|
||||
|
||||
func (e *nulllog) Panic(v any, stack []byte) {}
|
||||
func (e *nulllog) Error(err error) {}
|
||||
func (e *nulllog) SetSession(session types.Session) {}
|
||||
func (e *nulllog) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) {
|
||||
}
|
132
server/internal/http/manager.go
Normal file
132
server/internal/http/manager.go
Normal file
@ -0,0 +1,132 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type HttpManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
config *config.Server
|
||||
router types.Router
|
||||
http *http.Server
|
||||
}
|
||||
|
||||
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
|
||||
logger := log.With().Str("module", "http").Logger()
|
||||
|
||||
opts := []RouterOption{
|
||||
WithRequestID(), // create a request id for each request
|
||||
}
|
||||
|
||||
// use real ip if behind proxy
|
||||
// before logger so it can log the real ip
|
||||
if config.Proxy {
|
||||
opts = append(opts, WithRealIP())
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
WithLogger(logger),
|
||||
WithRecoverer(), // recover from panics without crashing server
|
||||
)
|
||||
|
||||
if config.HasCors() {
|
||||
opts = append(opts, WithCORS(config.AllowOrigin))
|
||||
}
|
||||
|
||||
if config.PathPrefix != "/" {
|
||||
opts = append(opts, WithPathPrefix(config.PathPrefix))
|
||||
}
|
||||
|
||||
router := newRouter(opts...)
|
||||
|
||||
router.Route("/api", ApiManager.Route)
|
||||
|
||||
router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool {
|
||||
return config.AllowOrigin(r.Header.Get("Origin"))
|
||||
}))
|
||||
|
||||
batch := batchHandler{
|
||||
Router: router,
|
||||
PathPrefix: "/api",
|
||||
Excluded: []string{
|
||||
"/api/batch", // do not allow batchception
|
||||
"/api/ws",
|
||||
},
|
||||
}
|
||||
router.Post("/api/batch", batch.Handle)
|
||||
|
||||
router.Get("/health", func(w http.ResponseWriter, r *http.Request) error {
|
||||
_, err := w.Write([]byte("true"))
|
||||
return err
|
||||
})
|
||||
|
||||
if config.Metrics {
|
||||
router.Get("/metrics", func(w http.ResponseWriter, r *http.Request) error {
|
||||
promhttp.Handler().ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if config.Static != "" {
|
||||
fs := http.FileServer(http.Dir(config.Static))
|
||||
router.Get("/*", func(w http.ResponseWriter, r *http.Request) error {
|
||||
_, err := os.Stat(config.Static + r.URL.Path)
|
||||
if err == nil {
|
||||
fs.ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
http.NotFound(w, r)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
if config.PProf {
|
||||
pprofHandler(router)
|
||||
}
|
||||
|
||||
return &HttpManagerCtx{
|
||||
logger: logger,
|
||||
config: config,
|
||||
router: router,
|
||||
http: &http.Server{
|
||||
Addr: config.Bind,
|
||||
Handler: router,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *HttpManagerCtx) Start() {
|
||||
if manager.config.Cert != "" && manager.config.Key != "" {
|
||||
go func() {
|
||||
if err := manager.http.ListenAndServeTLS(manager.config.Cert, manager.config.Key); err != http.ErrServerClosed {
|
||||
manager.logger.Panic().Err(err).Msg("unable to start https server")
|
||||
}
|
||||
}()
|
||||
manager.logger.Info().Msgf("https listening on %s", manager.http.Addr)
|
||||
} else {
|
||||
go func() {
|
||||
if err := manager.http.ListenAndServe(); err != http.ErrServerClosed {
|
||||
manager.logger.Panic().Err(err).Msg("unable to start http server")
|
||||
}
|
||||
}()
|
||||
manager.logger.Info().Msgf("http listening on %s", manager.http.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *HttpManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
return manager.http.Shutdown(context.Background())
|
||||
}
|
172
server/internal/http/router.go
Normal file
172
server/internal/http/router.go
Normal file
@ -0,0 +1,172 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type RouterOption func(*router)
|
||||
|
||||
func WithRequestID() RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(middleware.RequestID)
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger zerolog.Logger) RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(middleware.RequestLogger(&logFormatter{logger}))
|
||||
}
|
||||
}
|
||||
|
||||
func WithRecoverer() RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(middleware.Recoverer)
|
||||
}
|
||||
}
|
||||
|
||||
func WithCORS(allowOrigin func(origin string) bool) RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(cors.Handler(cors.Options{
|
||||
AllowOriginFunc: func(r *http.Request, origin string) bool {
|
||||
return allowOrigin(origin)
|
||||
},
|
||||
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func WithPathPrefix(prefix string) RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(func(h http.Handler) http.Handler {
|
||||
return http.StripPrefix(prefix, h)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func WithRealIP() RouterOption {
|
||||
return func(r *router) {
|
||||
r.chi.Use(middleware.RealIP)
|
||||
}
|
||||
}
|
||||
|
||||
type router struct {
|
||||
chi chi.Router
|
||||
}
|
||||
|
||||
func newRouter(opts ...RouterOption) types.Router {
|
||||
r := &router{chi.NewRouter()}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *router) Group(fn func(types.Router)) {
|
||||
r.chi.Group(func(c chi.Router) {
|
||||
fn(&router{c})
|
||||
})
|
||||
}
|
||||
|
||||
func (r *router) Route(pattern string, fn func(types.Router)) {
|
||||
r.chi.Route(pattern, func(c chi.Router) {
|
||||
fn(&router{c})
|
||||
})
|
||||
}
|
||||
|
||||
func (r *router) Get(pattern string, fn types.RouterHandler) {
|
||||
r.chi.Get(pattern, routeHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) Post(pattern string, fn types.RouterHandler) {
|
||||
r.chi.Post(pattern, routeHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) Put(pattern string, fn types.RouterHandler) {
|
||||
r.chi.Put(pattern, routeHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) Patch(pattern string, fn types.RouterHandler) {
|
||||
r.chi.Patch(pattern, routeHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) Delete(pattern string, fn types.RouterHandler) {
|
||||
r.chi.Delete(pattern, routeHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) With(fn types.MiddlewareHandler) types.Router {
|
||||
c := r.chi.With(middlewareHandler(fn))
|
||||
return &router{c}
|
||||
}
|
||||
|
||||
func (r *router) Use(fn types.MiddlewareHandler) {
|
||||
r.chi.Use(middlewareHandler(fn))
|
||||
}
|
||||
|
||||
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.chi.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func errorHandler(err error, w http.ResponseWriter, r *http.Request) {
|
||||
httpErr, ok := err.(*utils.HTTPError)
|
||||
if !ok {
|
||||
httpErr = utils.HttpInternalServerError().WithInternalErr(err)
|
||||
}
|
||||
|
||||
utils.HttpJsonResponse(w, httpErr.Code, httpErr)
|
||||
}
|
||||
|
||||
func routeHandler(fn types.RouterHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// get custom log entry pointer from context
|
||||
logEntry, _ := r.Context().Value(middleware.LogEntryCtxKey).(*logEntry)
|
||||
|
||||
if err := fn(w, r); err != nil {
|
||||
logEntry.Error(err)
|
||||
errorHandler(err, w, r)
|
||||
}
|
||||
|
||||
// set session if exits
|
||||
if session, ok := auth.GetSession(r); ok {
|
||||
logEntry.SetSession(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// get custom log entry pointer from context
|
||||
logEntry, _ := r.Context().Value(middleware.LogEntryCtxKey).(*logEntry)
|
||||
|
||||
ctx, err := fn(w, r)
|
||||
if err != nil {
|
||||
logEntry.Error(err)
|
||||
errorHandler(err, w, r)
|
||||
|
||||
// set session if exits
|
||||
if session, ok := auth.GetSession(r); ok {
|
||||
logEntry.SetSession(session)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
if ctx != nil {
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
204
server/internal/member/file/provider.go
Normal file
204
server/internal/member/file/provider.go
Normal file
@ -0,0 +1,204 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func New(config Config) types.MemberProvider {
|
||||
return &MemberProviderCtx{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
type MemberProviderCtx struct {
|
||||
config Config
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) hash(password string) string {
|
||||
// if hash is disabled, return password as plain text
|
||||
if !provider.config.Hash {
|
||||
return password
|
||||
}
|
||||
|
||||
sha256 := sha256.New()
|
||||
sha256.Write([]byte(password))
|
||||
hashedPassword := sha256.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(hashedPassword)
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
|
||||
// id will be also username
|
||||
id := username
|
||||
|
||||
entry, err := provider.getEntry(id)
|
||||
if err != nil {
|
||||
return "", types.MemberProfile{}, err
|
||||
}
|
||||
|
||||
if entry.Password != provider.hash(password) {
|
||||
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
|
||||
}
|
||||
|
||||
return id, entry.Profile, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
|
||||
// id will be also username
|
||||
id := username
|
||||
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, ok := entries[id]
|
||||
if ok {
|
||||
return "", types.ErrMemberAlreadyExists
|
||||
}
|
||||
|
||||
entries[id] = MemberEntry{
|
||||
Password: provider.hash(password),
|
||||
Profile: profile,
|
||||
}
|
||||
|
||||
return id, provider.serialize(entries)
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entry, ok := entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
entry.Profile = profile
|
||||
entries[id] = entry
|
||||
|
||||
return provider.serialize(entries)
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entry, ok := entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
entry.Password = provider.hash(password)
|
||||
entries[id] = entry
|
||||
|
||||
return provider.serialize(entries)
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
|
||||
entry, err := provider.getEntry(id)
|
||||
if err != nil {
|
||||
return types.MemberProfile{}, err
|
||||
}
|
||||
|
||||
return entry.Profile, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
|
||||
profiles := map[string]types.MemberProfile{}
|
||||
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
i := 0
|
||||
for id, entry := range entries {
|
||||
if i >= offset && (limit == 0 || i < offset+limit) {
|
||||
profiles[id] = entry.Profile
|
||||
}
|
||||
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Delete(id string) error {
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, ok := entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
delete(entries, id)
|
||||
|
||||
return provider.serialize(entries)
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) deserialize() (map[string]MemberEntry, error) {
|
||||
file, err := os.OpenFile(provider.config.Path, os.O_RDONLY|os.O_CREATE, os.ModePerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(raw) == 0 {
|
||||
return map[string]MemberEntry{}, nil
|
||||
}
|
||||
|
||||
var entries map[string]MemberEntry
|
||||
if err := json.Unmarshal([]byte(raw), &entries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) getEntry(id string) (MemberEntry, error) {
|
||||
entries, err := provider.deserialize()
|
||||
if err != nil {
|
||||
return MemberEntry{}, err
|
||||
}
|
||||
|
||||
entry, ok := entries[id]
|
||||
if !ok {
|
||||
return MemberEntry{}, types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) serialize(data map[string]MemberEntry) error {
|
||||
raw, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(provider.config.Path, raw, os.ModePerm)
|
||||
}
|
48
server/internal/member/file/provider_test.go
Normal file
48
server/internal/member/file/provider_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
// Ensure that hashes are the same after encoding and decoding using json
|
||||
func TestMemberProviderCtx_hash(t *testing.T) {
|
||||
provider := &MemberProviderCtx{
|
||||
config: Config{
|
||||
Hash: true,
|
||||
},
|
||||
}
|
||||
|
||||
// generate random strings
|
||||
passwords := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
password, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
t.Errorf("utils.NewUID() returned error: %s", err)
|
||||
}
|
||||
passwords = append(passwords, password)
|
||||
}
|
||||
|
||||
for _, password := range passwords {
|
||||
hashedPassword := provider.hash(password)
|
||||
|
||||
// json encode password hash
|
||||
hashedPasswordJSON, err := json.Marshal(hashedPassword)
|
||||
if err != nil {
|
||||
t.Errorf("json.Marshal() returned error: %s", err)
|
||||
}
|
||||
|
||||
// json decode password hash json
|
||||
var hashedPasswordStr string
|
||||
err = json.Unmarshal(hashedPasswordJSON, &hashedPasswordStr)
|
||||
if err != nil {
|
||||
t.Errorf("json.Unmarshal() returned error: %s", err)
|
||||
}
|
||||
|
||||
if hashedPasswordStr != hashedPassword {
|
||||
t.Errorf("hashedPasswordStr: %s != hashedPassword: %s", hashedPasswordStr, hashedPassword)
|
||||
}
|
||||
}
|
||||
}
|
15
server/internal/member/file/types.go
Normal file
15
server/internal/member/file/types.go
Normal file
@ -0,0 +1,15 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type MemberEntry struct {
|
||||
Password string `json:"password"`
|
||||
Profile types.MemberProfile `json:"profile"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Path string
|
||||
Hash bool
|
||||
}
|
168
server/internal/member/manager.go
Normal file
168
server/internal/member/manager.go
Normal file
@ -0,0 +1,168 @@
|
||||
package member
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/member/file"
|
||||
"github.com/demodesk/neko/internal/member/multiuser"
|
||||
"github.com/demodesk/neko/internal/member/noauth"
|
||||
"github.com/demodesk/neko/internal/member/object"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func New(sessions types.SessionManager, config *config.Member) *MemberManagerCtx {
|
||||
manager := &MemberManagerCtx{
|
||||
logger: log.With().Str("module", "member").Logger(),
|
||||
sessions: sessions,
|
||||
config: config,
|
||||
}
|
||||
|
||||
switch config.Provider {
|
||||
case "file":
|
||||
manager.provider = file.New(config.File)
|
||||
case "object":
|
||||
manager.provider = object.New(config.Object)
|
||||
case "multiuser":
|
||||
manager.provider = multiuser.New(config.Multiuser)
|
||||
case "noauth":
|
||||
fallthrough
|
||||
default:
|
||||
manager.provider = noauth.New()
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
type MemberManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
sessions types.SessionManager
|
||||
config *config.Member
|
||||
providerMu sync.Mutex
|
||||
provider types.MemberProvider
|
||||
loginMu sync.Mutex
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Connect() error {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.Connect()
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Disconnect() error {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.Disconnect()
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.Authenticate(username, password)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.Insert(username, password, profile)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Select(id string) (types.MemberProfile, error) {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
// get primarily from corresponding session, if exists
|
||||
session, ok := manager.sessions.Get(id)
|
||||
if ok {
|
||||
return session.Profile(), nil
|
||||
}
|
||||
|
||||
return manager.provider.Select(id)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.SelectAll(limit, offset)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) UpdateProfile(id string, profile types.MemberProfile) error {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
// update corresponding session, if exists
|
||||
err := manager.sessions.Update(id, profile)
|
||||
if err != nil && !errors.Is(err, types.ErrSessionNotFound) {
|
||||
manager.logger.Err(err).Msg("error while updating session")
|
||||
}
|
||||
|
||||
return manager.provider.UpdateProfile(id, profile)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) UpdatePassword(id string, password string) error {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
return manager.provider.UpdatePassword(id, password)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Delete(id string) error {
|
||||
manager.providerMu.Lock()
|
||||
defer manager.providerMu.Unlock()
|
||||
|
||||
// destroy corresponding session, if exists
|
||||
err := manager.sessions.Delete(id)
|
||||
if err != nil && !errors.Is(err, types.ErrSessionNotFound) {
|
||||
manager.logger.Err(err).Msg("error while deleting session")
|
||||
}
|
||||
|
||||
return manager.provider.Delete(id)
|
||||
}
|
||||
|
||||
//
|
||||
// member -> session
|
||||
//
|
||||
|
||||
func (manager *MemberManagerCtx) Login(username string, password string) (types.Session, string, error) {
|
||||
manager.loginMu.Lock()
|
||||
defer manager.loginMu.Unlock()
|
||||
|
||||
id, profile, err := manager.provider.Authenticate(username, password)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if !profile.IsAdmin && manager.sessions.Settings().LockedLogins {
|
||||
return nil, "", types.ErrSessionLoginsLocked
|
||||
}
|
||||
|
||||
session, ok := manager.sessions.Get(id)
|
||||
if ok {
|
||||
if session.State().IsConnected {
|
||||
return nil, "", types.ErrSessionAlreadyConnected
|
||||
}
|
||||
|
||||
// TODO: Replace session.
|
||||
if err := manager.sessions.Delete(id); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
return manager.sessions.Create(id, profile)
|
||||
}
|
||||
|
||||
func (manager *MemberManagerCtx) Logout(id string) error {
|
||||
manager.loginMu.Lock()
|
||||
defer manager.loginMu.Unlock()
|
||||
|
||||
return manager.sessions.Delete(id)
|
||||
}
|
82
server/internal/member/multiuser/provider.go
Normal file
82
server/internal/member/multiuser/provider.go
Normal file
@ -0,0 +1,82 @@
|
||||
package multiuser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func New(config Config) types.MemberProvider {
|
||||
return &MemberProviderCtx{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
type MemberProviderCtx struct {
|
||||
config Config
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
|
||||
// generate random token
|
||||
token, err := utils.NewUID(5)
|
||||
if err != nil {
|
||||
return "", types.MemberProfile{}, err
|
||||
}
|
||||
|
||||
// id is username with token
|
||||
id := fmt.Sprintf("%s-%s", username, token)
|
||||
|
||||
// if logged in as administrator
|
||||
if provider.config.AdminPassword == password {
|
||||
profile := provider.config.AdminProfile
|
||||
if profile.Name == "" {
|
||||
profile.Name = username
|
||||
}
|
||||
return id, profile, nil
|
||||
}
|
||||
|
||||
// if logged in as user
|
||||
if provider.config.UserPassword == password {
|
||||
profile := provider.config.UserProfile
|
||||
if profile.Name == "" {
|
||||
profile.Name = username
|
||||
}
|
||||
return id, profile, nil
|
||||
}
|
||||
|
||||
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
|
||||
return "", errors.New("new user is created on first login in multiuser mode")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
|
||||
return errors.New("password can only be modified in config while in multiuser mode")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
|
||||
return types.MemberProfile{}, errors.New("cannot select user in multiuser mode")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
|
||||
return map[string]types.MemberProfile{}, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Delete(id string) error {
|
||||
return errors.New("cannot delete user in multiuser mode")
|
||||
}
|
10
server/internal/member/multiuser/types.go
Normal file
10
server/internal/member/multiuser/types.go
Normal file
@ -0,0 +1,10 @@
|
||||
package multiuser
|
||||
|
||||
import "github.com/demodesk/neko/pkg/types"
|
||||
|
||||
type Config struct {
|
||||
AdminPassword string
|
||||
UserPassword string
|
||||
AdminProfile types.MemberProfile
|
||||
UserProfile types.MemberProfile
|
||||
}
|
75
server/internal/member/noauth/provider.go
Normal file
75
server/internal/member/noauth/provider.go
Normal file
@ -0,0 +1,75 @@
|
||||
package noauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func New() types.MemberProvider {
|
||||
return &MemberProviderCtx{
|
||||
profile: types.MemberProfile{
|
||||
IsAdmin: true,
|
||||
CanLogin: true,
|
||||
CanConnect: true,
|
||||
CanWatch: true,
|
||||
CanHost: true,
|
||||
CanShareMedia: true,
|
||||
CanAccessClipboard: true,
|
||||
SendsInactiveCursor: true,
|
||||
CanSeeInactiveCursors: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type MemberProviderCtx struct {
|
||||
profile types.MemberProfile
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
|
||||
// generate random token
|
||||
token, err := utils.NewUID(5)
|
||||
if err != nil {
|
||||
return "", types.MemberProfile{}, err
|
||||
}
|
||||
|
||||
// id is username with token
|
||||
id := fmt.Sprintf("%s-%s", username, token)
|
||||
|
||||
provider.profile.Name = username
|
||||
return id, provider.profile, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
|
||||
return "", errors.New("new user is created on first login in noauth mode")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
|
||||
return errors.New("noauth mode does not have password")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
|
||||
return types.MemberProfile{}, errors.New("cannot select user in noauth mode")
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
|
||||
return map[string]types.MemberProfile{}, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Delete(id string) error {
|
||||
return errors.New("cannot delete user in noauth mode")
|
||||
}
|
124
server/internal/member/object/provider.go
Normal file
124
server/internal/member/object/provider.go
Normal file
@ -0,0 +1,124 @@
|
||||
package object
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func New(config Config) types.MemberProvider {
|
||||
return &MemberProviderCtx{
|
||||
config: config,
|
||||
entries: make(map[string]*memberEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type MemberProviderCtx struct {
|
||||
config Config
|
||||
entries map[string]*memberEntry
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Connect() error {
|
||||
var err error
|
||||
|
||||
for _, entry := range provider.config.Users {
|
||||
_, err = provider.Insert(entry.Username, entry.Password, entry.Profile)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) {
|
||||
// id will be also username
|
||||
id := username
|
||||
|
||||
entry, ok := provider.entries[id]
|
||||
if !ok {
|
||||
return "", types.MemberProfile{}, types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
// TODO: Use hash function.
|
||||
if !entry.CheckPassword(password) {
|
||||
return "", types.MemberProfile{}, types.ErrMemberInvalidPassword
|
||||
}
|
||||
|
||||
return id, entry.profile, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) {
|
||||
// id will be also username
|
||||
id := username
|
||||
|
||||
_, ok := provider.entries[id]
|
||||
if ok {
|
||||
return "", types.ErrMemberAlreadyExists
|
||||
}
|
||||
|
||||
provider.entries[id] = &memberEntry{
|
||||
// TODO: Use hash function.
|
||||
password: password,
|
||||
profile: profile,
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdateProfile(id string, profile types.MemberProfile) error {
|
||||
entry, ok := provider.entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
entry.profile = profile
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) UpdatePassword(id string, password string) error {
|
||||
entry, ok := provider.entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
// TODO: Use hash function.
|
||||
entry.password = password
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Select(id string) (types.MemberProfile, error) {
|
||||
entry, ok := provider.entries[id]
|
||||
if !ok {
|
||||
return types.MemberProfile{}, types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
return entry.profile, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) {
|
||||
profiles := make(map[string]types.MemberProfile)
|
||||
|
||||
i := 0
|
||||
for id, entry := range provider.entries {
|
||||
if i >= offset && (limit == 0 || i < offset+limit) {
|
||||
profiles[id] = entry.profile
|
||||
}
|
||||
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (provider *MemberProviderCtx) Delete(id string) error {
|
||||
_, ok := provider.entries[id]
|
||||
if !ok {
|
||||
return types.ErrMemberDoesNotExist
|
||||
}
|
||||
|
||||
delete(provider.entries, id)
|
||||
|
||||
return nil
|
||||
}
|
24
server/internal/member/object/types.go
Normal file
24
server/internal/member/object/types.go
Normal file
@ -0,0 +1,24 @@
|
||||
package object
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type memberEntry struct {
|
||||
password string
|
||||
profile types.MemberProfile
|
||||
}
|
||||
|
||||
func (m *memberEntry) CheckPassword(password string) bool {
|
||||
return m.password == password
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
Profile types.MemberProfile
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Users []User
|
||||
}
|
23
server/internal/plugins/chat/config.go
Normal file
23
server/internal/plugins/chat/config.go
Normal file
@ -0,0 +1,23 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func (Config) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().Bool("chat.enabled", true, "whether to enable chat plugin")
|
||||
if err := viper.BindPFlag("chat.enabled", cmd.PersistentFlags().Lookup("chat.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Config) Set() {
|
||||
s.Enabled = viper.GetBool("chat.enabled")
|
||||
}
|
162
server/internal/plugins/chat/manager.go
Normal file
162
server/internal/plugins/chat/manager.go
Normal file
@ -0,0 +1,162 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func NewManager(
|
||||
sessions types.SessionManager,
|
||||
config *Config,
|
||||
) *Manager {
|
||||
logger := log.With().Str("module", "chat").Logger()
|
||||
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
config: config,
|
||||
sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
logger zerolog.Logger
|
||||
config *Config
|
||||
sessions types.SessionManager
|
||||
}
|
||||
|
||||
type Settings struct {
|
||||
CanSend bool `json:"can_send" mapstructure:"can_send"`
|
||||
CanReceive bool `json:"can_receive" mapstructure:"can_receive"`
|
||||
}
|
||||
|
||||
func (m *Manager) settingsForSession(session types.Session) (Settings, error) {
|
||||
settings := Settings{
|
||||
CanSend: true, // defaults to true
|
||||
CanReceive: true, // defaults to true
|
||||
}
|
||||
err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings)
|
||||
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
|
||||
return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err)
|
||||
}
|
||||
|
||||
profile := Settings{
|
||||
CanSend: true, // defaults to true
|
||||
CanReceive: true, // defaults to true
|
||||
}
|
||||
|
||||
err = session.Profile().Plugins.Unmarshal(PluginName, &profile)
|
||||
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
|
||||
return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err)
|
||||
}
|
||||
|
||||
return Settings{
|
||||
CanSend: m.config.Enabled && (settings.CanSend || session.Profile().IsAdmin) && profile.CanSend,
|
||||
CanReceive: m.config.Enabled && (settings.CanReceive || session.Profile().IsAdmin) && profile.CanReceive,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) sendMessage(session types.Session, content Content) {
|
||||
now := time.Now()
|
||||
|
||||
// get all sessions that have chat enabled
|
||||
var sessions []types.Session
|
||||
m.sessions.Range(func(s types.Session) bool {
|
||||
if settings, err := m.settingsForSession(s); err == nil && settings.CanReceive {
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
// continue iteration over all sessions
|
||||
return true
|
||||
})
|
||||
|
||||
// send content to all sessions
|
||||
for _, s := range sessions {
|
||||
s.Send(CHAT_MESSAGE, Message{
|
||||
ID: session.ID(),
|
||||
Created: now,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start() error {
|
||||
// send init message once a user connects
|
||||
m.sessions.OnConnected(func(session types.Session) {
|
||||
session.Send(CHAT_INIT, Init{
|
||||
Enabled: m.config.Enabled,
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Shutdown() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Route(r types.Router) {
|
||||
r.With(auth.AdminsOnly).Post("/", m.sendMessageHandler)
|
||||
}
|
||||
|
||||
func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool {
|
||||
switch msg.Event {
|
||||
case CHAT_MESSAGE:
|
||||
var content Content
|
||||
if err := json.Unmarshal(msg.Payload, &content); err != nil {
|
||||
m.logger.Error().Err(err).Msg("failed to unmarshal chat message")
|
||||
// we processed the message, return true
|
||||
return true
|
||||
}
|
||||
|
||||
settings, err := m.settingsForSession(session)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("error checking chat permissions for this session")
|
||||
// we processed the message, return true
|
||||
return true
|
||||
}
|
||||
if !settings.CanSend {
|
||||
m.logger.Warn().Msg("not allowed to send chat messages")
|
||||
// we processed the message, return true
|
||||
return true
|
||||
}
|
||||
|
||||
m.sendMessage(session, content)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) sendMessageHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
session, ok := auth.GetSession(r)
|
||||
if !ok {
|
||||
return utils.HttpUnauthorized("session not found")
|
||||
}
|
||||
|
||||
settings, err := m.settingsForSession(session)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
Msg("error checking chat permissions for this session")
|
||||
}
|
||||
|
||||
if !settings.CanSend {
|
||||
return utils.HttpForbidden("not allowed to send chat messages")
|
||||
}
|
||||
|
||||
content := Content{}
|
||||
if err := utils.HttpJsonRequest(w, r, &content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.sendMessage(session, content)
|
||||
return utils.HttpSuccess(w)
|
||||
}
|
35
server/internal/plugins/chat/plugin.go
Normal file
35
server/internal/plugins/chat/plugin.go
Normal file
@ -0,0 +1,35 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type Plugin struct {
|
||||
config *Config
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func NewPlugin() *Plugin {
|
||||
return &Plugin{
|
||||
config: &Config{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Plugin) Name() string {
|
||||
return PluginName
|
||||
}
|
||||
|
||||
func (p *Plugin) Config() types.PluginConfig {
|
||||
return p.config
|
||||
}
|
||||
|
||||
func (p *Plugin) Start(m types.PluginManagers) error {
|
||||
p.manager = NewManager(m.SessionManager, p.config)
|
||||
m.ApiManager.AddRouter("/chat", p.manager.Route)
|
||||
m.WebSocketManager.AddHandler(p.manager.WebSocketHandler)
|
||||
return p.manager.Start()
|
||||
}
|
||||
|
||||
func (p *Plugin) Shutdown() error {
|
||||
return p.manager.Shutdown()
|
||||
}
|
24
server/internal/plugins/chat/types.go
Normal file
24
server/internal/plugins/chat/types.go
Normal file
@ -0,0 +1,24 @@
|
||||
package chat
|
||||
|
||||
import "time"
|
||||
|
||||
const PluginName = "chat"
|
||||
|
||||
const (
|
||||
CHAT_INIT = "chat/init"
|
||||
CHAT_MESSAGE = "chat/message"
|
||||
)
|
||||
|
||||
type Init struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Created time.Time `json:"created"`
|
||||
Content Content `json:"content"`
|
||||
}
|
133
server/internal/plugins/dependency.go
Normal file
133
server/internal/plugins/dependency.go
Normal file
@ -0,0 +1,133 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type dependency struct {
|
||||
plugin types.Plugin
|
||||
dependsOn []*dependency
|
||||
invoked bool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (a *dependency) findPlugin(name string) (*dependency, bool) {
|
||||
if a == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if a.plugin.Name() == name {
|
||||
return a, true
|
||||
}
|
||||
|
||||
for _, dep := range a.dependsOn {
|
||||
plug, ok := dep.findPlugin(name)
|
||||
if ok {
|
||||
return plug, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (a *dependency) startPlugin(pm types.PluginManagers) error {
|
||||
if a.invoked {
|
||||
return nil
|
||||
}
|
||||
|
||||
a.invoked = true
|
||||
|
||||
for _, do := range a.dependsOn {
|
||||
if err := do.startPlugin(pm); err != nil {
|
||||
return fmt.Errorf("plugin's '%s' dependency: %w", a.plugin.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
err := a.plugin.Start(pm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("plugin '%s' failed to start: %w", a.plugin.Name(), err)
|
||||
}
|
||||
|
||||
a.logger.Info().Str("plugin", a.plugin.Name()).Msg("plugin started")
|
||||
return nil
|
||||
}
|
||||
|
||||
type dependiencies struct {
|
||||
deps map[string]*dependency
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (d *dependiencies) addPlugin(plugin types.Plugin) error {
|
||||
pluginName := plugin.Name()
|
||||
|
||||
plug, ok := d.deps[pluginName]
|
||||
if !ok {
|
||||
plug = &dependency{}
|
||||
} else if plug.plugin != nil {
|
||||
return fmt.Errorf("plugin '%s' already added", pluginName)
|
||||
}
|
||||
|
||||
plug.plugin = plugin
|
||||
plug.logger = d.logger
|
||||
d.deps[pluginName] = plug
|
||||
|
||||
dplug, ok := plugin.(types.DependablePlugin)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, depName := range dplug.DependsOn() {
|
||||
dependsOn, ok := d.deps[depName]
|
||||
if !ok {
|
||||
dependsOn = &dependency{}
|
||||
} else if dependsOn.plugin != nil {
|
||||
// if there is a cyclical dependency, break it and return error
|
||||
if tdep, ok := dependsOn.findPlugin(pluginName); ok {
|
||||
dependsOn.dependsOn = nil
|
||||
delete(d.deps, pluginName)
|
||||
return fmt.Errorf("cyclical dependency detected: '%s' <-> '%s'", pluginName, tdep.plugin.Name())
|
||||
}
|
||||
}
|
||||
|
||||
plug.dependsOn = append(plug.dependsOn, dependsOn)
|
||||
d.deps[depName] = dependsOn
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dependiencies) findPlugin(name string) (*dependency, bool) {
|
||||
for _, dep := range d.deps {
|
||||
plug, ok := dep.findPlugin(name)
|
||||
if ok {
|
||||
return plug, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (d *dependiencies) start(pm types.PluginManagers) error {
|
||||
for _, dep := range d.deps {
|
||||
if err := dep.startPlugin(pm); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dependiencies) forEach(f func(*dependency) error) error {
|
||||
for _, dep := range d.deps {
|
||||
if err := f(dep); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dependiencies) len() int {
|
||||
return len(d.deps)
|
||||
}
|
630
server/internal/plugins/dependency_test.go
Normal file
630
server/internal/plugins/dependency_test.go
Normal file
@ -0,0 +1,630 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func Test_deps_addPlugin(t *testing.T) {
|
||||
type args struct {
|
||||
p []types.Plugin
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want map[string]*dependency
|
||||
skipRun bool
|
||||
wantErr1 bool
|
||||
wantErr2 bool
|
||||
}{
|
||||
{
|
||||
name: "three plugins - no dependencies",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "second"},
|
||||
&dummyPlugin{name: "third"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "three plugins - one dependency",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "third", dep: []string{"second"}},
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "second"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "three plugins - one double dependency",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "third", dep: []string{"first", "second"}},
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "second"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "three plugins - two dependencies",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "third", dep: []string{"first"}},
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first"},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"first"}},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first"},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first"},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
skipRun: true,
|
||||
}, {
|
||||
name: "three plugins - three dependencies",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "third", dep: []string{"second"}},
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "four plugins - added in reverse order, with dependencies",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "forth", dep: []string{"third"}},
|
||||
&dummyPlugin{name: "third", dep: []string{"second"}},
|
||||
&dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
&dummyPlugin{name: "first"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"forth": {
|
||||
plugin: &dummyPlugin{name: "forth", dep: []string{"third"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
skipRun: true,
|
||||
}, {
|
||||
name: "four plugins - two double dependencies",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "forth", dep: []string{"first", "third"}},
|
||||
&dummyPlugin{name: "third", dep: []string{"first", "second"}},
|
||||
&dummyPlugin{name: "second"},
|
||||
&dummyPlugin{name: "first"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"forth": {
|
||||
plugin: &dummyPlugin{name: "forth", dep: []string{"first", "third"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
{
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"first", "second"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
// So, when we have plugin A in the list and want to add plugin C we can't determine the proper order without
|
||||
// resolving their direct dependiencies first:
|
||||
//
|
||||
// Can be C->D->A->B if D depends on A
|
||||
//
|
||||
// So to do it properly I would imagine tht we need to resolve all direct dependiencies first and build multiple lists:
|
||||
//
|
||||
// i.e. A->B->C D F->G
|
||||
//
|
||||
// and then join these lists in any order.
|
||||
name: "add indirect dependency CDAB",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "A", dep: []string{"B"}},
|
||||
&dummyPlugin{name: "C", dep: []string{"D"}},
|
||||
&dummyPlugin{name: "B"},
|
||||
&dummyPlugin{name: "D", dep: []string{"A"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"B": {
|
||||
plugin: &dummyPlugin{name: "B", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"A": {
|
||||
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "B", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"D": {
|
||||
plugin: &dummyPlugin{name: "D", dep: []string{"A"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "B", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"C": {
|
||||
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 3},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "D", dep: []string{"A"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "B", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
// So, when we have plugin A in the list and want to add plugin C we can't determine the proper order without
|
||||
// resolving their direct dependiencies first:
|
||||
//
|
||||
// Can be A->B->C->D (in this test) if B depends on C
|
||||
//
|
||||
// So to do it properly I would imagine tht we need to resolve all direct dependiencies first and build multiple lists:
|
||||
//
|
||||
// i.e. A->B->C D F->G
|
||||
//
|
||||
// and then join these lists in any order.
|
||||
name: "add indirect dependency ABCD",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "C", dep: []string{"D"}},
|
||||
&dummyPlugin{name: "D"},
|
||||
&dummyPlugin{name: "B", dep: []string{"C"}},
|
||||
&dummyPlugin{name: "A", dep: []string{"B"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"D": {
|
||||
plugin: &dummyPlugin{name: "D", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
"C": {
|
||||
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "D", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"B": {
|
||||
plugin: &dummyPlugin{name: "B", dep: []string{"C"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "D", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"A": {
|
||||
plugin: &dummyPlugin{name: "A", dep: []string{"B"}, idx: 3},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "B", dep: []string{"C"}, idx: 2},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "C", dep: []string{"D"}, idx: 1},
|
||||
invoked: true,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "D", idx: 0},
|
||||
invoked: true,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "add duplicate plugin",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "first"},
|
||||
&dummyPlugin{name: "first"},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {plugin: &dummyPlugin{name: "first", idx: 0}, invoked: true},
|
||||
},
|
||||
wantErr1: true,
|
||||
}, {
|
||||
name: "cyclical dependency",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "first", dep: []string{"second"}},
|
||||
&dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"first": {
|
||||
plugin: &dummyPlugin{name: "first", dep: []string{"second"}, idx: 1},
|
||||
invoked: true,
|
||||
},
|
||||
},
|
||||
wantErr1: true,
|
||||
}, {
|
||||
name: "four plugins - cyclical transitive dependencies in reverse order",
|
||||
args: args{
|
||||
p: []types.Plugin{
|
||||
&dummyPlugin{name: "forth", dep: []string{"third"}},
|
||||
&dummyPlugin{name: "third", dep: []string{"second"}},
|
||||
&dummyPlugin{name: "second", dep: []string{"first"}},
|
||||
&dummyPlugin{name: "first", dep: []string{"forth"}},
|
||||
},
|
||||
},
|
||||
want: map[string]*dependency{
|
||||
"second": {
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", dep: []string{"forth"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
"third": {
|
||||
plugin: &dummyPlugin{name: "third", dep: []string{"second"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "second", dep: []string{"first"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: []*dependency{
|
||||
{
|
||||
plugin: &dummyPlugin{name: "first", dep: []string{"forth"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"forth": {
|
||||
plugin: &dummyPlugin{name: "forth", dep: []string{"third"}, idx: 0},
|
||||
invoked: false,
|
||||
dependsOn: nil,
|
||||
},
|
||||
},
|
||||
wantErr1: true,
|
||||
skipRun: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &dependiencies{deps: make(map[string]*dependency)}
|
||||
|
||||
var (
|
||||
err error
|
||||
counter int
|
||||
)
|
||||
for _, p := range tt.args.p {
|
||||
if !tt.skipRun {
|
||||
p.(*dummyPlugin).counter = &counter
|
||||
}
|
||||
if err = d.addPlugin(p); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil != tt.wantErr1 {
|
||||
t.Errorf("dependiencies.addPlugin() error = %v, wantErr1 %v", err, tt.wantErr1)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.skipRun {
|
||||
if err := d.start(types.PluginManagers{}); (err != nil) != tt.wantErr2 {
|
||||
t.Errorf("dependiencies.start() error = %v, wantErr1 %v", err, tt.wantErr2)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(d.deps, tt.want) {
|
||||
t.Errorf("deps = %v, want %v", d.deps, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type dummyPlugin struct {
|
||||
name string
|
||||
dep []string
|
||||
idx int
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (d dummyPlugin) Name() string {
|
||||
return d.name
|
||||
}
|
||||
|
||||
func (d dummyPlugin) DependsOn() []string {
|
||||
return d.dep
|
||||
}
|
||||
|
||||
func (d dummyPlugin) Config() types.PluginConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dummyPlugin) Start(types.PluginManagers) error {
|
||||
if len(d.dep) > 0 {
|
||||
*d.counter++
|
||||
d.idx = *d.counter
|
||||
}
|
||||
d.counter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyPlugin) Shutdown() error {
|
||||
return nil
|
||||
}
|
41
server/internal/plugins/filetransfer/config.go
Normal file
41
server/internal/plugins/filetransfer/config.go
Normal file
@ -0,0 +1,41 @@
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
RootDir string
|
||||
RefreshInterval time.Duration
|
||||
}
|
||||
|
||||
func (Config) Init(cmd *cobra.Command) error {
|
||||
cmd.PersistentFlags().Bool("filetransfer.enabled", false, "whether file transfer is enabled")
|
||||
if err := viper.BindPFlag("filetransfer.enabled", cmd.PersistentFlags().Lookup("filetransfer.enabled")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("filetransfer.dir", "/home/neko/Downloads", "root directory for file transfer")
|
||||
if err := viper.BindPFlag("filetransfer.dir", cmd.PersistentFlags().Lookup("filetransfer.dir")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().Duration("filetransfer.refresh_interval", 30*time.Second, "interval to refresh file list")
|
||||
if err := viper.BindPFlag("filetransfer.refresh_interval", cmd.PersistentFlags().Lookup("filetransfer.refresh_interval")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Config) Set() {
|
||||
s.Enabled = viper.GetBool("filetransfer.enabled")
|
||||
rootDir := viper.GetString("filetransfer.dir")
|
||||
s.RootDir = filepath.Clean(rootDir)
|
||||
s.RefreshInterval = viper.GetDuration("filetransfer.refresh_interval")
|
||||
}
|
332
server/internal/plugins/filetransfer/manager.go
Normal file
332
server/internal/plugins/filetransfer/manager.go
Normal file
@ -0,0 +1,332 @@
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/auth"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const MULTIPART_FORM_MAX_MEMORY = 32 << 20
|
||||
|
||||
func NewManager(
|
||||
sessions types.SessionManager,
|
||||
config *Config,
|
||||
) *Manager {
|
||||
logger := log.With().Str("module", "filetransfer").Logger()
|
||||
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
config: config,
|
||||
sessions: sessions,
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
logger zerolog.Logger
|
||||
config *Config
|
||||
sessions types.SessionManager
|
||||
shutdown chan struct{}
|
||||
mu sync.RWMutex
|
||||
fileList []Item
|
||||
}
|
||||
|
||||
func (m *Manager) isEnabledForSession(session types.Session) (bool, error) {
|
||||
settings := Settings{
|
||||
Enabled: true, // defaults to true
|
||||
}
|
||||
err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings)
|
||||
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
|
||||
return false, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err)
|
||||
}
|
||||
|
||||
profile := Settings{
|
||||
Enabled: true, // defaults to true
|
||||
}
|
||||
|
||||
err = session.Profile().Plugins.Unmarshal(PluginName, &profile)
|
||||
if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) {
|
||||
return false, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err)
|
||||
}
|
||||
|
||||
return m.config.Enabled && (settings.Enabled || session.Profile().IsAdmin) && profile.Enabled, nil
|
||||
}
|
||||
|
||||
func (m *Manager) refresh() (error, bool) {
|
||||
// if file transfer is disabled, return immediately without refreshing
|
||||
if !m.config.Enabled {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
files, err := ListFiles(m.config.RootDir)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// check if file list has changed (todo: use hash instead of comparing all fields)
|
||||
changed := false
|
||||
if len(files) == len(m.fileList) {
|
||||
for i, file := range files {
|
||||
if file.Name != m.fileList[i].Name || file.Size != m.fileList[i].Size {
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
changed = true
|
||||
}
|
||||
|
||||
m.fileList = files
|
||||
return nil, changed
|
||||
}
|
||||
|
||||
func (m *Manager) broadcastUpdate() {
|
||||
m.mu.RLock()
|
||||
fileList := m.fileList
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.sessions.Broadcast(FILETRANSFER_UPDATE, Message{
|
||||
Enabled: m.config.Enabled,
|
||||
RootDir: m.config.RootDir,
|
||||
Files: fileList,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) sendUpdate(session types.Session) {
|
||||
m.mu.RLock()
|
||||
fileList := m.fileList
|
||||
m.mu.RUnlock()
|
||||
|
||||
session.Send(FILETRANSFER_UPDATE, Message{
|
||||
Enabled: m.config.Enabled,
|
||||
RootDir: m.config.RootDir,
|
||||
Files: fileList,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Start() error {
|
||||
// send init message once a user connects
|
||||
m.sessions.OnConnected(func(session types.Session) {
|
||||
m.sendUpdate(session)
|
||||
})
|
||||
|
||||
// if file transfer is disabled, return immediately without starting the watcher
|
||||
if !m.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := os.Stat(m.config.RootDir); os.IsNotExist(err) {
|
||||
err = os.Mkdir(m.config.RootDir, os.ModePerm)
|
||||
m.logger.Err(err).Msg("creating file transfer directory")
|
||||
}
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start file transfer dir watcher: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer watcher.Close()
|
||||
|
||||
// periodically refresh file list
|
||||
ticker := time.NewTicker(m.config.RefreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.shutdown:
|
||||
m.logger.Info().Msg("shutting down file transfer manager")
|
||||
return
|
||||
case <-ticker.C:
|
||||
err, changed := m.refresh()
|
||||
if err != nil {
|
||||
m.logger.Err(err).Msg("unable to refresh file transfer list")
|
||||
}
|
||||
if changed {
|
||||
m.broadcastUpdate()
|
||||
}
|
||||
case e, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
m.logger.Info().Msg("file transfer dir watcher closed")
|
||||
return
|
||||
}
|
||||
|
||||
if e.Has(fsnotify.Create) || e.Has(fsnotify.Remove) || e.Has(fsnotify.Rename) {
|
||||
m.logger.Debug().Str("event", e.String()).Msg("file transfer dir watcher event")
|
||||
|
||||
err, changed := m.refresh()
|
||||
if err != nil {
|
||||
m.logger.Err(err).Msg("unable to refresh file transfer list")
|
||||
}
|
||||
|
||||
if changed {
|
||||
m.broadcastUpdate()
|
||||
}
|
||||
}
|
||||
case err := <-watcher.Errors:
|
||||
m.logger.Err(err).Msg("error in file transfer dir watcher")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := watcher.Add(m.config.RootDir); err != nil {
|
||||
return fmt.Errorf("unable to watch file transfer dir: %w", err)
|
||||
}
|
||||
|
||||
// initial refresh
|
||||
err, changed := m.refresh()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to refresh file transfer list: %w", err)
|
||||
}
|
||||
if changed {
|
||||
m.broadcastUpdate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Shutdown() error {
|
||||
close(m.shutdown)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Route(r types.Router) {
|
||||
r.With(auth.AdminsOnly).Get("/", m.downloadFileHandler)
|
||||
r.With(auth.AdminsOnly).Post("/", m.uploadFileHandler)
|
||||
}
|
||||
|
||||
func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool {
|
||||
switch msg.Event {
|
||||
case FILETRANSFER_UPDATE:
|
||||
err, changed := m.refresh()
|
||||
if err != nil {
|
||||
m.logger.Err(err).Msg("unable to refresh file transfer list")
|
||||
}
|
||||
|
||||
if changed {
|
||||
// broadcast update message to all clients
|
||||
m.broadcastUpdate()
|
||||
} else {
|
||||
// send update message to this client only
|
||||
m.sendUpdate(session)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// not handled by this plugin
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) downloadFileHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
session, ok := auth.GetSession(r)
|
||||
if !ok {
|
||||
return utils.HttpUnauthorized("session not found")
|
||||
}
|
||||
|
||||
enabled, err := m.isEnabledForSession(session)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
Msg("error checking file transfer permissions")
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
return utils.HttpForbidden("file transfer is disabled")
|
||||
}
|
||||
|
||||
filename := r.URL.Query().Get("filename")
|
||||
badChars, err := regexp.MatchString(`(?m)\.\.(?:\/|$)`, filename)
|
||||
if filename == "" || badChars || err != nil {
|
||||
return utils.HttpBadRequest().
|
||||
WithInternalErr(err).
|
||||
Msg("bad filename")
|
||||
}
|
||||
|
||||
// ensure filename is clean and only contains the basename
|
||||
filename = filepath.Clean(filename)
|
||||
filename = filepath.Base(filename)
|
||||
filePath := filepath.Join(m.config.RootDir, filename)
|
||||
|
||||
http.ServeFile(w, r, filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) uploadFileHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
session, ok := auth.GetSession(r)
|
||||
if !ok {
|
||||
return utils.HttpUnauthorized("session not found")
|
||||
}
|
||||
|
||||
enabled, err := m.isEnabledForSession(session)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
Msg("error checking file transfer permissions")
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
return utils.HttpForbidden("file transfer is disabled")
|
||||
}
|
||||
|
||||
err = r.ParseMultipartForm(MULTIPART_FORM_MAX_MEMORY)
|
||||
if err != nil || r.MultipartForm == nil {
|
||||
return utils.HttpBadRequest().
|
||||
WithInternalErr(err).
|
||||
Msg("error parsing form")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = r.MultipartForm.RemoveAll()
|
||||
if err != nil {
|
||||
m.logger.Warn().Err(err).Msg("failed to clean up multipart form")
|
||||
}
|
||||
}()
|
||||
|
||||
for _, formheader := range r.MultipartForm.File["files"] {
|
||||
// ensure filename is clean and only contains the basename
|
||||
filename := filepath.Clean(formheader.Filename)
|
||||
filename = filepath.Base(filename)
|
||||
filePath := filepath.Join(m.config.RootDir, filename)
|
||||
|
||||
formfile, err := formheader.Open()
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest().
|
||||
WithInternalErr(err).
|
||||
Msg("error opening formdata file")
|
||||
}
|
||||
defer formfile.Close()
|
||||
|
||||
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
Msg("error opening file for writing")
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, formfile)
|
||||
if err != nil {
|
||||
return utils.HttpInternalServerError().
|
||||
WithInternalErr(err).
|
||||
Msg("error writing file")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
35
server/internal/plugins/filetransfer/plugin.go
Normal file
35
server/internal/plugins/filetransfer/plugin.go
Normal file
@ -0,0 +1,35 @@
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type Plugin struct {
|
||||
config *Config
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func NewPlugin() *Plugin {
|
||||
return &Plugin{
|
||||
config: &Config{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Plugin) Name() string {
|
||||
return PluginName
|
||||
}
|
||||
|
||||
func (p *Plugin) Config() types.PluginConfig {
|
||||
return p.config
|
||||
}
|
||||
|
||||
func (p *Plugin) Start(m types.PluginManagers) error {
|
||||
p.manager = NewManager(m.SessionManager, p.config)
|
||||
m.ApiManager.AddRouter("/filetransfer", p.manager.Route)
|
||||
m.WebSocketManager.AddHandler(p.manager.WebSocketHandler)
|
||||
return p.manager.Start()
|
||||
}
|
||||
|
||||
func (p *Plugin) Shutdown() error {
|
||||
return p.manager.Shutdown()
|
||||
}
|
30
server/internal/plugins/filetransfer/types.go
Normal file
30
server/internal/plugins/filetransfer/types.go
Normal file
@ -0,0 +1,30 @@
|
||||
package filetransfer
|
||||
|
||||
const PluginName = "filetransfer"
|
||||
|
||||
type Settings struct {
|
||||
Enabled bool `json:"enabled" mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
const (
|
||||
FILETRANSFER_UPDATE = "filetransfer/update"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RootDir string `json:"root_dir"`
|
||||
Files []Item `json:"files"`
|
||||
}
|
||||
|
||||
type ItemType string
|
||||
|
||||
const (
|
||||
ItemTypeFile ItemType = "file"
|
||||
ItemTypeDir ItemType = "dir"
|
||||
)
|
||||
|
||||
type Item struct {
|
||||
Name string `json:"name"`
|
||||
Type ItemType `json:"type"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
}
|
32
server/internal/plugins/filetransfer/utils.go
Normal file
32
server/internal/plugins/filetransfer/utils.go
Normal file
@ -0,0 +1,32 @@
|
||||
package filetransfer
|
||||
|
||||
import "os"
|
||||
|
||||
func ListFiles(path string) ([]Item, error) {
|
||||
items, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]Item, len(items))
|
||||
for i, item := range items {
|
||||
var itemType ItemType
|
||||
var size int64 = 0
|
||||
if item.IsDir() {
|
||||
itemType = ItemTypeDir
|
||||
} else {
|
||||
itemType = ItemTypeFile
|
||||
info, err := item.Info()
|
||||
if err == nil {
|
||||
size = info.Size()
|
||||
}
|
||||
}
|
||||
out[i] = Item{
|
||||
Name: item.Name(),
|
||||
Type: itemType,
|
||||
Size: size,
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
183
server/internal/plugins/manager.go
Normal file
183
server/internal/plugins/manager.go
Normal file
@ -0,0 +1,183 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"plugin"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/plugins/chat"
|
||||
"github.com/demodesk/neko/internal/plugins/filetransfer"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
type ManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
config *config.Plugins
|
||||
plugins dependiencies
|
||||
}
|
||||
|
||||
func New(config *config.Plugins) *ManagerCtx {
|
||||
manager := &ManagerCtx{
|
||||
logger: log.With().Str("module", "plugins").Logger(),
|
||||
config: config,
|
||||
plugins: dependiencies{
|
||||
deps: make(map[string]*dependency),
|
||||
},
|
||||
}
|
||||
|
||||
manager.plugins.logger = manager.logger
|
||||
|
||||
if config.Enabled {
|
||||
err := manager.loadDir(config.Dir)
|
||||
|
||||
// only log error if plugin is not required
|
||||
if err != nil && config.Required {
|
||||
manager.logger.Fatal().Err(err).Msg("error loading plugins")
|
||||
}
|
||||
|
||||
manager.logger.Info().Msgf("loading finished, total %d plugins", manager.plugins.len())
|
||||
}
|
||||
|
||||
// add built-in plugins
|
||||
manager.plugins.addPlugin(filetransfer.NewPlugin())
|
||||
manager.plugins.addPlugin(chat.NewPlugin())
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) loadDir(dir string) error {
|
||||
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = manager.load(path)
|
||||
|
||||
// return error if plugin is required
|
||||
if err != nil && manager.config.Required {
|
||||
return err
|
||||
}
|
||||
|
||||
// otherwise only log error if plugin is not required
|
||||
manager.logger.Err(err).Str("plugin", path).Msg("loading a plugin")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) load(path string) error {
|
||||
pl, err := plugin.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sym, err := pl.Lookup("Plugin")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p, ok := sym.(types.Plugin)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a valid plugin")
|
||||
}
|
||||
|
||||
if err = manager.plugins.addPlugin(p); err != nil {
|
||||
return fmt.Errorf("failed to add plugin: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) InitConfigs(cmd *cobra.Command) {
|
||||
_ = manager.plugins.forEach(func(d *dependency) error {
|
||||
if err := d.plugin.Config().Init(cmd); err != nil {
|
||||
log.Err(err).Str("plugin", d.plugin.Name()).Msg("unable to initialize configuration")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) SetConfigs() {
|
||||
_ = manager.plugins.forEach(func(d *dependency) error {
|
||||
d.plugin.Config().Set()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) Start(
|
||||
sessionManager types.SessionManager,
|
||||
webSocketManager types.WebSocketManager,
|
||||
apiManager types.ApiManager,
|
||||
) {
|
||||
err := manager.plugins.start(types.PluginManagers{
|
||||
SessionManager: sessionManager,
|
||||
WebSocketManager: webSocketManager,
|
||||
ApiManager: apiManager,
|
||||
LoadServiceFromPlugin: manager.LookupService,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if manager.config.Required {
|
||||
manager.logger.Fatal().Err(err).Msg("failed to start plugins, exiting...")
|
||||
} else {
|
||||
manager.logger.Err(err).Msg("failed to start plugins, skipping...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) Shutdown() error {
|
||||
_ = manager.plugins.forEach(func(d *dependency) error {
|
||||
err := d.plugin.Shutdown()
|
||||
manager.logger.Err(err).Str("plugin", d.plugin.Name()).Msg("plugin shutdown")
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) LookupService(pluginName string) (any, error) {
|
||||
plug, ok := manager.plugins.findPlugin(pluginName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin '%s' not found", pluginName)
|
||||
}
|
||||
|
||||
expPlug, ok := plug.plugin.(types.ExposablePlugin)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin '%s' is not exposable", pluginName)
|
||||
}
|
||||
|
||||
return expPlug.ExposeService(), nil
|
||||
}
|
||||
|
||||
func (manager *ManagerCtx) Metadata() []types.PluginMetadata {
|
||||
var plugins []types.PluginMetadata
|
||||
|
||||
_ = manager.plugins.forEach(func(d *dependency) error {
|
||||
dependsOn := make([]string, 0)
|
||||
deps, isDependalbe := d.plugin.(types.DependablePlugin)
|
||||
if isDependalbe {
|
||||
dependsOn = deps.DependsOn()
|
||||
}
|
||||
|
||||
_, isExposable := d.plugin.(types.ExposablePlugin)
|
||||
|
||||
plugins = append(plugins, types.PluginMetadata{
|
||||
Name: d.plugin.Name(),
|
||||
IsDependable: isDependalbe,
|
||||
IsExposable: isExposable,
|
||||
DependsOn: dependsOn,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return plugins
|
||||
}
|
80
server/internal/session/auth.go
Normal file
80
server/internal/session/auth.go
Normal file
@ -0,0 +1,80 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func (manager *SessionManagerCtx) CookieSetToken(w http.ResponseWriter, token string) {
|
||||
sameSite := http.SameSiteDefaultMode
|
||||
if manager.config.CookieSecure {
|
||||
sameSite = http.SameSiteNoneMode
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: manager.config.CookieName,
|
||||
Value: token,
|
||||
Expires: time.Now().Add(manager.config.CookieExpiration),
|
||||
Secure: manager.config.CookieSecure,
|
||||
SameSite: sameSite,
|
||||
HttpOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) CookieClearToken(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(manager.config.CookieName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cookie.Value = ""
|
||||
cookie.Expires = time.Unix(0, 0)
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Authenticate(r *http.Request) (types.Session, error) {
|
||||
token, ok := manager.getToken(r)
|
||||
if !ok {
|
||||
return nil, errors.New("no authentication provided")
|
||||
}
|
||||
|
||||
session, ok := manager.GetByToken(token)
|
||||
if !ok {
|
||||
return nil, types.ErrSessionNotFound
|
||||
}
|
||||
|
||||
if !session.Profile().CanLogin {
|
||||
return nil, types.ErrSessionLoginDisabled
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) getToken(r *http.Request) (string, bool) {
|
||||
if manager.CookieEnabled() {
|
||||
// get from Cookie
|
||||
cookie, err := r.Cookie(manager.config.CookieName)
|
||||
if err == nil {
|
||||
return cookie.Value, true
|
||||
}
|
||||
}
|
||||
|
||||
// get from Header
|
||||
reqToken := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(reqToken, "Bearer ")
|
||||
if len(splitToken) == 2 {
|
||||
return strings.TrimSpace(splitToken[1]), true
|
||||
}
|
||||
|
||||
// get from URL
|
||||
token := r.URL.Query().Get("token")
|
||||
if token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
470
server/internal/session/manager.go
Normal file
470
server/internal/session/manager.go
Normal file
@ -0,0 +1,470 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/kataras/go-events"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func New(config *config.Session) *SessionManagerCtx {
|
||||
manager := &SessionManagerCtx{
|
||||
logger: log.With().Str("module", "session").Logger(),
|
||||
config: config,
|
||||
settings: types.Settings{
|
||||
PrivateMode: config.PrivateMode,
|
||||
LockedLogins: config.LockedLogins,
|
||||
LockedControls: config.LockedControls || config.ControlProtection,
|
||||
ControlProtection: config.ControlProtection,
|
||||
ImplicitHosting: config.ImplicitHosting,
|
||||
InactiveCursors: config.InactiveCursors,
|
||||
MercifulReconnect: config.MercifulReconnect,
|
||||
},
|
||||
tokens: make(map[string]string),
|
||||
sessions: make(map[string]*SessionCtx),
|
||||
cursors: make(map[types.Session][]types.Cursor),
|
||||
emmiter: events.New(),
|
||||
}
|
||||
|
||||
// create API session
|
||||
if config.APIToken != "" {
|
||||
manager.apiSession = &SessionCtx{
|
||||
id: "API",
|
||||
token: config.APIToken,
|
||||
manager: manager,
|
||||
logger: manager.logger.With().Str("session_id", "API").Logger(),
|
||||
profile: types.MemberProfile{
|
||||
Name: "API Session",
|
||||
IsAdmin: true,
|
||||
CanLogin: true,
|
||||
CanConnect: false,
|
||||
CanWatch: true,
|
||||
CanHost: true,
|
||||
CanAccessClipboard: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// try to load sessions from file
|
||||
manager.load()
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
type SessionManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
config *config.Session
|
||||
|
||||
settings types.Settings
|
||||
settingsMu sync.Mutex
|
||||
|
||||
tokens map[string]string
|
||||
sessions map[string]*SessionCtx
|
||||
sessionsMu sync.Mutex
|
||||
|
||||
hostId atomic.Value
|
||||
|
||||
cursors map[types.Session][]types.Cursor
|
||||
cursorsMu sync.Mutex
|
||||
|
||||
emmiter events.EventEmmiter
|
||||
apiSession *SessionCtx
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Create(id string, profile types.MemberProfile) (types.Session, string, error) {
|
||||
token, err := utils.NewUID(64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
manager.sessionsMu.Lock()
|
||||
if _, ok := manager.sessions[id]; ok {
|
||||
manager.sessionsMu.Unlock()
|
||||
return nil, "", types.ErrSessionAlreadyExists
|
||||
}
|
||||
|
||||
if _, ok := manager.tokens[token]; ok {
|
||||
manager.sessionsMu.Unlock()
|
||||
return nil, "", errors.New("session token already exists")
|
||||
}
|
||||
|
||||
session := &SessionCtx{
|
||||
id: id,
|
||||
token: token,
|
||||
manager: manager,
|
||||
logger: manager.logger.With().Str("session_id", id).Logger(),
|
||||
profile: profile,
|
||||
}
|
||||
|
||||
manager.tokens[token] = id
|
||||
manager.sessions[id] = session
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
manager.emmiter.Emit("created", session)
|
||||
manager.save()
|
||||
|
||||
return session, token, nil
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Update(id string, profile types.MemberProfile) error {
|
||||
manager.sessionsMu.Lock()
|
||||
|
||||
session, ok := manager.sessions[id]
|
||||
if !ok {
|
||||
manager.sessionsMu.Unlock()
|
||||
return types.ErrSessionNotFound
|
||||
}
|
||||
|
||||
old := session.profile
|
||||
session.profile = profile
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
manager.emmiter.Emit("profile_changed", session, profile, old)
|
||||
manager.save()
|
||||
|
||||
session.profileChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Delete(id string) error {
|
||||
manager.sessionsMu.Lock()
|
||||
session, ok := manager.sessions[id]
|
||||
if !ok {
|
||||
manager.sessionsMu.Unlock()
|
||||
return types.ErrSessionNotFound
|
||||
}
|
||||
|
||||
delete(manager.tokens, session.token)
|
||||
delete(manager.sessions, id)
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
if session.State().IsConnected {
|
||||
session.DestroyWebSocketPeer("session deleted")
|
||||
}
|
||||
|
||||
if session.State().IsWatching {
|
||||
session.GetWebRTCPeer().Destroy()
|
||||
}
|
||||
|
||||
manager.emmiter.Emit("deleted", session)
|
||||
manager.save()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Disconnect(id string) error {
|
||||
manager.sessionsMu.Lock()
|
||||
session, ok := manager.sessions[id]
|
||||
if !ok {
|
||||
manager.sessionsMu.Unlock()
|
||||
return types.ErrSessionNotFound
|
||||
}
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
if session.State().IsConnected {
|
||||
session.DestroyWebSocketPeer("session disconnected")
|
||||
}
|
||||
|
||||
if session.State().IsWatching {
|
||||
session.GetWebRTCPeer().Destroy()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Get(id string) (types.Session, bool) {
|
||||
manager.sessionsMu.Lock()
|
||||
defer manager.sessionsMu.Unlock()
|
||||
|
||||
session, ok := manager.sessions[id]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) GetByToken(token string) (types.Session, bool) {
|
||||
manager.sessionsMu.Lock()
|
||||
id, ok := manager.tokens[token]
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
if ok {
|
||||
return manager.Get(id)
|
||||
}
|
||||
|
||||
// is API session
|
||||
if manager.apiSession != nil && manager.apiSession.token == token {
|
||||
return manager.apiSession, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) List() []types.Session {
|
||||
manager.sessionsMu.Lock()
|
||||
defer manager.sessionsMu.Unlock()
|
||||
|
||||
var sessions []types.Session
|
||||
for _, session := range manager.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Range(f func(session types.Session) bool) {
|
||||
manager.sessionsMu.Lock()
|
||||
defer manager.sessionsMu.Unlock()
|
||||
|
||||
for _, session := range manager.sessions {
|
||||
if !f(session) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// host
|
||||
// ---
|
||||
|
||||
func (manager *SessionManagerCtx) setHost(session, host types.Session) {
|
||||
var hostId string
|
||||
if host != nil {
|
||||
hostId = host.ID()
|
||||
}
|
||||
|
||||
manager.hostId.Store(hostId)
|
||||
manager.emmiter.Emit("host_changed", session, host)
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) GetHost() (types.Session, bool) {
|
||||
hostId, ok := manager.hostId.Load().(string)
|
||||
if !ok || hostId == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return manager.Get(hostId)
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) isHost(host types.Session) bool {
|
||||
hostId, ok := manager.hostId.Load().(string)
|
||||
return ok && hostId == host.ID()
|
||||
}
|
||||
|
||||
// ---
|
||||
// cursors
|
||||
// ---
|
||||
|
||||
func (manager *SessionManagerCtx) SetCursor(cursor types.Cursor, session types.Session) {
|
||||
manager.cursorsMu.Lock()
|
||||
defer manager.cursorsMu.Unlock()
|
||||
|
||||
list, ok := manager.cursors[session]
|
||||
if !ok {
|
||||
list = []types.Cursor{}
|
||||
}
|
||||
|
||||
list = append(list, cursor)
|
||||
manager.cursors[session] = list
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) PopCursors() map[types.Session][]types.Cursor {
|
||||
manager.cursorsMu.Lock()
|
||||
defer manager.cursorsMu.Unlock()
|
||||
|
||||
cursors := manager.cursors
|
||||
manager.cursors = make(map[types.Session][]types.Cursor)
|
||||
|
||||
return cursors
|
||||
}
|
||||
|
||||
// ---
|
||||
// broadcasts
|
||||
// ---
|
||||
|
||||
func (manager *SessionManagerCtx) Broadcast(event string, payload any, exclude ...string) {
|
||||
for _, session := range manager.List() {
|
||||
if !session.State().IsConnected {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(exclude) > 0 {
|
||||
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
session.Send(event, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) AdminBroadcast(event string, payload any, exclude ...string) {
|
||||
for _, session := range manager.List() {
|
||||
if !session.State().IsConnected || !session.Profile().IsAdmin {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(exclude) > 0 {
|
||||
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
session.Send(event, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) InactiveCursorsBroadcast(event string, payload any, exclude ...string) {
|
||||
for _, session := range manager.List() {
|
||||
if !session.State().IsConnected || !session.Profile().CanSeeInactiveCursors {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(exclude) > 0 {
|
||||
if in, _ := utils.ArrayIn(session.ID(), exclude); in {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
session.Send(event, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// events
|
||||
// ---
|
||||
|
||||
func (manager *SessionManagerCtx) OnCreated(listener func(session types.Session)) {
|
||||
manager.emmiter.On("created", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnDeleted(listener func(session types.Session)) {
|
||||
manager.emmiter.On("deleted", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnConnected(listener func(session types.Session)) {
|
||||
manager.emmiter.On("connected", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnDisconnected(listener func(session types.Session)) {
|
||||
manager.emmiter.On("disconnected", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnProfileChanged(listener func(session types.Session, new, old types.MemberProfile)) {
|
||||
manager.emmiter.On("profile_changed", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx), payload[1].(types.MemberProfile), payload[2].(types.MemberProfile))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnStateChanged(listener func(session types.Session)) {
|
||||
manager.emmiter.On("state_changed", func(payload ...any) {
|
||||
listener(payload[0].(*SessionCtx))
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnHostChanged(listener func(session, host types.Session)) {
|
||||
manager.emmiter.On("host_changed", func(payload ...any) {
|
||||
if payload[1] == nil {
|
||||
listener(payload[0].(*SessionCtx), nil)
|
||||
} else {
|
||||
listener(payload[0].(*SessionCtx), payload[1].(*SessionCtx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) OnSettingsChanged(listener func(session types.Session, new, old types.Settings)) {
|
||||
manager.emmiter.On("settings_changed", func(payload ...any) {
|
||||
listener(payload[0].(types.Session), payload[1].(types.Settings), payload[2].(types.Settings))
|
||||
})
|
||||
}
|
||||
|
||||
// ---
|
||||
// settings
|
||||
// ---
|
||||
|
||||
func (manager *SessionManagerCtx) UpdateSettingsFunc(session types.Session, f func(settings *types.Settings) bool) {
|
||||
manager.settingsMu.Lock()
|
||||
new := manager.settings
|
||||
if f(&new) {
|
||||
old := manager.settings
|
||||
manager.settings = new
|
||||
manager.settingsMu.Unlock()
|
||||
manager.updateSettings(session, new, old)
|
||||
return
|
||||
}
|
||||
manager.settingsMu.Unlock()
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) updateSettings(session types.Session, new, old types.Settings) {
|
||||
// if private mode changed
|
||||
if old.PrivateMode != new.PrivateMode {
|
||||
// update webrtc paused state for all sessions
|
||||
for _, s := range manager.List() {
|
||||
enabled := s.PrivateModeEnabled()
|
||||
|
||||
// if session had control, it must release it
|
||||
if enabled && s.IsHost() {
|
||||
session.ClearHost()
|
||||
}
|
||||
|
||||
// its webrtc connection will be paused or unpaused
|
||||
if webrtcPeer := s.GetWebRTCPeer(); webrtcPeer != nil {
|
||||
webrtcPeer.SetPaused(enabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if control protection changed and controls are not locked
|
||||
if old.ControlProtection != new.ControlProtection && new.ControlProtection && !new.LockedControls {
|
||||
// if there is no admin, lock controls
|
||||
hasAdmin := false
|
||||
manager.Range(func(session types.Session) bool {
|
||||
if session.Profile().IsAdmin && session.State().IsConnected {
|
||||
hasAdmin = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasAdmin {
|
||||
manager.settingsMu.Lock()
|
||||
manager.settings.LockedControls = true
|
||||
new.LockedControls = true
|
||||
manager.settingsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// if contols have been locked
|
||||
if old.LockedControls != new.LockedControls && new.LockedControls {
|
||||
// if the host is not admin, it must release controls
|
||||
host, hasHost := manager.GetHost()
|
||||
if hasHost && !host.Profile().IsAdmin {
|
||||
session.ClearHost()
|
||||
}
|
||||
}
|
||||
|
||||
manager.emmiter.Emit("settings_changed", session, new, old)
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) Settings() types.Settings {
|
||||
manager.settingsMu.Lock()
|
||||
defer manager.settingsMu.Unlock()
|
||||
|
||||
return manager.settings
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) CookieEnabled() bool {
|
||||
return manager.config.CookieEnabled
|
||||
}
|
97
server/internal/session/serialize.go
Normal file
97
server/internal/session/serialize.go
Normal file
@ -0,0 +1,97 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
)
|
||||
|
||||
func (manager *SessionManagerCtx) save() {
|
||||
if manager.config.File == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// serialize sessions
|
||||
sessions := make([]types.SessionProfile, 0, len(manager.sessions))
|
||||
for _, session := range manager.sessions {
|
||||
sessions = append(sessions, types.SessionProfile{
|
||||
Id: session.id,
|
||||
Token: session.token,
|
||||
Profile: session.profile,
|
||||
})
|
||||
}
|
||||
|
||||
// convert to json
|
||||
data, err := json.Marshal(sessions)
|
||||
if err != nil {
|
||||
manager.logger.Error().Err(err).Msg("failed to marshal sessions")
|
||||
return
|
||||
}
|
||||
|
||||
// write to file
|
||||
err = os.WriteFile(manager.config.File, data, 0644)
|
||||
if err != nil {
|
||||
manager.logger.Error().Err(err).
|
||||
Str("file", manager.config.File).
|
||||
Msg("failed to write sessions to a file")
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *SessionManagerCtx) load() {
|
||||
if manager.config.File == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// read file
|
||||
data, err := os.ReadFile(manager.config.File)
|
||||
if err != nil {
|
||||
// if file does not exist
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
manager.logger.Info().
|
||||
Str("file", manager.config.File).
|
||||
Msg("sessions file does not exist")
|
||||
return
|
||||
}
|
||||
manager.logger.Error().Err(err).
|
||||
Str("file", manager.config.File).
|
||||
Msg("failed to read sessions from a file")
|
||||
return
|
||||
}
|
||||
|
||||
// if file is empty
|
||||
if len(data) == 0 {
|
||||
manager.logger.Info().
|
||||
Str("file", manager.config.File).
|
||||
Msg("sessions file is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// deserialize sessions
|
||||
sessions := make([]types.SessionProfile, 0)
|
||||
err = json.Unmarshal(data, &sessions)
|
||||
if err != nil {
|
||||
manager.logger.Error().Err(err).Msg("failed to unmarshal sessions")
|
||||
return
|
||||
}
|
||||
|
||||
// create sessions
|
||||
manager.sessionsMu.Lock()
|
||||
for _, session := range sessions {
|
||||
manager.tokens[session.Token] = session.Id
|
||||
manager.sessions[session.Id] = &SessionCtx{
|
||||
id: session.Id,
|
||||
token: session.Token,
|
||||
manager: manager,
|
||||
logger: manager.logger.With().Str("session_id", session.Id).Logger(),
|
||||
profile: session.Profile,
|
||||
}
|
||||
}
|
||||
manager.sessionsMu.Unlock()
|
||||
|
||||
manager.logger.Info().
|
||||
Int("sessions", len(sessions)).
|
||||
Str("file", manager.config.File).
|
||||
Msg("loaded sessions from a file")
|
||||
}
|
283
server/internal/session/session.go
Normal file
283
server/internal/session/session.go
Normal file
@ -0,0 +1,283 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
)
|
||||
|
||||
// client is expected to reconnect within 5 second
|
||||
// if some unexpected websocket disconnect happens
|
||||
const WS_DELAYED_DURATION = 5 * time.Second
|
||||
|
||||
type SessionCtx struct {
|
||||
id string
|
||||
token string
|
||||
logger zerolog.Logger
|
||||
manager *SessionManagerCtx
|
||||
profile types.MemberProfile
|
||||
state types.SessionState
|
||||
|
||||
websocketPeer types.WebSocketPeer
|
||||
websocketMu sync.Mutex
|
||||
|
||||
// websocket delayed set connected events
|
||||
wsDelayedMu sync.Mutex
|
||||
wsDelayedTimer *time.Timer
|
||||
|
||||
webrtcPeer types.WebRTCPeer
|
||||
webrtcMu sync.Mutex
|
||||
}
|
||||
|
||||
func (session *SessionCtx) ID() string {
|
||||
return session.id
|
||||
}
|
||||
|
||||
func (session *SessionCtx) Profile() types.MemberProfile {
|
||||
return session.profile
|
||||
}
|
||||
|
||||
func (session *SessionCtx) profileChanged() {
|
||||
if !session.profile.CanHost && session.IsHost() {
|
||||
session.ClearHost()
|
||||
}
|
||||
|
||||
if (!session.profile.CanConnect || !session.profile.CanLogin || !session.profile.CanWatch) && session.state.IsWatching {
|
||||
session.GetWebRTCPeer().Destroy()
|
||||
}
|
||||
|
||||
if (!session.profile.CanConnect || !session.profile.CanLogin) && session.state.IsConnected {
|
||||
session.DestroyWebSocketPeer("profile changed")
|
||||
}
|
||||
|
||||
// update webrtc paused state
|
||||
if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil {
|
||||
webrtcPeer.SetPaused(session.PrivateModeEnabled())
|
||||
}
|
||||
}
|
||||
|
||||
func (session *SessionCtx) State() types.SessionState {
|
||||
return session.state
|
||||
}
|
||||
|
||||
func (session *SessionCtx) IsHost() bool {
|
||||
return session.manager.isHost(session)
|
||||
}
|
||||
|
||||
func (session *SessionCtx) SetAsHost() {
|
||||
session.manager.setHost(session, session)
|
||||
}
|
||||
|
||||
func (session *SessionCtx) SetAsHostBy(host types.Session) {
|
||||
session.manager.setHost(session, host)
|
||||
}
|
||||
|
||||
func (session *SessionCtx) ClearHost() {
|
||||
session.manager.setHost(session, nil)
|
||||
}
|
||||
|
||||
func (session *SessionCtx) PrivateModeEnabled() bool {
|
||||
return session.manager.Settings().PrivateMode && !session.profile.IsAdmin
|
||||
}
|
||||
|
||||
func (session *SessionCtx) SetCursor(cursor types.Cursor) {
|
||||
if session.manager.Settings().InactiveCursors && session.profile.SendsInactiveCursor {
|
||||
session.manager.SetCursor(cursor, session)
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// websocket
|
||||
// ---
|
||||
|
||||
// Connect WebSocket peer sets current peer and emits connected event. It also destroys the
|
||||
// previous peer, if there was one. If the peer is already set, it will be ignored.
|
||||
func (session *SessionCtx) ConnectWebSocketPeer(websocketPeer types.WebSocketPeer) {
|
||||
session.websocketMu.Lock()
|
||||
isCurrentPeer := websocketPeer == session.websocketPeer
|
||||
session.websocketPeer, websocketPeer = websocketPeer, session.websocketPeer
|
||||
session.websocketMu.Unlock()
|
||||
|
||||
// ignore if already set
|
||||
if isCurrentPeer {
|
||||
return
|
||||
}
|
||||
|
||||
session.logger.Info().Msg("set websocket connected")
|
||||
|
||||
// update state
|
||||
now := time.Now()
|
||||
session.state.IsConnected = true
|
||||
session.state.ConnectedSince = &now
|
||||
session.state.NotConnectedSince = nil
|
||||
|
||||
session.manager.emmiter.Emit("connected", session)
|
||||
|
||||
// if there is a previous peer, destroy it
|
||||
if websocketPeer != nil {
|
||||
websocketPeer.Destroy("connection replaced")
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect WebSocket peer sets current peer to nil and emits disconnected event. It also
|
||||
// allows for a delayed disconnect. That means, the peer will not be disconnected immediately,
|
||||
// but after a delay. If the peer is connected again before the delay, the disconnect will be
|
||||
// cancelled.
|
||||
//
|
||||
// If the peer is not the current peer or the peer is nil, it will be ignored.
|
||||
func (session *SessionCtx) DisconnectWebSocketPeer(websocketPeer types.WebSocketPeer, delayed bool) {
|
||||
session.websocketMu.Lock()
|
||||
isCurrentPeer := websocketPeer == session.websocketPeer && websocketPeer != nil
|
||||
session.websocketMu.Unlock()
|
||||
|
||||
// ignore if not current peer
|
||||
if !isCurrentPeer {
|
||||
return
|
||||
}
|
||||
|
||||
//
|
||||
// ws delayed
|
||||
//
|
||||
|
||||
var wsDelayedTimer *time.Timer
|
||||
|
||||
if delayed {
|
||||
wsDelayedTimer = time.AfterFunc(WS_DELAYED_DURATION, func() {
|
||||
session.DisconnectWebSocketPeer(websocketPeer, false)
|
||||
})
|
||||
}
|
||||
|
||||
session.wsDelayedMu.Lock()
|
||||
if session.wsDelayedTimer != nil {
|
||||
session.wsDelayedTimer.Stop()
|
||||
}
|
||||
session.wsDelayedTimer = wsDelayedTimer
|
||||
session.wsDelayedMu.Unlock()
|
||||
|
||||
if delayed {
|
||||
session.logger.Info().Msg("delayed websocket disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
//
|
||||
// not delayed
|
||||
//
|
||||
|
||||
session.logger.Info().Msg("set websocket disconnected")
|
||||
|
||||
now := time.Now()
|
||||
session.state.IsConnected = false
|
||||
session.state.ConnectedSince = nil
|
||||
session.state.NotConnectedSince = &now
|
||||
|
||||
session.manager.emmiter.Emit("disconnected", session)
|
||||
|
||||
session.websocketMu.Lock()
|
||||
if websocketPeer == session.websocketPeer {
|
||||
session.websocketPeer = nil
|
||||
}
|
||||
session.websocketMu.Unlock()
|
||||
}
|
||||
|
||||
// Destroy WebSocket peer disconnects the peer and destroys it. It ensures that the peer is
|
||||
// disconnected immediately even though normal flow would be to disconnect it delayed.
|
||||
func (session *SessionCtx) DestroyWebSocketPeer(reason string) {
|
||||
session.websocketMu.Lock()
|
||||
peer := session.websocketPeer
|
||||
session.websocketMu.Unlock()
|
||||
|
||||
if peer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// disconnect peer first, so that it is not used anymore
|
||||
session.DisconnectWebSocketPeer(peer, false)
|
||||
|
||||
// destroy it afterwards
|
||||
peer.Destroy(reason)
|
||||
}
|
||||
|
||||
// Send event to websocket peer.
|
||||
func (session *SessionCtx) Send(event string, payload any) {
|
||||
session.websocketMu.Lock()
|
||||
peer := session.websocketPeer
|
||||
session.websocketMu.Unlock()
|
||||
|
||||
if peer != nil {
|
||||
peer.Send(event, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
// webrtc
|
||||
// ---
|
||||
|
||||
// Set webrtc peer and destroy the old one, if there is old one.
|
||||
func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) {
|
||||
session.webrtcMu.Lock()
|
||||
session.webrtcPeer, webrtcPeer = webrtcPeer, session.webrtcPeer
|
||||
session.webrtcMu.Unlock()
|
||||
|
||||
if webrtcPeer != nil && webrtcPeer != session.webrtcPeer {
|
||||
webrtcPeer.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
// Set if current webrtc peer is connected or not. Since there might be lefover calls from
|
||||
// webrtc peer, that are not used anymore, we need to check if the webrtc peer is still the
|
||||
// same as the one we are setting the connected state for.
|
||||
//
|
||||
// If webrtc peer is disconnected, we don't expect it to be reconnected, so we set it to nil
|
||||
// and send a signal close to the client. New connection is expected to use a new webrtc peer.
|
||||
func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, connected bool) {
|
||||
session.webrtcMu.Lock()
|
||||
isCurrentPeer := webrtcPeer == session.webrtcPeer
|
||||
session.webrtcMu.Unlock()
|
||||
|
||||
if !isCurrentPeer {
|
||||
return
|
||||
}
|
||||
|
||||
session.logger.Info().
|
||||
Bool("connected", connected).
|
||||
Msg("set webrtc connected")
|
||||
|
||||
// update state
|
||||
session.state.IsWatching = connected
|
||||
if now := time.Now(); connected {
|
||||
session.state.WatchingSince = &now
|
||||
session.state.NotWatchingSince = nil
|
||||
} else {
|
||||
session.state.WatchingSince = nil
|
||||
session.state.NotWatchingSince = &now
|
||||
}
|
||||
|
||||
session.manager.emmiter.Emit("state_changed", session)
|
||||
|
||||
if connected {
|
||||
return
|
||||
}
|
||||
|
||||
session.webrtcMu.Lock()
|
||||
isCurrentPeer = webrtcPeer == session.webrtcPeer
|
||||
if isCurrentPeer {
|
||||
session.webrtcPeer = nil
|
||||
}
|
||||
session.webrtcMu.Unlock()
|
||||
|
||||
if isCurrentPeer {
|
||||
session.Send(event.SIGNAL_CLOSE, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current WebRTC peer. Nil if not connected.
|
||||
func (session *SessionCtx) GetWebRTCPeer() types.WebRTCPeer {
|
||||
session.webrtcMu.Lock()
|
||||
defer session.webrtcMu.Unlock()
|
||||
|
||||
return session.webrtcPeer
|
||||
}
|
168
server/internal/webrtc/cursor/image.go
Normal file
168
server/internal/webrtc/cursor/image.go
Normal file
@ -0,0 +1,168 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type ImageListener interface {
|
||||
SendCursorImage(cur *types.CursorImage, img []byte) error
|
||||
}
|
||||
|
||||
type Image interface {
|
||||
Start()
|
||||
Shutdown()
|
||||
GetCurrent() (cur *types.CursorImage, img []byte, err error)
|
||||
AddListener(listener ImageListener)
|
||||
RemoveListener(listener ImageListener)
|
||||
}
|
||||
|
||||
type imageEntry struct {
|
||||
*types.CursorImage
|
||||
ImagePNG []byte
|
||||
}
|
||||
|
||||
type image struct {
|
||||
logger zerolog.Logger
|
||||
desktop types.DesktopManager
|
||||
|
||||
listeners map[uintptr]ImageListener
|
||||
listenersMu sync.RWMutex
|
||||
|
||||
cache map[uint64]*imageEntry
|
||||
cacheMu sync.RWMutex
|
||||
current *imageEntry
|
||||
maxSerial uint64
|
||||
}
|
||||
|
||||
func NewImage(logger zerolog.Logger, desktop types.DesktopManager) *image {
|
||||
return &image{
|
||||
logger: logger.With().Str("submodule", "cursor-image").Logger(),
|
||||
desktop: desktop,
|
||||
listeners: map[uintptr]ImageListener{},
|
||||
cache: map[uint64]*imageEntry{},
|
||||
maxSerial: 300, // TODO: Cleanup?
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) Start() {
|
||||
manager.desktop.OnCursorChanged(func(serial uint64) {
|
||||
entry, err := manager.getCached(serial)
|
||||
if err != nil {
|
||||
manager.logger.Err(err).Msg("failed to get cursor image")
|
||||
return
|
||||
}
|
||||
|
||||
manager.current = entry
|
||||
|
||||
manager.listenersMu.RLock()
|
||||
for _, l := range manager.listeners {
|
||||
if err := l.SendCursorImage(entry.CursorImage, entry.ImagePNG); err != nil {
|
||||
manager.logger.Err(err).Msg("failed to set cursor image")
|
||||
}
|
||||
}
|
||||
manager.listenersMu.RUnlock()
|
||||
})
|
||||
|
||||
manager.logger.Info().Msg("starting")
|
||||
}
|
||||
|
||||
func (manager *image) Shutdown() {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
for key := range manager.listeners {
|
||||
delete(manager.listeners, key)
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
}
|
||||
|
||||
func (manager *image) getCached(serial uint64) (*imageEntry, error) {
|
||||
// zero means no serial available
|
||||
if serial == 0 || serial > manager.maxSerial {
|
||||
manager.logger.Debug().Uint64("serial", serial).Msg("cache bypass")
|
||||
return manager.fetchEntry()
|
||||
}
|
||||
|
||||
manager.cacheMu.RLock()
|
||||
entry, ok := manager.cache[serial]
|
||||
manager.cacheMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
manager.logger.Debug().Uint64("serial", serial).Msg("cache miss")
|
||||
|
||||
entry, err := manager.fetchEntry()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager.cacheMu.Lock()
|
||||
manager.cache[entry.Serial] = entry
|
||||
manager.cacheMu.Unlock()
|
||||
|
||||
if entry.Serial != serial {
|
||||
manager.logger.Warn().
|
||||
Uint64("expected-serial", serial).
|
||||
Uint64("received-serial", entry.Serial).
|
||||
Msg("serial mismatch")
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (manager *image) GetCurrent() (cur *types.CursorImage, img []byte, err error) {
|
||||
if manager.current != nil {
|
||||
return manager.current.CursorImage, manager.current.ImagePNG, nil
|
||||
}
|
||||
|
||||
entry, err := manager.fetchEntry()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
manager.current = entry
|
||||
return entry.CursorImage, entry.ImagePNG, nil
|
||||
}
|
||||
|
||||
func (manager *image) AddListener(listener ImageListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
manager.listeners[ptr] = listener
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) RemoveListener(listener ImageListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
delete(manager.listeners, ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *image) fetchEntry() (*imageEntry, error) {
|
||||
cur := manager.desktop.GetCursorImage()
|
||||
|
||||
img, err := utils.CreatePNGImage(cur.Image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cur.Image = nil // free memory
|
||||
|
||||
return &imageEntry{
|
||||
CursorImage: cur,
|
||||
ImagePNG: img,
|
||||
}, nil
|
||||
}
|
74
server/internal/webrtc/cursor/position.go
Normal file
74
server/internal/webrtc/cursor/position.go
Normal file
@ -0,0 +1,74 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type PositionListener interface {
|
||||
SendCursorPosition(x, y int) error
|
||||
}
|
||||
|
||||
type Position interface {
|
||||
Shutdown()
|
||||
Set(x, y int)
|
||||
AddListener(listener PositionListener)
|
||||
RemoveListener(listener PositionListener)
|
||||
}
|
||||
|
||||
type position struct {
|
||||
logger zerolog.Logger
|
||||
|
||||
listeners map[uintptr]PositionListener
|
||||
listenersMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPosition(logger zerolog.Logger) *position {
|
||||
return &position{
|
||||
logger: logger.With().Str("submodule", "cursor-position").Logger(),
|
||||
listeners: map[uintptr]PositionListener{},
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) Shutdown() {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.listenersMu.Lock()
|
||||
for key := range manager.listeners {
|
||||
delete(manager.listeners, key)
|
||||
}
|
||||
manager.listenersMu.Unlock()
|
||||
}
|
||||
|
||||
func (manager *position) Set(x, y int) {
|
||||
manager.listenersMu.RLock()
|
||||
defer manager.listenersMu.RUnlock()
|
||||
|
||||
for _, l := range manager.listeners {
|
||||
if err := l.SendCursorPosition(x, y); err != nil {
|
||||
manager.logger.Err(err).Msg("failed to set cursor position")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) AddListener(listener PositionListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
manager.listeners[ptr] = listener
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *position) RemoveListener(listener PositionListener) {
|
||||
manager.listenersMu.Lock()
|
||||
defer manager.listenersMu.Unlock()
|
||||
|
||||
if listener != nil {
|
||||
ptr := reflect.ValueOf(listener).Pointer()
|
||||
delete(manager.listeners, ptr)
|
||||
}
|
||||
}
|
205
server/internal/webrtc/handler.go
Normal file
205
server/internal/webrtc/handler.go
Normal file
@ -0,0 +1,205 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func (manager *WebRTCManagerCtx) handle(
|
||||
logger zerolog.Logger, data []byte,
|
||||
dataChannel *webrtc.DataChannel,
|
||||
session types.Session,
|
||||
) error {
|
||||
isHost := session.IsHost()
|
||||
|
||||
//
|
||||
// parse header
|
||||
//
|
||||
|
||||
buffer := bytes.NewBuffer(data)
|
||||
|
||||
header := &payload.Header{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//
|
||||
// parse body
|
||||
//
|
||||
|
||||
// handle cursor move event
|
||||
if header.Event == payload.OP_MOVE {
|
||||
payload := &payload.Move{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
x, y := int(payload.X), int(payload.Y)
|
||||
if isHost {
|
||||
// handle active cursor movement
|
||||
manager.desktop.Move(x, y)
|
||||
manager.curPosition.Set(x, y)
|
||||
} else {
|
||||
// handle inactive cursor movement
|
||||
session.SetCursor(types.Cursor{
|
||||
X: x,
|
||||
Y: y,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if header.Event == payload.OP_PING {
|
||||
ping := &payload.Ping{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, ping); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create pong header
|
||||
header := payload.Header{
|
||||
Event: payload.OP_PONG,
|
||||
Length: 19,
|
||||
}
|
||||
|
||||
// generate server timestamp
|
||||
serverTs := uint64(time.Now().UnixMilli())
|
||||
|
||||
// generate pong payload
|
||||
pong := payload.Pong{
|
||||
Ping: *ping,
|
||||
ServerTs1: uint32(serverTs / math.MaxUint32),
|
||||
ServerTs2: uint32(serverTs % math.MaxUint32),
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, pong); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dataChannel.Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
// continue only if session is host
|
||||
if !isHost {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch header.Event {
|
||||
case payload.OP_SCROLL:
|
||||
// TODO: remove this once the client is fixed
|
||||
if header.Length == 4 {
|
||||
payload := &payload.Scroll_Old{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.desktop.Scroll(int(payload.X), int(payload.Y), false)
|
||||
logger.Trace().
|
||||
Int16("x", payload.X).
|
||||
Int16("y", payload.Y).
|
||||
Msg("scroll")
|
||||
} else {
|
||||
payload := &payload.Scroll{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager.desktop.Scroll(int(payload.DeltaX), int(payload.DeltaY), payload.ControlKey)
|
||||
logger.Trace().
|
||||
Int16("deltaX", payload.DeltaX).
|
||||
Int16("deltaY", payload.DeltaY).
|
||||
Bool("controlKey", payload.ControlKey).
|
||||
Msg("scroll")
|
||||
}
|
||||
case payload.OP_KEY_DOWN:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.KeyDown(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key down failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("key down")
|
||||
}
|
||||
case payload.OP_KEY_UP:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.KeyUp(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("key up failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("key up")
|
||||
}
|
||||
case payload.OP_BTN_DOWN:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.ButtonDown(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button down failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("button down")
|
||||
}
|
||||
case payload.OP_BTN_UP:
|
||||
payload := &payload.Key{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.ButtonUp(payload.Key); err != nil {
|
||||
logger.Warn().Err(err).Uint32("key", payload.Key).Msg("button up failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("key", payload.Key).Msg("button up")
|
||||
}
|
||||
case payload.OP_TOUCH_BEGIN:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchBegin(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch begin failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch begin")
|
||||
}
|
||||
case payload.OP_TOUCH_UPDATE:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchUpdate(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch update failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch update")
|
||||
}
|
||||
case payload.OP_TOUCH_END:
|
||||
payload := &payload.Touch{}
|
||||
if err := binary.Read(buffer, binary.BigEndian, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manager.desktop.TouchEnd(payload.TouchId, int(payload.X), int(payload.Y), payload.Pressure); err != nil {
|
||||
logger.Warn().Err(err).Uint32("touchId", payload.TouchId).Msg("touch end failed")
|
||||
} else {
|
||||
logger.Trace().Uint32("touchId", payload.TouchId).Msg("touch end")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
576
server/internal/webrtc/manager.go
Normal file
576
server/internal/webrtc/manager.go
Normal file
@ -0,0 +1,576 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/interceptor/pkg/cc"
|
||||
"github.com/pion/interceptor/pkg/gcc"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/webrtc/cursor"
|
||||
"github.com/demodesk/neko/internal/webrtc/pionlog"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// size of receiving channel used to buffer incoming TCP packets
|
||||
tcpReadChanBufferSize = 50
|
||||
|
||||
// size of buffer used to buffer outgoing TCP packets. Default is 4MB
|
||||
tcpWriteBufferSizeInBytes = 4 * 1024 * 1024
|
||||
|
||||
// the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds
|
||||
disconnectedTimeout = 4 * time.Second
|
||||
|
||||
// the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds
|
||||
failedTimeout = 6 * time.Second
|
||||
|
||||
// how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds
|
||||
keepAliveInterval = 2 * time.Second
|
||||
|
||||
// send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval
|
||||
rtcpPLIInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
func New(desktop types.DesktopManager, capture types.CaptureManager, config *config.WebRTC) *WebRTCManagerCtx {
|
||||
logger := log.With().Str("module", "webrtc").Logger()
|
||||
|
||||
configuration := webrtc.Configuration{
|
||||
SDPSemantics: webrtc.SDPSemanticsUnifiedPlan,
|
||||
}
|
||||
|
||||
if !config.ICELite {
|
||||
ICEServers := []webrtc.ICEServer{}
|
||||
for _, server := range config.ICEServersBackend {
|
||||
var credential any
|
||||
if server.Credential != "" {
|
||||
credential = server.Credential
|
||||
} else {
|
||||
credential = false
|
||||
}
|
||||
|
||||
ICEServers = append(ICEServers, webrtc.ICEServer{
|
||||
URLs: server.URLs,
|
||||
Username: server.Username,
|
||||
Credential: credential,
|
||||
})
|
||||
}
|
||||
|
||||
configuration.ICEServers = ICEServers
|
||||
}
|
||||
|
||||
return &WebRTCManagerCtx{
|
||||
logger: logger,
|
||||
config: config,
|
||||
metrics: newMetricsManager(),
|
||||
|
||||
webrtcConfiguration: configuration,
|
||||
|
||||
desktop: desktop,
|
||||
capture: capture,
|
||||
curImage: cursor.NewImage(logger, desktop),
|
||||
curPosition: cursor.NewPosition(logger),
|
||||
}
|
||||
}
|
||||
|
||||
type WebRTCManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
config *config.WebRTC
|
||||
metrics *metricsManager
|
||||
peerId int32
|
||||
|
||||
desktop types.DesktopManager
|
||||
capture types.CaptureManager
|
||||
curImage cursor.Image
|
||||
curPosition cursor.Position
|
||||
|
||||
webrtcConfiguration webrtc.Configuration
|
||||
|
||||
tcpMux ice.TCPMux
|
||||
udpMux ice.UDPMux
|
||||
|
||||
camStop, micStop *func()
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) Start() {
|
||||
manager.curImage.Start()
|
||||
|
||||
logger := pionlog.New(manager.logger)
|
||||
|
||||
// add TCP Mux listener
|
||||
if manager.config.TCPMux > 0 {
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.IP{0, 0, 0, 0},
|
||||
Port: manager.config.TCPMux,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
manager.logger.Fatal().Err(err).Msg("unable to setup ice TCP mux")
|
||||
}
|
||||
|
||||
manager.tcpMux = ice.NewTCPMuxDefault(ice.TCPMuxParams{
|
||||
Listener: tcpListener,
|
||||
Logger: logger.NewLogger("ice-tcp"),
|
||||
ReadBufferSize: tcpReadChanBufferSize,
|
||||
WriteBufferSize: tcpWriteBufferSizeInBytes,
|
||||
})
|
||||
}
|
||||
|
||||
// add UDP Mux listener
|
||||
if manager.config.UDPMux > 0 {
|
||||
var err error
|
||||
manager.udpMux, err = ice.NewMultiUDPMuxFromPort(manager.config.UDPMux,
|
||||
ice.UDPMuxFromPortWithLogger(logger.NewLogger("ice-udp")),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
manager.logger.Fatal().Err(err).Msg("unable to setup ice UDP mux")
|
||||
}
|
||||
}
|
||||
|
||||
manager.logger.Info().
|
||||
Bool("icelite", manager.config.ICELite).
|
||||
Bool("icetrickle", manager.config.ICETrickle).
|
||||
Interface("iceservers-frontend", manager.config.ICEServersFrontend).
|
||||
Interface("iceservers-backend", manager.config.ICEServersBackend).
|
||||
Str("nat1to1", strings.Join(manager.config.NAT1To1IPs, ",")).
|
||||
Str("epr", fmt.Sprintf("%d-%d", manager.config.EphemeralMin, manager.config.EphemeralMax)).
|
||||
Int("tcpmux", manager.config.TCPMux).
|
||||
Int("udpmux", manager.config.UDPMux).
|
||||
Msg("webrtc starting")
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
|
||||
manager.curImage.Shutdown()
|
||||
manager.curPosition.Shutdown()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer {
|
||||
return manager.config.ICEServersFrontend
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) {
|
||||
// create media engine
|
||||
engine := &webrtc.MediaEngine{}
|
||||
for _, codec := range codecs {
|
||||
if err := codec.Register(engine); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// create setting engine
|
||||
settings := webrtc.SettingEngine{
|
||||
LoggerFactory: pionlog.New(logger),
|
||||
}
|
||||
|
||||
settings.DisableMediaEngineCopy(true)
|
||||
settings.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
settings.SetNAT1To1IPs(manager.config.NAT1To1IPs, webrtc.ICECandidateTypeHost)
|
||||
settings.SetLite(manager.config.ICELite)
|
||||
// make sure server answer sdp setup as passive, to not force DTLS renegotiation
|
||||
// otherwise iOS renegotiation fails with: Failed to set SSL role for the transport.
|
||||
settings.SetAnsweringDTLSRole(webrtc.DTLSRoleServer)
|
||||
|
||||
var networkType []webrtc.NetworkType
|
||||
|
||||
// udp candidates
|
||||
if manager.udpMux != nil {
|
||||
settings.SetICEUDPMux(manager.udpMux)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeUDP4,
|
||||
webrtc.NetworkTypeUDP6,
|
||||
)
|
||||
} else if manager.config.EphemeralMax != 0 {
|
||||
_ = settings.SetEphemeralUDPPortRange(manager.config.EphemeralMin, manager.config.EphemeralMax)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeUDP4,
|
||||
webrtc.NetworkTypeUDP6,
|
||||
)
|
||||
}
|
||||
|
||||
// tcp candidates
|
||||
if manager.tcpMux != nil {
|
||||
settings.SetICETCPMux(manager.tcpMux)
|
||||
networkType = append(networkType,
|
||||
webrtc.NetworkTypeTCP4,
|
||||
webrtc.NetworkTypeTCP6,
|
||||
)
|
||||
}
|
||||
|
||||
// enable support for TCP and UDP ICE candidates
|
||||
settings.SetNetworkTypes(networkType)
|
||||
|
||||
// create interceptor registry
|
||||
registry := &interceptor.Registry{}
|
||||
|
||||
// create bandwidth estimator
|
||||
estimatorChan := make(chan cc.BandwidthEstimator, 1)
|
||||
if manager.config.Estimator.Enabled {
|
||||
congestionController, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
|
||||
return gcc.NewSendSideBWE(
|
||||
gcc.SendSideBWEInitialBitrate(manager.config.Estimator.InitialBitrate),
|
||||
gcc.SendSideBWEPacer(gcc.NewNoOpPacer()),
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
congestionController.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) {
|
||||
estimatorChan <- estimator
|
||||
})
|
||||
|
||||
registry.Add(congestionController)
|
||||
if err = webrtc.ConfigureTWCCHeaderExtensionSender(engine, registry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
// no estimator, send nil
|
||||
estimatorChan <- nil
|
||||
}
|
||||
|
||||
if err := webrtc.RegisterDefaultInterceptors(engine, registry); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// create new API
|
||||
api := webrtc.NewAPI(
|
||||
webrtc.WithMediaEngine(engine),
|
||||
webrtc.WithSettingEngine(settings),
|
||||
webrtc.WithInterceptorRegistry(registry),
|
||||
)
|
||||
|
||||
// create new peer connection
|
||||
configuration := manager.webrtcConfiguration
|
||||
connection, err := api.NewPeerConnection(configuration)
|
||||
return connection, <-estimatorChan, err
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.SessionDescription, types.WebRTCPeer, error) {
|
||||
id := atomic.AddInt32(&manager.peerId, 1)
|
||||
|
||||
// get metrics for session
|
||||
metrics := manager.metrics.getBySession(session)
|
||||
metrics.NewConnection()
|
||||
|
||||
// add session id to logger context
|
||||
logger := manager.logger.With().Str("session_id", session.ID()).Int32("peer_id", id).Logger()
|
||||
logger.Info().Msg("creating webrtc peer")
|
||||
|
||||
// all audios must have the same codec
|
||||
audio := manager.capture.Audio()
|
||||
audioCodec := audio.Codec()
|
||||
|
||||
// all videos must have the same codec
|
||||
video := manager.capture.Video()
|
||||
videoCodec := video.Codec()
|
||||
|
||||
connection, estimator, err := manager.newPeerConnection(
|
||||
logger, []codec.RTPCodec{audioCodec, videoCodec})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// asynchronously send local ICE Candidates
|
||||
if manager.config.ICETrickle {
|
||||
connection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
logger.Debug().Msg("all local ice candidates sent")
|
||||
return
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_CANDIDATE,
|
||||
message.SignalCandidate{
|
||||
ICECandidateInit: candidate.ToJSON(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// audio track
|
||||
audioTrack, err := NewTrack(logger, audioCodec, connection)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we disable audio by default manually
|
||||
audioTrack.SetPaused(true)
|
||||
|
||||
// set stream for audio track
|
||||
_, err = audioTrack.SetStream(audio)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// video track
|
||||
videoRtcp := make(chan []rtcp.Packet, 1)
|
||||
videoTrack, err := NewTrack(logger, videoCodec, connection, WithRtcpChan(videoRtcp))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
//
|
||||
// stream for video track will be set later
|
||||
//
|
||||
|
||||
// data channel
|
||||
|
||||
dataChannel, err := connection.CreateDataChannel("data", nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer := &WebRTCPeerCtx{
|
||||
logger: logger,
|
||||
session: session,
|
||||
metrics: metrics,
|
||||
connection: connection,
|
||||
// bandwidth estimator
|
||||
estimator: estimator,
|
||||
estimateTrend: utils.NewTrendDetector(
|
||||
utils.TrendDetectorParams{
|
||||
// Probing
|
||||
//RequiredSamples: 3,
|
||||
//DownwardTrendThreshold: 0.0,
|
||||
//CollapseValues: false,
|
||||
// Non-Probing
|
||||
RequiredSamples: 8,
|
||||
DownwardTrendThreshold: -0.5,
|
||||
CollapseValues: true,
|
||||
}),
|
||||
// stream selectors
|
||||
video: video,
|
||||
audio: audio,
|
||||
// tracks & channels
|
||||
audioTrack: audioTrack,
|
||||
videoTrack: videoTrack,
|
||||
dataChannel: dataChannel,
|
||||
rtcpChannel: videoRtcp,
|
||||
// config
|
||||
iceTrickle: manager.config.ICETrickle,
|
||||
estimatorConfig: manager.config.Estimator,
|
||||
audioDisabled: true, // we disable audio by default manually
|
||||
}
|
||||
|
||||
connection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
logger := logger.With().
|
||||
Str("kind", track.Kind().String()).
|
||||
Str("mime", track.Codec().RTPCodecCapability.MimeType).
|
||||
Logger()
|
||||
|
||||
logger.Info().Msgf("received new remote track")
|
||||
|
||||
if !session.Profile().CanShareMedia {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("media sharing is disabled for this session")
|
||||
return
|
||||
}
|
||||
|
||||
// parse codec from remote track
|
||||
codec, ok := codec.ParseRTC(track.Codec())
|
||||
if !ok {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("remote track with unknown codec")
|
||||
return
|
||||
}
|
||||
|
||||
var srcManager types.StreamSrcManager
|
||||
|
||||
stopped := false
|
||||
stopFn := func() {
|
||||
if stopped {
|
||||
return
|
||||
}
|
||||
|
||||
stopped = true
|
||||
err := receiver.Stop()
|
||||
srcManager.Stop()
|
||||
logger.Err(err).Msg("remote track stopped")
|
||||
}
|
||||
|
||||
if track.Kind() == webrtc.RTPCodecTypeAudio {
|
||||
// audio -> microphone
|
||||
srcManager = manager.capture.Microphone()
|
||||
defer stopFn()
|
||||
|
||||
if manager.micStop != nil {
|
||||
(*manager.micStop)()
|
||||
}
|
||||
manager.micStop = &stopFn
|
||||
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
|
||||
// video -> webcam
|
||||
srcManager = manager.capture.Webcam()
|
||||
defer stopFn()
|
||||
|
||||
if manager.camStop != nil {
|
||||
(*manager.camStop)()
|
||||
}
|
||||
manager.camStop = &stopFn
|
||||
} else {
|
||||
err := receiver.Stop()
|
||||
logger.Warn().Err(err).Msg("remote track with unsupported codec type")
|
||||
return
|
||||
}
|
||||
|
||||
err := srcManager.Start(codec)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to start pipeline")
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(rtcpPLIInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
err := connection.WriteRTCP([]rtcp.Packet{
|
||||
&rtcp.PictureLossIndication{
|
||||
MediaSSRC: uint32(track.SSRC()),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("remote track rtcp send err")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1400)
|
||||
for {
|
||||
i, _, err := track.Read(buf)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed read from remote track")
|
||||
break
|
||||
}
|
||||
|
||||
srcManager.Push(buf[:i])
|
||||
}
|
||||
|
||||
logger.Info().Msg("remote track data finished")
|
||||
})
|
||||
|
||||
connection.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
logger.Info().Interface("data_channel", dc).Msg("got remote data channel")
|
||||
})
|
||||
|
||||
var once sync.Once
|
||||
connection.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
session.SetWebRTCConnected(peer, true)
|
||||
case webrtc.PeerConnectionStateDisconnected,
|
||||
webrtc.PeerConnectionStateFailed:
|
||||
peer.Destroy()
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
// ensure we only run this once
|
||||
once.Do(func() {
|
||||
session.SetWebRTCConnected(peer, false)
|
||||
//
|
||||
// TODO: Shutdown peer?
|
||||
//
|
||||
audioTrack.Shutdown()
|
||||
videoTrack.Shutdown()
|
||||
close(videoRtcp)
|
||||
})
|
||||
}
|
||||
|
||||
metrics.SetState(state)
|
||||
})
|
||||
|
||||
dataChannel.OnOpen(func() {
|
||||
manager.curImage.AddListener(peer)
|
||||
manager.curPosition.AddListener(peer)
|
||||
|
||||
// send initial cursor image
|
||||
cur, img, err := manager.curImage.GetCurrent()
|
||||
if err == nil {
|
||||
err := peer.SendCursorImage(cur, img)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to set cursor image")
|
||||
}
|
||||
} else {
|
||||
logger.Err(err).Msg("failed to get cursor image")
|
||||
}
|
||||
|
||||
// send initial cursor position
|
||||
x, y := manager.desktop.GetCursorPosition()
|
||||
err = peer.SendCursorPosition(x, y)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to set cursor position")
|
||||
}
|
||||
})
|
||||
|
||||
dataChannel.OnClose(func() {
|
||||
manager.curImage.RemoveListener(peer)
|
||||
manager.curPosition.RemoveListener(peer)
|
||||
})
|
||||
|
||||
dataChannel.OnMessage(func(message webrtc.DataChannelMessage) {
|
||||
if err := manager.handle(logger, message.Data, dataChannel, session); err != nil {
|
||||
logger.Err(err).Msg("data handle failed")
|
||||
}
|
||||
})
|
||||
|
||||
session.SetWebRTCPeer(peer)
|
||||
|
||||
offer, err := peer.CreateOffer(false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// on negotiation needed handler must be registered after creating initial
|
||||
// offer, otherwise it can fire and intercept sucessful negotiation
|
||||
|
||||
connection.OnNegotiationNeeded(func() {
|
||||
logger.Warn().Msg("negotiation is needed")
|
||||
|
||||
if connection.SignalingState() != webrtc.SignalingStateStable {
|
||||
logger.Warn().Msg("connection isn't stable yet; postponing...")
|
||||
return
|
||||
}
|
||||
|
||||
offer, err := peer.CreateOffer(false)
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("sdp offer failed")
|
||||
return
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_OFFER,
|
||||
message.SignalDescription{
|
||||
SDP: offer.SDP,
|
||||
})
|
||||
})
|
||||
|
||||
// start metrics collectors
|
||||
go metrics.rtcpReceiver(videoRtcp)
|
||||
go metrics.connectionStats(connection)
|
||||
|
||||
// start estimator reader
|
||||
go peer.estimatorReader()
|
||||
|
||||
return offer, peer, nil
|
||||
}
|
||||
|
||||
func (manager *WebRTCManagerCtx) SetCursorPosition(x, y int) {
|
||||
manager.curPosition.Set(x, y)
|
||||
}
|
458
server/internal/webrtc/metrics.go
Normal file
458
server/internal/webrtc/metrics.go
Normal file
@ -0,0 +1,458 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
const (
|
||||
// how often to read and process webrtc connection stats
|
||||
connectionStatsInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
type metricsManager struct {
|
||||
mu sync.Mutex
|
||||
|
||||
sessions map[string]*metrics
|
||||
}
|
||||
|
||||
func newMetricsManager() *metricsManager {
|
||||
return &metricsManager{
|
||||
sessions: map[string]*metrics{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *metricsManager) getBySession(session types.Session) *metrics {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessionId := session.ID()
|
||||
|
||||
met, ok := m.sessions[sessionId]
|
||||
if ok {
|
||||
return met
|
||||
}
|
||||
|
||||
met = &metrics{
|
||||
sessionId: sessionId,
|
||||
|
||||
connectionState: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "connection_state",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Connection state of session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
connectionStateCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "connection_state_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of connection state changes for a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
connectionCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "connection_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Connection count of a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
iceCandidates: map[string]struct{}{},
|
||||
iceCandidatesMu: &sync.Mutex{},
|
||||
iceCandidatesUdpCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ice_candidates_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of ICE candidates sent by a remote client.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "udp",
|
||||
},
|
||||
}),
|
||||
iceCandidatesTcpCount: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ice_candidates_count",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Count of ICE candidates sent by a remote client.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "tcp",
|
||||
},
|
||||
}),
|
||||
|
||||
iceCandidatesUsedUdp: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ice_candidates_used",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Used ICE candidates that are currently in use.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "udp",
|
||||
},
|
||||
}),
|
||||
iceCandidatesUsedTcp: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ice_candidates_used",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Used ICE candidates that are currently in use.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"protocol": "tcp",
|
||||
},
|
||||
}),
|
||||
|
||||
videoIds: map[string]prometheus.Gauge{},
|
||||
videoIdsMu: &sync.Mutex{},
|
||||
|
||||
receiverEstimatedMaximumBitrate: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_estimated_maximum_bitrate",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Estimated Maximum Bitrate from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverEstimatedTargetBitrate: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_estimated_target_bitrate",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Estimated Target Bitrate using Google's congestion control.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
receiverReportDelay: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_delay",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Delay from RTCP, expressed in units of 1/65536 seconds.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverReportJitter: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_jitter",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Jitter from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
receiverReportTotalLost: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "receiver_report_total_lost",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Receiver Report Total Lost from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
transportLayerNacks: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "transport_layer_nacks",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Transport Layer NACKs from RTCP.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
},
|
||||
}),
|
||||
|
||||
iceBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_sent",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Sent bytes to a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "ice",
|
||||
},
|
||||
}),
|
||||
iceBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_received",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Received bytes from a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "ice",
|
||||
},
|
||||
}),
|
||||
|
||||
sctpBytesSent: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_sent",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Sent bytes to a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "sctp",
|
||||
},
|
||||
}),
|
||||
sctpBytesReceived: promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "bytes_received",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Received bytes from a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": sessionId,
|
||||
"transport": "sctp",
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
m.sessions[sessionId] = met
|
||||
return met
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
sessionId string
|
||||
|
||||
connectionState prometheus.Gauge
|
||||
connectionStateCount prometheus.Counter
|
||||
connectionCount prometheus.Counter
|
||||
|
||||
iceCandidates map[string]struct{}
|
||||
iceCandidatesMu *sync.Mutex
|
||||
iceCandidatesUdpCount prometheus.Counter
|
||||
iceCandidatesTcpCount prometheus.Counter
|
||||
|
||||
iceCandidatesUsedUdp prometheus.Gauge
|
||||
iceCandidatesUsedTcp prometheus.Gauge
|
||||
|
||||
videoIds map[string]prometheus.Gauge
|
||||
videoIdsMu *sync.Mutex
|
||||
|
||||
receiverEstimatedMaximumBitrate prometheus.Gauge
|
||||
receiverEstimatedTargetBitrate prometheus.Gauge
|
||||
|
||||
receiverReportDelay prometheus.Gauge
|
||||
receiverReportJitter prometheus.Gauge
|
||||
receiverReportTotalLost prometheus.Gauge
|
||||
|
||||
transportLayerNacks prometheus.Counter
|
||||
|
||||
iceBytesSent prometheus.Gauge
|
||||
iceBytesReceived prometheus.Gauge
|
||||
sctpBytesSent prometheus.Gauge
|
||||
sctpBytesReceived prometheus.Gauge
|
||||
}
|
||||
|
||||
func (met *metrics) reset() {
|
||||
met.videoIdsMu.Lock()
|
||||
for _, entry := range met.videoIds {
|
||||
entry.Set(0)
|
||||
}
|
||||
met.videoIdsMu.Unlock()
|
||||
|
||||
met.iceCandidatesUsedUdp.Set(float64(0))
|
||||
met.iceCandidatesUsedTcp.Set(float64(0))
|
||||
|
||||
met.receiverEstimatedMaximumBitrate.Set(0)
|
||||
|
||||
met.receiverReportDelay.Set(0)
|
||||
met.receiverReportJitter.Set(0)
|
||||
}
|
||||
|
||||
func (met *metrics) NewConnection() {
|
||||
met.connectionCount.Add(1)
|
||||
}
|
||||
|
||||
func (met *metrics) NewICECandidate(candidate webrtc.ICECandidateStats) {
|
||||
met.iceCandidatesMu.Lock()
|
||||
defer met.iceCandidatesMu.Unlock()
|
||||
|
||||
if _, found := met.iceCandidates[candidate.ID]; found {
|
||||
return
|
||||
}
|
||||
|
||||
met.iceCandidates[candidate.ID] = struct{}{}
|
||||
if candidate.Protocol == "udp" {
|
||||
met.iceCandidatesUdpCount.Add(1)
|
||||
} else if candidate.Protocol == "tcp" {
|
||||
met.iceCandidatesTcpCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) SetICECandidatesUsed(candidates []webrtc.ICECandidateStats) {
|
||||
udp, tcp := 0, 0
|
||||
for _, candidate := range candidates {
|
||||
if candidate.Protocol == "udp" {
|
||||
udp++
|
||||
} else if candidate.Protocol == "tcp" {
|
||||
tcp++
|
||||
}
|
||||
}
|
||||
|
||||
met.iceCandidatesUsedUdp.Set(float64(udp))
|
||||
met.iceCandidatesUsedTcp.Set(float64(tcp))
|
||||
}
|
||||
|
||||
func (met *metrics) SetState(state webrtc.PeerConnectionState) {
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateNew:
|
||||
met.connectionState.Set(0)
|
||||
case webrtc.PeerConnectionStateConnecting:
|
||||
met.connectionState.Set(4)
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
met.connectionState.Set(5)
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
met.connectionState.Set(3)
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
met.connectionState.Set(2)
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
met.connectionState.Set(1)
|
||||
met.reset()
|
||||
default:
|
||||
met.connectionState.Set(-1)
|
||||
}
|
||||
|
||||
met.connectionStateCount.Add(1)
|
||||
}
|
||||
|
||||
func (met *metrics) SetVideoID(videoId string) {
|
||||
met.videoIdsMu.Lock()
|
||||
defer met.videoIdsMu.Unlock()
|
||||
|
||||
if _, found := met.videoIds[videoId]; !found {
|
||||
met.videoIds[videoId] = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "video_listeners",
|
||||
Namespace: "neko",
|
||||
Subsystem: "webrtc",
|
||||
Help: "Listeners for Video pipelines by a session.",
|
||||
ConstLabels: map[string]string{
|
||||
"session_id": met.sessionId,
|
||||
"video_id": videoId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for id, entry := range met.videoIds {
|
||||
if id == videoId {
|
||||
entry.Set(1)
|
||||
} else {
|
||||
entry.Set(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverEstimatedMaximumBitrate(bitrate float32) {
|
||||
met.receiverEstimatedMaximumBitrate.Set(float64(bitrate))
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverEstimatedTargetBitrate(bitrate float64) {
|
||||
met.receiverEstimatedTargetBitrate.Set(bitrate)
|
||||
}
|
||||
|
||||
func (met *metrics) SetReceiverReport(report rtcp.ReceptionReport) {
|
||||
met.receiverReportDelay.Set(float64(report.Delay))
|
||||
met.receiverReportJitter.Set(float64(report.Jitter))
|
||||
met.receiverReportTotalLost.Set(float64(report.TotalLost))
|
||||
}
|
||||
|
||||
func (met *metrics) SetIceTransportStats(data webrtc.TransportStats) {
|
||||
met.iceBytesSent.Set(float64(data.BytesSent))
|
||||
met.iceBytesReceived.Set(float64(data.BytesReceived))
|
||||
}
|
||||
|
||||
func (met *metrics) SetSctpTransportStats(data webrtc.TransportStats) {
|
||||
met.sctpBytesSent.Set(float64(data.BytesSent))
|
||||
met.sctpBytesReceived.Set(float64(data.BytesReceived))
|
||||
}
|
||||
|
||||
//
|
||||
// collectors
|
||||
//
|
||||
|
||||
func (met *metrics) rtcpReceiver(rtcpCh chan []rtcp.Packet) {
|
||||
for {
|
||||
packets, ok := <-rtcpCh
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
for _, p := range packets {
|
||||
switch rtcpPacket := p.(type) {
|
||||
case *rtcp.ReceiverEstimatedMaximumBitrate: // TODO: Deprecated.
|
||||
met.SetReceiverEstimatedMaximumBitrate(rtcpPacket.Bitrate)
|
||||
|
||||
case *rtcp.ReceiverReport:
|
||||
l := len(rtcpPacket.Reports)
|
||||
if l > 0 {
|
||||
// use only last report
|
||||
met.SetReceiverReport(rtcpPacket.Reports[l-1])
|
||||
}
|
||||
case *rtcp.TransportLayerNack:
|
||||
for _, pair := range rtcpPacket.Nacks {
|
||||
packetList := pair.PacketList()
|
||||
met.transportLayerNacks.Add(float64(len(packetList)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (met *metrics) connectionStats(connection *webrtc.PeerConnection) {
|
||||
ticker := time.NewTicker(connectionStatsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
|
||||
stats := connection.GetStats()
|
||||
|
||||
data, ok := stats["iceTransport"].(webrtc.TransportStats)
|
||||
if ok {
|
||||
met.SetIceTransportStats(data)
|
||||
}
|
||||
|
||||
data, ok = stats["sctpTransport"].(webrtc.TransportStats)
|
||||
if ok {
|
||||
met.SetSctpTransportStats(data)
|
||||
}
|
||||
|
||||
remoteCandidates := map[string]webrtc.ICECandidateStats{}
|
||||
nominatedRemoteCandidates := map[string]struct{}{}
|
||||
for _, entry := range stats {
|
||||
// only remote ice candidate stats
|
||||
candidate, ok := entry.(webrtc.ICECandidateStats)
|
||||
if ok && candidate.Type == webrtc.StatsTypeRemoteCandidate {
|
||||
met.NewICECandidate(candidate)
|
||||
remoteCandidates[candidate.ID] = candidate
|
||||
}
|
||||
|
||||
// only nominated ice candidate pair stats
|
||||
pair, ok := entry.(webrtc.ICECandidatePairStats)
|
||||
if ok && pair.Nominated {
|
||||
nominatedRemoteCandidates[pair.RemoteCandidateID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
iceCandidatesUsed := []webrtc.ICECandidateStats{}
|
||||
for id := range nominatedRemoteCandidates {
|
||||
if candidate, ok := remoteCandidates[id]; ok {
|
||||
iceCandidatesUsed = append(iceCandidatesUsed, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
met.SetICECandidatesUsed(iceCandidatesUsed)
|
||||
}
|
||||
}
|
55
server/internal/webrtc/payload/receive.go
Normal file
55
server/internal/webrtc/payload/receive.go
Normal file
@ -0,0 +1,55 @@
|
||||
package payload
|
||||
|
||||
import "math"
|
||||
|
||||
const (
|
||||
OP_MOVE = 0x01
|
||||
OP_SCROLL = 0x02
|
||||
OP_KEY_DOWN = 0x03
|
||||
OP_KEY_UP = 0x04
|
||||
OP_BTN_DOWN = 0x05
|
||||
OP_BTN_UP = 0x06
|
||||
OP_PING = 0x07
|
||||
// touch events
|
||||
OP_TOUCH_BEGIN = 0x08
|
||||
OP_TOUCH_UPDATE = 0x09
|
||||
OP_TOUCH_END = 0x0a
|
||||
)
|
||||
|
||||
type Move struct {
|
||||
X uint16
|
||||
Y uint16
|
||||
}
|
||||
|
||||
// TODO: remove this once the client is fixed
|
||||
type Scroll_Old struct {
|
||||
X int16
|
||||
Y int16
|
||||
}
|
||||
|
||||
type Scroll struct {
|
||||
DeltaX int16
|
||||
DeltaY int16
|
||||
ControlKey bool
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
Key uint32
|
||||
}
|
||||
|
||||
type Ping struct {
|
||||
// client's timestamp split into two uint32
|
||||
ClientTs1 uint32
|
||||
ClientTs2 uint32
|
||||
}
|
||||
|
||||
func (p Ping) ClientTs() uint64 {
|
||||
return (uint64(p.ClientTs1) * uint64(math.MaxUint32)) + uint64(p.ClientTs2)
|
||||
}
|
||||
|
||||
type Touch struct {
|
||||
TouchId uint32
|
||||
X int32
|
||||
Y int32
|
||||
Pressure uint8
|
||||
}
|
33
server/internal/webrtc/payload/send.go
Normal file
33
server/internal/webrtc/payload/send.go
Normal file
@ -0,0 +1,33 @@
|
||||
package payload
|
||||
|
||||
import "math"
|
||||
|
||||
const (
|
||||
OP_CURSOR_POSITION = 0x01
|
||||
OP_CURSOR_IMAGE = 0x02
|
||||
OP_PONG = 0x03
|
||||
)
|
||||
|
||||
type CursorPosition struct {
|
||||
X uint16
|
||||
Y uint16
|
||||
}
|
||||
|
||||
type CursorImage struct {
|
||||
Width uint16
|
||||
Height uint16
|
||||
Xhot uint16
|
||||
Yhot uint16
|
||||
}
|
||||
|
||||
type Pong struct {
|
||||
Ping
|
||||
|
||||
// server's timestamp split into two uint32
|
||||
ServerTs1 uint32
|
||||
ServerTs2 uint32
|
||||
}
|
||||
|
||||
func (p Pong) ServerTs() uint64 {
|
||||
return (uint64(p.ServerTs1) * uint64(math.MaxUint32)) + uint64(p.ServerTs2)
|
||||
}
|
6
server/internal/webrtc/payload/types.go
Normal file
6
server/internal/webrtc/payload/types.go
Normal file
@ -0,0 +1,6 @@
|
||||
package payload
|
||||
|
||||
type Header struct {
|
||||
Event uint8
|
||||
Length uint16
|
||||
}
|
543
server/internal/webrtc/peer.go
Normal file
543
server/internal/webrtc/peer.go
Normal file
@ -0,0 +1,543 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/interceptor/pkg/cc"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/internal/config"
|
||||
"github.com/demodesk/neko/internal/webrtc/payload"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type WebRTCPeerCtx struct {
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
session types.Session
|
||||
metrics *metrics
|
||||
connection *webrtc.PeerConnection
|
||||
// bandwidth estimator
|
||||
estimator cc.BandwidthEstimator
|
||||
estimateTrend *utils.TrendDetector
|
||||
// stream selectors
|
||||
video types.StreamSelectorManager
|
||||
audio types.StreamSinkManager
|
||||
// tracks & channels
|
||||
audioTrack *Track
|
||||
videoTrack *Track
|
||||
dataChannel *webrtc.DataChannel
|
||||
rtcpChannel chan []rtcp.Packet
|
||||
// config
|
||||
iceTrickle bool
|
||||
estimatorConfig config.WebRTCEstimator
|
||||
paused bool
|
||||
videoAuto bool
|
||||
videoDisabled bool
|
||||
audioDisabled bool
|
||||
}
|
||||
|
||||
//
|
||||
// connection
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) CreateOffer(ICERestart bool) (*webrtc.SessionDescription, error) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
offer, err := peer.connection.CreateOffer(&webrtc.OfferOptions{
|
||||
ICERestart: ICERestart,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peer.setLocalDescription(offer)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) CreateAnswer() (*webrtc.SessionDescription, error) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
answer, err := peer.connection.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return peer.setLocalDescription(answer)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) setLocalDescription(description webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
|
||||
if !peer.iceTrickle {
|
||||
// Create channel that is blocked until ICE Gathering is complete
|
||||
gatherComplete := webrtc.GatheringCompletePromise(peer.connection)
|
||||
|
||||
if err := peer.connection.SetLocalDescription(description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
<-gatherComplete
|
||||
} else {
|
||||
if err := peer.connection.SetLocalDescription(description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return peer.connection.LocalDescription(), nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetRemoteDescription(desc webrtc.SessionDescription) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.connection.SetRemoteDescription(desc)
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetCandidate(candidate webrtc.ICECandidateInit) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.connection.AddICECandidate(candidate)
|
||||
}
|
||||
|
||||
// TODO: Add shutdown function?
|
||||
func (peer *WebRTCPeerCtx) Destroy() {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
var err error
|
||||
|
||||
// if peer connection is not closed, close it
|
||||
if peer.connection.ConnectionState() != webrtc.PeerConnectionStateClosed {
|
||||
err = peer.connection.Close()
|
||||
}
|
||||
|
||||
peer.logger.Err(err).Msg("peer connection destroyed")
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) estimatorReader() {
|
||||
conf := peer.estimatorConfig
|
||||
|
||||
// if estimator is not in debug mode, use a nop logger
|
||||
var debugLogger zerolog.Logger
|
||||
if conf.Debug {
|
||||
debugLogger = peer.logger.With().Str("component", "estimator").Logger().Level(zerolog.DebugLevel)
|
||||
} else {
|
||||
debugLogger = zerolog.Nop()
|
||||
}
|
||||
|
||||
// if estimator is disabled, do nothing
|
||||
if peer.estimator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// use a ticker to get current client target bitrate
|
||||
ticker := time.NewTicker(conf.ReadInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// since when is the estimate stable/unstable
|
||||
stableSince := time.Now() // we asume stable at start
|
||||
unstableSince := time.Time{}
|
||||
// since when are we neutral but cannot accomodate current bitrate
|
||||
// we migt be stalled or estimator just reached zer (very bad connection)
|
||||
stalledSince := time.Time{}
|
||||
// when was the last upgrade/downgrade
|
||||
lastUpgradeTime := time.Time{}
|
||||
lastDowngradeTime := time.Time{}
|
||||
|
||||
for range ticker.C {
|
||||
targetBitrate := peer.estimator.GetTargetBitrate()
|
||||
peer.metrics.SetReceiverEstimatedTargetBitrate(float64(targetBitrate))
|
||||
|
||||
// if peer connection is closed, stop reading
|
||||
if peer.connection.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
break
|
||||
}
|
||||
|
||||
// if estimation or video is disabled, do nothing
|
||||
if !peer.videoAuto || peer.videoDisabled || peer.paused || conf.Passive {
|
||||
continue
|
||||
}
|
||||
|
||||
// get trend direction to decide if we should upgrade or downgrade
|
||||
peer.estimateTrend.AddValue(int64(targetBitrate))
|
||||
direction := peer.estimateTrend.GetDirection()
|
||||
|
||||
// get current stream bitrate
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if !ok {
|
||||
debugLogger.Warn().Msg("looks like we don't have a stream yet, skipping bitrate estimation")
|
||||
continue
|
||||
}
|
||||
|
||||
// if stream bitrate is 0, we need to wait for some time until we get a valid value
|
||||
streamId, streamBitrate := stream.ID(), stream.Bitrate()
|
||||
if streamBitrate == 0 {
|
||||
debugLogger.Warn().Msg("looks like stream bitrate is 0, we need to wait for some time")
|
||||
continue
|
||||
}
|
||||
|
||||
// check whats the difference between target and stream bitrate
|
||||
diff := float64(targetBitrate) / float64(streamBitrate)
|
||||
|
||||
debugLogger.Info().
|
||||
Float64("diff", diff).
|
||||
Int("target_bitrate", targetBitrate).
|
||||
Uint64("stream_bitrate", streamBitrate).
|
||||
Str("direction", direction.String()).
|
||||
Msg("got bitrate from estimator")
|
||||
|
||||
// if we can accomodate current stream or we are not netural anymore,
|
||||
// we are not stalled so we reset the stalled time
|
||||
if direction != utils.TrendDirectionNeutral || diff > 1+conf.DiffThreshold {
|
||||
stalledSince = time.Now()
|
||||
}
|
||||
|
||||
// if we are neutral and stalled for too long, we might be congesting
|
||||
stalled := direction == utils.TrendDirectionNeutral && time.Since(stalledSince) > conf.StalledDuration
|
||||
if stalled {
|
||||
debugLogger.Warn().
|
||||
Time("stalled_since", stalledSince).
|
||||
Msgf("it looks like we are stalled")
|
||||
}
|
||||
|
||||
// if we have an downward trend or are stalled, we might be congesting
|
||||
if direction == utils.TrendDirectionDownward || stalled {
|
||||
// we reset the stable time because we are congesting
|
||||
stableSince = time.Now()
|
||||
|
||||
// if we downgraded recently, we wait for some more time
|
||||
if time.Since(lastDowngradeTime) < conf.DowngradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_downgrade", lastDowngradeTime).
|
||||
Msgf("downgraded recently, waiting for at least %v", conf.DowngradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not unstable but we fluctuate we should wait for some more time
|
||||
if time.Since(unstableSince) < conf.UnstableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("unstable_since", unstableSince).
|
||||
Msgf("we are not unstable long enough, waiting for at least %v", conf.UnstableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we still have a big difference between target and stream bitrate, we wait for some more time
|
||||
if conf.DiffThreshold >= 0 && diff > 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("we still have a big difference between target and stream bitrate, " +
|
||||
"therefore we still should be able to accomodate current stream")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.PeerVideoRequest{
|
||||
Selector: &types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeLower,
|
||||
},
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to downgrade video stream")
|
||||
}
|
||||
lastDowngradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the lowest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("downgraded video stream")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// we reset the unstable time because we are not congesting
|
||||
unstableSince = time.Now()
|
||||
|
||||
// if we have a neutral or upward trend, that means our estimate is stable
|
||||
// if we are on the highest stream, we don't need to do anything
|
||||
// but if there is a higher stream, we should try to upgrade and see if it works
|
||||
|
||||
// if we upgraded recently, we wait for some more time
|
||||
if time.Since(lastUpgradeTime) < conf.UpgradeBackoff {
|
||||
debugLogger.Debug().
|
||||
Time("last_upgrade", lastUpgradeTime).
|
||||
Msgf("upgraded recently, waiting for at least %v", conf.UpgradeBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// if we are not stable for long enough, we wait for some more time
|
||||
// because bandwidth estimation might fluctuate
|
||||
if time.Since(stableSince) < conf.StableDuration {
|
||||
debugLogger.Debug().
|
||||
Time("stable_since", stableSince).
|
||||
Msgf("we are not stable long enough, waiting for at least %v", conf.StableDuration)
|
||||
continue
|
||||
}
|
||||
|
||||
// upgrade only if estimated bitrate passed the threshold
|
||||
if conf.DiffThreshold >= 0 && diff < 1+conf.DiffThreshold {
|
||||
debugLogger.Debug().
|
||||
Float64("diff", diff).
|
||||
Float64("threshold", conf.DiffThreshold).
|
||||
Msgf("looks like we don't have enough bitrate to accomodate higher stream, " +
|
||||
"therefore we should wait for some more time")
|
||||
continue
|
||||
}
|
||||
|
||||
err := peer.SetVideo(types.PeerVideoRequest{
|
||||
Selector: &types.StreamSelector{
|
||||
ID: streamId,
|
||||
Type: types.StreamSelectorTypeHigher,
|
||||
},
|
||||
})
|
||||
if err != nil && err != types.ErrWebRTCStreamNotFound {
|
||||
peer.logger.Warn().Err(err).Msg("failed to upgrade video stream")
|
||||
}
|
||||
lastUpgradeTime = time.Now()
|
||||
|
||||
if err == types.ErrWebRTCStreamNotFound {
|
||||
debugLogger.Info().Msg("looks like we are already on the highest stream")
|
||||
} else {
|
||||
debugLogger.Info().Msg("upgraded video stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetPaused(isPaused bool) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
peer.videoTrack.SetPaused(isPaused || peer.videoDisabled)
|
||||
peer.audioTrack.SetPaused(isPaused || peer.audioDisabled)
|
||||
|
||||
peer.logger.Info().Bool("is_paused", isPaused).Msg("set paused")
|
||||
peer.paused = isPaused
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Paused() bool {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return peer.paused
|
||||
}
|
||||
|
||||
//
|
||||
// video
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetVideo(r types.PeerVideoRequest) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
modified := false
|
||||
|
||||
// video disabled
|
||||
if r.Disabled != nil {
|
||||
disabled := *r.Disabled
|
||||
|
||||
// update only if changed
|
||||
if peer.videoDisabled != disabled {
|
||||
peer.videoDisabled = disabled
|
||||
peer.videoTrack.SetPaused(disabled || peer.paused)
|
||||
|
||||
peer.logger.Info().Bool("disabled", disabled).Msg("set video disabled")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// video selector
|
||||
if r.Selector != nil {
|
||||
selector := *r.Selector
|
||||
|
||||
// get requested video stream from selector
|
||||
stream, ok := peer.video.GetStream(selector)
|
||||
if !ok {
|
||||
return types.ErrWebRTCStreamNotFound
|
||||
}
|
||||
|
||||
// set video stream to track
|
||||
changed, err := peer.videoTrack.SetStream(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update only if stream changed
|
||||
if changed {
|
||||
videoID := stream.ID()
|
||||
peer.metrics.SetVideoID(videoID)
|
||||
|
||||
peer.logger.Info().Str("video_id", videoID).Msg("set video")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// video auto
|
||||
if r.Auto != nil {
|
||||
videoAuto := *r.Auto
|
||||
|
||||
if peer.estimator == nil || peer.estimatorConfig.Passive {
|
||||
peer.logger.Warn().Msg("estimator is disabled or in passive mode, cannot change video auto")
|
||||
videoAuto = false // ensure video auto is disabled
|
||||
}
|
||||
|
||||
// update only if video auto changed
|
||||
if peer.videoAuto != videoAuto {
|
||||
peer.videoAuto = videoAuto
|
||||
|
||||
peer.logger.Info().Bool("video_auto", videoAuto).Msg("set video auto")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// send video signal if modified
|
||||
if modified {
|
||||
go func() {
|
||||
// in goroutine because of mutex and we don't want to block
|
||||
peer.session.Send(event.SIGNAL_VIDEO, peer.Video())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Video() types.PeerVideo {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// get current video stream ID
|
||||
ID := ""
|
||||
stream, ok := peer.videoTrack.Stream()
|
||||
if ok {
|
||||
ID = stream.ID()
|
||||
}
|
||||
|
||||
return types.PeerVideo{
|
||||
Disabled: peer.videoDisabled,
|
||||
ID: ID,
|
||||
Video: ID, // TODO: Remove, used for backward compatibility
|
||||
Auto: peer.videoAuto,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// audio
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SetAudio(r types.PeerAudioRequest) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
modified := false
|
||||
|
||||
// audio disabled
|
||||
if r.Disabled != nil {
|
||||
disabled := *r.Disabled
|
||||
|
||||
// update only if changed
|
||||
if peer.audioDisabled != disabled {
|
||||
peer.audioDisabled = disabled
|
||||
peer.audioTrack.SetPaused(disabled || peer.paused)
|
||||
|
||||
peer.logger.Info().Bool("disabled", disabled).Msg("set audio disabled")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// send video signal if modified
|
||||
if modified {
|
||||
go func() {
|
||||
// in goroutine because of mutex and we don't want to block
|
||||
peer.session.Send(event.SIGNAL_AUDIO, peer.Audio())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) Audio() types.PeerAudio {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
return types.PeerAudio{
|
||||
Disabled: peer.audioDisabled,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// data channel
|
||||
//
|
||||
|
||||
func (peer *WebRTCPeerCtx) SendCursorPosition(x, y int) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// do not send cursor position to host
|
||||
if peer.session.IsHost() {
|
||||
return nil
|
||||
}
|
||||
|
||||
header := payload.Header{
|
||||
Event: payload.OP_CURSOR_POSITION,
|
||||
Length: 7,
|
||||
}
|
||||
|
||||
data := payload.CursorPosition{
|
||||
X: uint16(x),
|
||||
Y: uint16(y),
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return peer.dataChannel.Send(buffer.Bytes())
|
||||
}
|
||||
|
||||
func (peer *WebRTCPeerCtx) SendCursorImage(cur *types.CursorImage, img []byte) error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
header := payload.Header{
|
||||
Event: payload.OP_CURSOR_IMAGE,
|
||||
Length: uint16(11 + len(img)),
|
||||
}
|
||||
|
||||
data := payload.CursorImage{
|
||||
Width: cur.Width,
|
||||
Height: cur.Height,
|
||||
Xhot: cur.Xhot,
|
||||
Yhot: cur.Yhot,
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(buffer, binary.BigEndian, img); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return peer.dataChannel.Send(buffer.Bytes())
|
||||
}
|
27
server/internal/webrtc/pionlog/factory.go
Normal file
27
server/internal/webrtc/pionlog/factory.go
Normal file
@ -0,0 +1,27 @@
|
||||
package pionlog
|
||||
|
||||
import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func New(logger zerolog.Logger) Factory {
|
||||
return Factory{
|
||||
Logger: logger.With().Str("submodule", "pion").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
Logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (l Factory) NewLogger(subsystem string) logging.LeveledLogger {
|
||||
if subsystem == "sctp" {
|
||||
return nulllog{}
|
||||
}
|
||||
|
||||
return logger{
|
||||
subsystem: subsystem,
|
||||
logger: l.Logger.With().Str("subsystem", subsystem).Logger(),
|
||||
}
|
||||
}
|
66
server/internal/webrtc/pionlog/logger.go
Normal file
66
server/internal/webrtc/pionlog/logger.go
Normal file
@ -0,0 +1,66 @@
|
||||
package pionlog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
logger zerolog.Logger
|
||||
subsystem string
|
||||
}
|
||||
|
||||
func (l logger) Trace(msg string) {
|
||||
l.logger.Trace().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Tracef(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Trace().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Debug(msg string) {
|
||||
l.logger.Debug().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Debugf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Debug().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Info(msg string) {
|
||||
if strings.Contains(msg, "duplicated packet") {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Infof(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if strings.Contains(msg, "duplicated packet") {
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Warn(msg string) {
|
||||
l.logger.Warn().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Warnf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Warn().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Error(msg string) {
|
||||
l.logger.Error().Msg(strings.TrimSpace(msg))
|
||||
}
|
||||
|
||||
func (l logger) Errorf(format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.logger.Error().Msg(strings.TrimSpace(msg))
|
||||
}
|
14
server/internal/webrtc/pionlog/nullog.go
Normal file
14
server/internal/webrtc/pionlog/nullog.go
Normal file
@ -0,0 +1,14 @@
|
||||
package pionlog
|
||||
|
||||
type nulllog struct{}
|
||||
|
||||
func (l nulllog) Trace(msg string) {}
|
||||
func (l nulllog) Tracef(format string, args ...any) {}
|
||||
func (l nulllog) Debug(msg string) {}
|
||||
func (l nulllog) Debugf(format string, args ...any) {}
|
||||
func (l nulllog) Info(msg string) {}
|
||||
func (l nulllog) Infof(format string, args ...any) {}
|
||||
func (l nulllog) Warn(msg string) {}
|
||||
func (l nulllog) Warnf(format string, args ...any) {}
|
||||
func (l nulllog) Error(msg string) {}
|
||||
func (l nulllog) Errorf(format string, args ...any) {}
|
203
server/internal/webrtc/track.go
Normal file
203
server/internal/webrtc/track.go
Normal file
@ -0,0 +1,203 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pion/webrtc/v3/pkg/media"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/codec"
|
||||
)
|
||||
|
||||
type Track struct {
|
||||
logger zerolog.Logger
|
||||
track *webrtc.TrackLocalStaticSample
|
||||
|
||||
rtcpCh chan []rtcp.Packet
|
||||
sample chan types.Sample
|
||||
|
||||
paused bool
|
||||
stream types.StreamSinkManager
|
||||
streamMu sync.Mutex
|
||||
}
|
||||
|
||||
type trackOption func(*Track)
|
||||
|
||||
func WithRtcpChan(rtcp chan []rtcp.Packet) trackOption {
|
||||
return func(t *Track) {
|
||||
t.rtcpCh = rtcp
|
||||
}
|
||||
}
|
||||
|
||||
func NewTrack(logger zerolog.Logger, codec codec.RTPCodec, connection *webrtc.PeerConnection, opts ...trackOption) (*Track, error) {
|
||||
id := codec.Type.String()
|
||||
track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Track{
|
||||
logger: logger.With().Str("id", id).Logger(),
|
||||
track: track,
|
||||
rtcpCh: nil,
|
||||
sample: make(chan types.Sample),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
sender, err := connection.AddTrack(t.track)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go t.rtcpReader(sender)
|
||||
go t.sampleReader()
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Track) Shutdown() {
|
||||
t.RemoveStream()
|
||||
close(t.sample)
|
||||
}
|
||||
|
||||
func (t *Track) rtcpReader(sender *webrtc.RTPSender) {
|
||||
for {
|
||||
packets, _, err := sender.ReadRTCP()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
t.logger.Debug().Msg("track rtcp reader closed")
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Warn().Err(err).Msg("failed to read track rtcp")
|
||||
continue
|
||||
}
|
||||
|
||||
if t.rtcpCh != nil {
|
||||
t.rtcpCh <- packets
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- sample ---
|
||||
|
||||
func (t *Track) sampleReader() {
|
||||
for {
|
||||
sample, ok := <-t.sample
|
||||
if !ok {
|
||||
t.logger.Debug().Msg("track sample reader closed")
|
||||
return
|
||||
}
|
||||
|
||||
err := t.track.WriteSample(media.Sample{
|
||||
Data: sample.Data,
|
||||
Duration: sample.Duration,
|
||||
Timestamp: sample.Timestamp,
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.logger.Warn().Err(err).Msg("failed to write sample to track")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Track) WriteSample(sample types.Sample) {
|
||||
t.sample <- sample
|
||||
}
|
||||
|
||||
// --- stream ---
|
||||
|
||||
func (t *Track) SetStream(stream types.StreamSinkManager) (bool, error) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if we already listen to the stream, do nothing
|
||||
if t.stream == stream {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// if paused, we switch the stream but don't add the listener
|
||||
if t.paused {
|
||||
t.stream = stream
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if t.stream != nil {
|
||||
err = t.stream.MoveListenerTo(t, stream)
|
||||
} else {
|
||||
err = stream.AddListener(t)
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
t.stream = stream
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (t *Track) RemoveStream() {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if there is no stream, or paused we don't need to remove the listener
|
||||
if t.stream == nil || t.paused {
|
||||
t.stream = nil
|
||||
return
|
||||
}
|
||||
|
||||
err := t.stream.RemoveListener(t)
|
||||
if err != nil {
|
||||
t.logger.Warn().Err(err).Msg("failed to remove listener from stream")
|
||||
}
|
||||
|
||||
t.stream = nil
|
||||
}
|
||||
|
||||
func (t *Track) Stream() (types.StreamSinkManager, bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.stream, t.stream != nil
|
||||
}
|
||||
|
||||
// --- paused ---
|
||||
|
||||
func (t *Track) SetPaused(paused bool) {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
// if there is no state change or no stream, do nothing
|
||||
if t.paused == paused || t.stream == nil {
|
||||
t.paused = paused
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if paused {
|
||||
err = t.stream.RemoveListener(t)
|
||||
} else {
|
||||
err = t.stream.AddListener(t)
|
||||
}
|
||||
if err != nil {
|
||||
t.logger.Warn().Err(err).Msg("failed to change listener state")
|
||||
return
|
||||
}
|
||||
|
||||
t.paused = paused
|
||||
}
|
||||
|
||||
func (t *Track) Paused() bool {
|
||||
t.streamMu.Lock()
|
||||
defer t.streamMu.Unlock()
|
||||
|
||||
return t.paused
|
||||
}
|
67
server/internal/websocket/filechooserdialog.go
Normal file
67
server/internal/websocket/filechooserdialog.go
Normal file
@ -0,0 +1,67 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (manager *WebSocketManagerCtx) fileChooserDialogEvents() {
|
||||
var activeSession types.Session
|
||||
|
||||
// when dialog opens, everyone should be notified.
|
||||
manager.desktop.OnFileChooserDialogOpened(func() {
|
||||
manager.logger.Info().Msg("file chooser dialog opened")
|
||||
|
||||
host, hasHost := manager.sessions.GetHost()
|
||||
if !hasHost {
|
||||
manager.logger.Warn().Msg("no host for file chooser dialog found, closing")
|
||||
go manager.desktop.CloseFileChooserDialog()
|
||||
return
|
||||
}
|
||||
|
||||
activeSession = host
|
||||
|
||||
go manager.sessions.Broadcast(
|
||||
event.FILE_CHOOSER_DIALOG_OPENED,
|
||||
message.SessionID{
|
||||
ID: host.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
// when dialog closes, everyone should be notified.
|
||||
manager.desktop.OnFileChooserDialogClosed(func() {
|
||||
manager.logger.Info().Msg("file chooser dialog closed")
|
||||
|
||||
activeSession = nil
|
||||
|
||||
go manager.sessions.Broadcast(
|
||||
event.FILE_CHOOSER_DIALOG_CLOSED,
|
||||
message.SessionID{})
|
||||
})
|
||||
|
||||
// when new user joins, and someone holds dialog, he shouldd be notified about it.
|
||||
manager.sessions.OnConnected(func(session types.Session) {
|
||||
if activeSession == nil {
|
||||
return
|
||||
}
|
||||
|
||||
manager.logger.Debug().Str("session_id", session.ID()).Msg("sending file chooser dialog status to a new session")
|
||||
|
||||
session.Send(
|
||||
event.FILE_CHOOSER_DIALOG_OPENED,
|
||||
message.SessionID{
|
||||
ID: activeSession.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
// when user, that holds dialog, disconnects, it should be closed.
|
||||
manager.sessions.OnDisconnected(func(session types.Session) {
|
||||
if activeSession == nil || activeSession != session {
|
||||
return
|
||||
}
|
||||
|
||||
manager.logger.Info().Msg("file chooser dialog owner left, closing")
|
||||
manager.desktop.CloseFileChooserDialog()
|
||||
})
|
||||
}
|
23
server/internal/websocket/handler/clipboard.go
Normal file
23
server/internal/websocket/handler/clipboard.go
Normal file
@ -0,0 +1,23 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) clipboardSet(session types.Session, payload *message.ClipboardData) error {
|
||||
if !session.Profile().CanAccessClipboard {
|
||||
return errors.New("cannot access clipboard")
|
||||
}
|
||||
|
||||
if !session.IsHost() {
|
||||
return errors.New("is not the host")
|
||||
}
|
||||
|
||||
return h.desktop.ClipboardSetText(types.ClipboardText{
|
||||
Text: payload.Text,
|
||||
// TODO: Send HTML?
|
||||
})
|
||||
}
|
228
server/internal/websocket/handler/control.go
Normal file
228
server/internal/websocket/handler/control.go
Normal file
@ -0,0 +1,228 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/xorg"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIsNotAllowedToHost = errors.New("is not allowed to host")
|
||||
ErrIsNotTheHost = errors.New("is not the host")
|
||||
ErrIsAlreadyTheHost = errors.New("is already the host")
|
||||
ErrIsAlreadyHosted = errors.New("is already hosted")
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) controlRelease(session types.Session) error {
|
||||
if !session.Profile().CanHost || session.PrivateModeEnabled() {
|
||||
return ErrIsNotAllowedToHost
|
||||
}
|
||||
|
||||
if !session.IsHost() {
|
||||
return ErrIsNotTheHost
|
||||
}
|
||||
|
||||
h.desktop.ResetKeys()
|
||||
session.ClearHost()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlRequest(session types.Session) error {
|
||||
if !session.Profile().CanHost || session.PrivateModeEnabled() {
|
||||
return ErrIsNotAllowedToHost
|
||||
}
|
||||
|
||||
if session.IsHost() {
|
||||
return ErrIsAlreadyTheHost
|
||||
}
|
||||
|
||||
if h.sessions.Settings().LockedControls && !session.Profile().IsAdmin {
|
||||
return ErrIsNotAllowedToHost
|
||||
}
|
||||
|
||||
// if implicit hosting is enabled, set session as host without asking
|
||||
if h.sessions.Settings().ImplicitHosting {
|
||||
session.SetAsHost()
|
||||
return nil
|
||||
}
|
||||
|
||||
// if there is no host, set session as host
|
||||
host, hasHost := h.sessions.GetHost()
|
||||
if !hasHost {
|
||||
session.SetAsHost()
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Some throttling mechanism to prevent spamming.
|
||||
|
||||
// let host know that someone wants to take control
|
||||
host.Send(
|
||||
event.CONTROL_REQUEST,
|
||||
message.SessionID{
|
||||
ID: session.ID(),
|
||||
})
|
||||
|
||||
return ErrIsAlreadyHosted
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlMove(session types.Session, payload *message.ControlPos) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
// handle active cursor movement
|
||||
h.desktop.Move(payload.X, payload.Y)
|
||||
h.webrtc.SetCursorPosition(payload.X, payload.Y)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlScroll(session types.Session, payload *message.ControlScroll) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
// TOOD: remove this once the client is fixed
|
||||
if payload.DeltaX == 0 && payload.DeltaY == 0 {
|
||||
payload.DeltaX = payload.X
|
||||
payload.DeltaY = payload.Y
|
||||
}
|
||||
|
||||
h.desktop.Scroll(payload.DeltaX, payload.DeltaY, payload.ControlKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlButtonPress(session types.Session, payload *message.ControlButton) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.ButtonPress(payload.Code)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlButtonDown(session types.Session, payload *message.ControlButton) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.ButtonDown(payload.Code)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlButtonUp(session types.Session, payload *message.ControlButton) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.ButtonUp(payload.Code)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlKeyPress(session types.Session, payload *message.ControlKey) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyPress(payload.Keysym)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlKeyDown(session types.Session, payload *message.ControlKey) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyDown(payload.Keysym)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlKeyUp(session types.Session, payload *message.ControlKey) error {
|
||||
if payload.ControlPos != nil {
|
||||
if err := h.controlMove(session, payload.ControlPos); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyUp(payload.Keysym)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlTouchBegin(session types.Session, payload *message.ControlTouch) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
return h.desktop.TouchBegin(payload.TouchId, payload.X, payload.Y, payload.Pressure)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlTouchUpdate(session types.Session, payload *message.ControlTouch) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
return h.desktop.TouchUpdate(payload.TouchId, payload.X, payload.Y, payload.Pressure)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlTouchEnd(session types.Session, payload *message.ControlTouch) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
return h.desktop.TouchEnd(payload.TouchId, payload.X, payload.Y, payload.Pressure)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlCut(session types.Session) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_x)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlCopy(session types.Session) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_c)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlPaste(session types.Session, payload *message.ClipboardData) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
// if there have been set clipboard data, set them first
|
||||
if payload != nil && payload.Text != "" {
|
||||
if err := h.clipboardSet(session, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_v)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) controlSelectAll(session types.Session) error {
|
||||
if err := h.controlRequest(session); err != nil && !errors.Is(err, ErrIsAlreadyTheHost) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.desktop.KeyPress(xorg.XK_Control_L, xorg.XK_a)
|
||||
}
|
203
server/internal/websocket/handler/handler.go
Normal file
203
server/internal/websocket/handler/handler.go
Normal file
@ -0,0 +1,203 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
func New(
|
||||
sessions types.SessionManager,
|
||||
desktop types.DesktopManager,
|
||||
capture types.CaptureManager,
|
||||
webrtc types.WebRTCManager,
|
||||
) *MessageHandlerCtx {
|
||||
return &MessageHandlerCtx{
|
||||
logger: log.With().Str("module", "websocket").Str("submodule", "handler").Logger(),
|
||||
sessions: sessions,
|
||||
desktop: desktop,
|
||||
capture: capture,
|
||||
webrtc: webrtc,
|
||||
}
|
||||
}
|
||||
|
||||
type MessageHandlerCtx struct {
|
||||
logger zerolog.Logger
|
||||
sessions types.SessionManager
|
||||
webrtc types.WebRTCManager
|
||||
desktop types.DesktopManager
|
||||
capture types.CaptureManager
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketMessage) bool {
|
||||
var err error
|
||||
switch data.Event {
|
||||
// System Events
|
||||
case event.SYSTEM_LOGS:
|
||||
payload := &message.SystemLogs{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.systemLogs(session, payload)
|
||||
})
|
||||
|
||||
// Signal Events
|
||||
case event.SIGNAL_REQUEST:
|
||||
payload := &message.SignalRequest{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalRequest(session, payload)
|
||||
})
|
||||
case event.SIGNAL_RESTART:
|
||||
err = h.signalRestart(session)
|
||||
case event.SIGNAL_OFFER:
|
||||
payload := &message.SignalDescription{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalOffer(session, payload)
|
||||
})
|
||||
case event.SIGNAL_ANSWER:
|
||||
payload := &message.SignalDescription{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalAnswer(session, payload)
|
||||
})
|
||||
case event.SIGNAL_CANDIDATE:
|
||||
payload := &message.SignalCandidate{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalCandidate(session, payload)
|
||||
})
|
||||
case event.SIGNAL_VIDEO:
|
||||
payload := &message.SignalVideo{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalVideo(session, payload)
|
||||
})
|
||||
case event.SIGNAL_AUDIO:
|
||||
payload := &message.SignalAudio{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.signalAudio(session, payload)
|
||||
})
|
||||
|
||||
// Control Events
|
||||
case event.CONTROL_RELEASE:
|
||||
err = h.controlRelease(session)
|
||||
case event.CONTROL_REQUEST:
|
||||
err = h.controlRequest(session)
|
||||
case event.CONTROL_MOVE:
|
||||
payload := &message.ControlPos{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlMove(session, payload)
|
||||
})
|
||||
case event.CONTROL_SCROLL:
|
||||
payload := &message.ControlScroll{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlScroll(session, payload)
|
||||
})
|
||||
case event.CONTROL_BUTTONPRESS:
|
||||
payload := &message.ControlButton{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlButtonPress(session, payload)
|
||||
})
|
||||
case event.CONTROL_BUTTONDOWN:
|
||||
payload := &message.ControlButton{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlButtonDown(session, payload)
|
||||
})
|
||||
case event.CONTROL_BUTTONUP:
|
||||
payload := &message.ControlButton{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlButtonUp(session, payload)
|
||||
})
|
||||
case event.CONTROL_KEYPRESS:
|
||||
payload := &message.ControlKey{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlKeyPress(session, payload)
|
||||
})
|
||||
case event.CONTROL_KEYDOWN:
|
||||
payload := &message.ControlKey{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlKeyDown(session, payload)
|
||||
})
|
||||
case event.CONTROL_KEYUP:
|
||||
payload := &message.ControlKey{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlKeyUp(session, payload)
|
||||
})
|
||||
// touch
|
||||
case event.CONTROL_TOUCHBEGIN:
|
||||
payload := &message.ControlTouch{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlTouchBegin(session, payload)
|
||||
})
|
||||
case event.CONTROL_TOUCHUPDATE:
|
||||
payload := &message.ControlTouch{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlTouchUpdate(session, payload)
|
||||
})
|
||||
case event.CONTROL_TOUCHEND:
|
||||
payload := &message.ControlTouch{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlTouchEnd(session, payload)
|
||||
})
|
||||
// actions
|
||||
case event.CONTROL_CUT:
|
||||
err = h.controlCut(session)
|
||||
case event.CONTROL_COPY:
|
||||
err = h.controlCopy(session)
|
||||
case event.CONTROL_PASTE:
|
||||
payload := &message.ClipboardData{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.controlPaste(session, payload)
|
||||
})
|
||||
case event.CONTROL_SELECT_ALL:
|
||||
err = h.controlSelectAll(session)
|
||||
|
||||
// Screen Events
|
||||
case event.SCREEN_SET:
|
||||
payload := &message.ScreenSize{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.screenSet(session, payload)
|
||||
})
|
||||
|
||||
// Clipboard Events
|
||||
case event.CLIPBOARD_SET:
|
||||
payload := &message.ClipboardData{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.clipboardSet(session, payload)
|
||||
})
|
||||
|
||||
// Keyboard Events
|
||||
case event.KEYBOARD_MAP:
|
||||
payload := &message.KeyboardMap{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.keyboardMap(session, payload)
|
||||
})
|
||||
case event.KEYBOARD_MODIFIERS:
|
||||
payload := &message.KeyboardModifiers{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.keyboardModifiers(session, payload)
|
||||
})
|
||||
|
||||
// Send Events
|
||||
case event.SEND_UNICAST:
|
||||
payload := &message.SendUnicast{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.sendUnicast(session, payload)
|
||||
})
|
||||
case event.SEND_BROADCAST:
|
||||
payload := &message.SendBroadcast{}
|
||||
err = utils.Unmarshal(payload, data.Payload, func() error {
|
||||
return h.sendBroadcast(session, payload)
|
||||
})
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Warn().Err(err).
|
||||
Str("event", data.Event).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("message handler has failed")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
25
server/internal/websocket/handler/keyboard.go
Normal file
25
server/internal/websocket/handler/keyboard.go
Normal file
@ -0,0 +1,25 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) keyboardMap(session types.Session, payload *message.KeyboardMap) error {
|
||||
if !session.IsHost() {
|
||||
return errors.New("is not the host")
|
||||
}
|
||||
|
||||
return h.desktop.SetKeyboardMap(payload.KeyboardMap)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) keyboardModifiers(session types.Session, payload *message.KeyboardModifiers) error {
|
||||
if !session.IsHost() {
|
||||
return errors.New("is not the host")
|
||||
}
|
||||
|
||||
h.desktop.SetKeyboardModifiers(payload.KeyboardModifiers)
|
||||
return nil
|
||||
}
|
26
server/internal/websocket/handler/screen.go
Normal file
26
server/internal/websocket/handler/screen.go
Normal file
@ -0,0 +1,26 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) screenSet(session types.Session, payload *message.ScreenSize) error {
|
||||
if !session.Profile().IsAdmin {
|
||||
return errors.New("is not the admin")
|
||||
}
|
||||
|
||||
size, err := h.desktop.SetScreenSize(payload.ScreenSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{
|
||||
ID: session.ID(),
|
||||
ScreenSize: size,
|
||||
})
|
||||
return nil
|
||||
}
|
39
server/internal/websocket/handler/send.go
Normal file
39
server/internal/websocket/handler/send.go
Normal file
@ -0,0 +1,39 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) sendUnicast(session types.Session, payload *message.SendUnicast) error {
|
||||
receiver, ok := h.sessions.Get(payload.Receiver)
|
||||
if !ok {
|
||||
return errors.New("receiver session ID not found")
|
||||
}
|
||||
|
||||
receiver.Send(
|
||||
event.SEND_UNICAST,
|
||||
message.SendUnicast{
|
||||
Sender: session.ID(),
|
||||
Receiver: receiver.ID(),
|
||||
Subject: payload.Subject,
|
||||
Body: payload.Body,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) sendBroadcast(session types.Session, payload *message.SendBroadcast) error {
|
||||
h.sessions.Broadcast(
|
||||
event.SEND_BROADCAST,
|
||||
message.SendBroadcast{
|
||||
Sender: session.ID(),
|
||||
Subject: payload.Subject,
|
||||
Body: payload.Body,
|
||||
}, session.ID())
|
||||
|
||||
return nil
|
||||
}
|
106
server/internal/websocket/handler/session.go
Normal file
106
server/internal/websocket/handler/session.go
Normal file
@ -0,0 +1,106 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) SessionCreated(session types.Session) error {
|
||||
h.sessions.Broadcast(
|
||||
event.SESSION_CREATED,
|
||||
message.SessionData{
|
||||
ID: session.ID(),
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) SessionDeleted(session types.Session) error {
|
||||
h.sessions.Broadcast(
|
||||
event.SESSION_DELETED,
|
||||
message.SessionID{
|
||||
ID: session.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) SessionConnected(session types.Session) error {
|
||||
if err := h.systemInit(session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if session.Profile().IsAdmin {
|
||||
if err := h.systemAdmin(session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update settings in atomic way
|
||||
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
|
||||
// if control protection & locked controls: unlock controls
|
||||
if settings.LockedControls && settings.ControlProtection {
|
||||
settings.LockedControls = false
|
||||
return true // update settings
|
||||
}
|
||||
return false // do not update settings
|
||||
})
|
||||
}
|
||||
|
||||
return h.SessionStateChanged(session)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) SessionDisconnected(session types.Session) error {
|
||||
// clear host if exists
|
||||
if session.IsHost() {
|
||||
h.desktop.ResetKeys()
|
||||
session.ClearHost()
|
||||
}
|
||||
|
||||
if session.Profile().IsAdmin {
|
||||
hasAdmin := false
|
||||
h.sessions.Range(func(s types.Session) bool {
|
||||
if s.Profile().IsAdmin && s.ID() != session.ID() && s.State().IsConnected {
|
||||
hasAdmin = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// update settings in atomic way
|
||||
h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool {
|
||||
// if control protection & not locked controls & no admin: lock controls
|
||||
if !settings.LockedControls && settings.ControlProtection && !hasAdmin {
|
||||
settings.LockedControls = true
|
||||
return true // update settings
|
||||
}
|
||||
return false // do not update settings
|
||||
})
|
||||
}
|
||||
|
||||
return h.SessionStateChanged(session)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) SessionProfileChanged(session types.Session, new, old types.MemberProfile) error {
|
||||
h.sessions.Broadcast(
|
||||
event.SESSION_PROFILE,
|
||||
message.MemberProfile{
|
||||
ID: session.ID(),
|
||||
MemberProfile: new,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) SessionStateChanged(session types.Session) error {
|
||||
h.sessions.Broadcast(
|
||||
event.SESSION_STATE,
|
||||
message.SessionState{
|
||||
ID: session.ID(),
|
||||
SessionState: session.State(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
162
server/internal/websocket/handler/signal.go
Normal file
162
server/internal/websocket/handler/signal.go
Normal file
@ -0,0 +1,162 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/pion/webrtc/v3"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) signalRequest(session types.Session, payload *message.SignalRequest) error {
|
||||
if !session.Profile().CanWatch {
|
||||
return errors.New("not allowed to watch")
|
||||
}
|
||||
|
||||
offer, peer, err := h.webrtc.CreatePeer(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set webrtc as paused if session has private mode enabled
|
||||
if session.PrivateModeEnabled() {
|
||||
peer.SetPaused(true)
|
||||
}
|
||||
|
||||
video := payload.Video
|
||||
|
||||
// use default first video, if not provided
|
||||
if video.Selector == nil {
|
||||
videos := h.capture.Video().IDs()
|
||||
video.Selector = &types.StreamSelector{
|
||||
ID: videos[0],
|
||||
Type: types.StreamSelectorTypeExact,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Remove, used for compatibility with old clients.
|
||||
if video.Auto == nil {
|
||||
video.Auto = &payload.Auto
|
||||
}
|
||||
|
||||
// set video stream
|
||||
err = peer.SetVideo(video)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
audio := payload.Audio
|
||||
|
||||
// enable by default if not requested otherwise
|
||||
if audio.Disabled == nil {
|
||||
disabled := false
|
||||
audio.Disabled = &disabled
|
||||
}
|
||||
|
||||
// set audio stream
|
||||
err = peer.SetAudio(audio)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_PROVIDE,
|
||||
message.SignalProvide{
|
||||
SDP: offer.SDP,
|
||||
ICEServers: h.webrtc.ICEServers(),
|
||||
|
||||
Video: peer.Video(),
|
||||
Audio: peer.Audio(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalRestart(session types.Session) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
offer, err := peer.CreateOffer(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Use offer event instead.
|
||||
session.Send(
|
||||
event.SIGNAL_RESTART,
|
||||
message.SignalDescription{
|
||||
SDP: offer.SDP,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalOffer(session types.Session, payload *message.SignalDescription) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
err := peer.SetRemoteDescription(webrtc.SessionDescription{
|
||||
SDP: payload.SDP,
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
answer, err := peer.CreateAnswer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SIGNAL_ANSWER,
|
||||
message.SignalDescription{
|
||||
SDP: answer.SDP,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalAnswer(session types.Session, payload *message.SignalDescription) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
return peer.SetRemoteDescription(webrtc.SessionDescription{
|
||||
SDP: payload.SDP,
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalCandidate(session types.Session, payload *message.SignalCandidate) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
return peer.SetCandidate(payload.ICECandidateInit)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalVideo(session types.Session, payload *message.SignalVideo) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
return peer.SetVideo(payload.PeerVideoRequest)
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) signalAudio(session types.Session, payload *message.SignalAudio) error {
|
||||
peer := session.GetWebRTCPeer()
|
||||
if peer == nil {
|
||||
return errors.New("webRTC peer does not exist")
|
||||
}
|
||||
|
||||
return peer.SetAudio(payload.PeerAudioRequest)
|
||||
}
|
96
server/internal/websocket/handler/system.go
Normal file
96
server/internal/websocket/handler/system.go
Normal file
@ -0,0 +1,96 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandlerCtx) systemInit(session types.Session) error {
|
||||
host, hasHost := h.sessions.GetHost()
|
||||
|
||||
var hostID string
|
||||
if hasHost {
|
||||
hostID = host.ID()
|
||||
}
|
||||
|
||||
controlHost := message.ControlHost{
|
||||
HasHost: hasHost,
|
||||
HostID: hostID,
|
||||
}
|
||||
|
||||
sessions := map[string]message.SessionData{}
|
||||
for _, session := range h.sessions.List() {
|
||||
sessionId := session.ID()
|
||||
sessions[sessionId] = message.SessionData{
|
||||
ID: sessionId,
|
||||
Profile: session.Profile(),
|
||||
State: session.State(),
|
||||
}
|
||||
}
|
||||
|
||||
session.Send(
|
||||
event.SYSTEM_INIT,
|
||||
message.SystemInit{
|
||||
SessionId: session.ID(),
|
||||
ControlHost: controlHost,
|
||||
ScreenSize: h.desktop.GetScreenSize(),
|
||||
Sessions: sessions,
|
||||
Settings: h.sessions.Settings(),
|
||||
TouchEvents: h.desktop.HasTouchSupport(),
|
||||
ScreencastEnabled: h.capture.Screencast().Enabled(),
|
||||
WebRTC: message.SystemWebRTC{
|
||||
Videos: h.capture.Video().IDs(),
|
||||
},
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) systemAdmin(session types.Session) error {
|
||||
configurations := h.desktop.ScreenConfigurations()
|
||||
|
||||
list := make([]types.ScreenSize, 0, len(configurations))
|
||||
for _, conf := range configurations {
|
||||
list = append(list, types.ScreenSize{
|
||||
Width: conf.Width,
|
||||
Height: conf.Height,
|
||||
Rate: conf.Rate,
|
||||
})
|
||||
}
|
||||
|
||||
broadcast := h.capture.Broadcast()
|
||||
session.Send(
|
||||
event.SYSTEM_ADMIN,
|
||||
message.SystemAdmin{
|
||||
ScreenSizesList: list, // TODO: remove
|
||||
BroadcastStatus: message.BroadcastStatus{
|
||||
IsActive: broadcast.Started(),
|
||||
URL: broadcast.Url(),
|
||||
},
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandlerCtx) systemLogs(session types.Session, payload *message.SystemLogs) error {
|
||||
for _, msg := range *payload {
|
||||
level, _ := zerolog.ParseLevel(msg.Level)
|
||||
|
||||
if level < zerolog.DebugLevel || level > zerolog.ErrorLevel {
|
||||
level = zerolog.NoLevel
|
||||
}
|
||||
|
||||
// do not use handler logger context
|
||||
log.WithLevel(level).
|
||||
Fields(msg.Fields).
|
||||
Str("module", "client").
|
||||
Str("session_id", session.ID()).
|
||||
Msg(msg.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
431
server/internal/websocket/manager.go
Normal file
431
server/internal/websocket/manager.go
Normal file
@ -0,0 +1,431 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/demodesk/neko/internal/websocket/handler"
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
// send pings to peer with this period - must be less than pongWait
|
||||
const pingPeriod = 10 * time.Second
|
||||
|
||||
// period for sending inactive cursor messages
|
||||
const inactiveCursorsPeriod = 750 * time.Millisecond
|
||||
|
||||
// maximum payload length for logging
|
||||
const maxPayloadLogLength = 10_000
|
||||
|
||||
// events that are not logged in debug mode
|
||||
var nologEvents = []string{
|
||||
// don't log twice
|
||||
event.SYSTEM_LOGS,
|
||||
// don't log heartbeat
|
||||
event.SYSTEM_HEARTBEAT,
|
||||
// don't log every cursor update
|
||||
event.SESSION_CURSORS,
|
||||
}
|
||||
|
||||
func New(
|
||||
sessions types.SessionManager,
|
||||
desktop types.DesktopManager,
|
||||
capture types.CaptureManager,
|
||||
webrtc types.WebRTCManager,
|
||||
) *WebSocketManagerCtx {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
return &WebSocketManagerCtx{
|
||||
logger: logger,
|
||||
shutdown: make(chan struct{}),
|
||||
sessions: sessions,
|
||||
desktop: desktop,
|
||||
handler: handler.New(sessions, desktop, capture, webrtc),
|
||||
handlers: []types.WebSocketHandler{},
|
||||
}
|
||||
}
|
||||
|
||||
type WebSocketManagerCtx struct {
|
||||
logger zerolog.Logger
|
||||
wg sync.WaitGroup
|
||||
shutdown chan struct{}
|
||||
sessions types.SessionManager
|
||||
desktop types.DesktopManager
|
||||
handler *handler.MessageHandlerCtx
|
||||
handlers []types.WebSocketHandler
|
||||
|
||||
shutdownInactiveCursors chan struct{}
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) Start() {
|
||||
manager.sessions.OnCreated(func(session types.Session) {
|
||||
err := manager.handler.SessionCreated(session)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("session created")
|
||||
})
|
||||
|
||||
manager.sessions.OnDeleted(func(session types.Session) {
|
||||
err := manager.handler.SessionDeleted(session)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("session deleted")
|
||||
})
|
||||
|
||||
manager.sessions.OnConnected(func(session types.Session) {
|
||||
err := manager.handler.SessionConnected(session)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("session connected")
|
||||
})
|
||||
|
||||
manager.sessions.OnDisconnected(func(session types.Session) {
|
||||
err := manager.handler.SessionDisconnected(session)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("session disconnected")
|
||||
})
|
||||
|
||||
manager.sessions.OnProfileChanged(func(session types.Session, new, old types.MemberProfile) {
|
||||
err := manager.handler.SessionProfileChanged(session, new, old)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Interface("new", new).
|
||||
Interface("old", old).
|
||||
Msg("session profile changed")
|
||||
})
|
||||
|
||||
manager.sessions.OnStateChanged(func(session types.Session) {
|
||||
err := manager.handler.SessionStateChanged(session)
|
||||
manager.logger.Err(err).
|
||||
Str("session_id", session.ID()).
|
||||
Msg("session state changed")
|
||||
})
|
||||
|
||||
manager.sessions.OnHostChanged(func(session, host types.Session) {
|
||||
payload := message.ControlHost{
|
||||
ID: session.ID(),
|
||||
HasHost: host != nil,
|
||||
}
|
||||
|
||||
if payload.HasHost {
|
||||
payload.HostID = host.ID()
|
||||
}
|
||||
|
||||
manager.sessions.Broadcast(event.CONTROL_HOST, payload)
|
||||
|
||||
manager.logger.Info().
|
||||
Str("session_id", session.ID()).
|
||||
Bool("has_host", payload.HasHost).
|
||||
Str("host_id", payload.HostID).
|
||||
Msg("session host changed")
|
||||
})
|
||||
|
||||
manager.sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) {
|
||||
// start inactive cursors
|
||||
if new.InactiveCursors && !old.InactiveCursors {
|
||||
manager.startInactiveCursors()
|
||||
}
|
||||
|
||||
// stop inactive cursors
|
||||
if !new.InactiveCursors && old.InactiveCursors {
|
||||
manager.stopInactiveCursors()
|
||||
}
|
||||
|
||||
manager.sessions.Broadcast(event.SYSTEM_SETTINGS, message.SystemSettingsUpdate{
|
||||
ID: session.ID(),
|
||||
Settings: new,
|
||||
})
|
||||
|
||||
manager.logger.Info().
|
||||
Str("session_id", session.ID()).
|
||||
Interface("new", new).
|
||||
Interface("old", old).
|
||||
Msg("settings changed")
|
||||
})
|
||||
|
||||
manager.desktop.OnClipboardUpdated(func() {
|
||||
host, hasHost := manager.sessions.GetHost()
|
||||
if !hasHost || !host.Profile().CanAccessClipboard {
|
||||
return
|
||||
}
|
||||
|
||||
manager.logger.Info().Msg("sync clipboard")
|
||||
|
||||
data, err := manager.desktop.ClipboardGetText()
|
||||
if err != nil {
|
||||
manager.logger.Err(err).Msg("could not get clipboard content")
|
||||
return
|
||||
}
|
||||
|
||||
host.Send(
|
||||
event.CLIPBOARD_UPDATED,
|
||||
message.ClipboardData{
|
||||
Text: data.Text,
|
||||
// TODO: Send HTML?
|
||||
})
|
||||
})
|
||||
|
||||
if manager.desktop.IsFileChooserDialogEnabled() {
|
||||
manager.fileChooserDialogEvents()
|
||||
}
|
||||
|
||||
if manager.sessions.Settings().InactiveCursors {
|
||||
manager.startInactiveCursors()
|
||||
}
|
||||
|
||||
manager.logger.Info().Msg("websocket starting")
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) Shutdown() error {
|
||||
manager.logger.Info().Msg("shutdown")
|
||||
close(manager.shutdown)
|
||||
manager.stopInactiveCursors()
|
||||
manager.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) {
|
||||
manager.handlers = append(manager.handlers, handler)
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) Upgrade(checkOrigin types.CheckOrigin) types.RouterHandler {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: checkOrigin,
|
||||
// Do not return any error while handshake
|
||||
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {},
|
||||
}
|
||||
|
||||
connection, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return utils.HttpBadRequest().WithInternalErr(err)
|
||||
}
|
||||
|
||||
// Cannot write HTTP response after connection upgrade
|
||||
manager.connect(connection, r)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http.Request) {
|
||||
session, err := manager.sessions.Authenticate(r)
|
||||
if err != nil {
|
||||
manager.logger.Warn().Err(err).Msg("authentication failed")
|
||||
newPeer(manager.logger, connection).Destroy(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// add session id to all log messages
|
||||
logger := manager.logger.With().Str("session_id", session.ID()).Logger()
|
||||
|
||||
// create new peer
|
||||
peer := newPeer(logger, connection)
|
||||
|
||||
if !session.Profile().CanConnect {
|
||||
logger.Warn().Msg("connection disabled")
|
||||
peer.Destroy("connection disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if session.State().IsConnected {
|
||||
logger.Warn().Msg("already connected")
|
||||
|
||||
if !manager.sessions.Settings().MercifulReconnect {
|
||||
peer.Destroy("already connected")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info().Msg("replacing peer connection")
|
||||
}
|
||||
|
||||
logger.Info().
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Str("agent", r.UserAgent()).
|
||||
Msg("connection started")
|
||||
|
||||
session.ConnectWebSocketPeer(peer)
|
||||
|
||||
// this is a blocking function that lives
|
||||
// throughout whole websocket connection
|
||||
err = manager.handle(connection, peer, session)
|
||||
|
||||
logger.Info().
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Str("agent", r.UserAgent()).
|
||||
Msg("connection ended")
|
||||
|
||||
if err == nil {
|
||||
logger.Debug().Msg("websocket close")
|
||||
session.DisconnectWebSocketPeer(peer, false)
|
||||
return
|
||||
}
|
||||
|
||||
delayedDisconnect := false
|
||||
|
||||
e, ok := err.(*websocket.CloseError)
|
||||
if !ok {
|
||||
err = errors.Unwrap(err) // unwrap if possible
|
||||
logger.Warn().Err(err).Msg("read message error")
|
||||
// client is expected to reconnect soon
|
||||
delayedDisconnect = true
|
||||
} else {
|
||||
switch e.Code {
|
||||
case websocket.CloseNormalClosure:
|
||||
logger.Debug().Str("reason", e.Text).Msg("websocket close")
|
||||
case websocket.CloseGoingAway:
|
||||
logger.Debug().Str("reason", "going away").Msg("websocket close")
|
||||
default:
|
||||
logger.Warn().Err(err).Msg("websocket close")
|
||||
// abnormal websocket closure:
|
||||
// client is expected to reconnect soon
|
||||
delayedDisconnect = true
|
||||
}
|
||||
}
|
||||
|
||||
session.DisconnectWebSocketPeer(peer, delayedDisconnect)
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer types.WebSocketPeer, session types.Session) error {
|
||||
// add session id to logger context
|
||||
logger := manager.logger.With().Str("session_id", session.ID()).Logger()
|
||||
|
||||
bytes := make(chan []byte)
|
||||
cancel := make(chan error)
|
||||
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
manager.wg.Add(1)
|
||||
go func() {
|
||||
defer manager.wg.Done()
|
||||
|
||||
for {
|
||||
_, raw, err := connection.ReadMessage()
|
||||
if err != nil {
|
||||
cancel <- err
|
||||
break
|
||||
}
|
||||
|
||||
bytes <- raw
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case raw := <-bytes:
|
||||
data := types.WebSocketMessage{}
|
||||
if err := json.Unmarshal(raw, &data); err != nil {
|
||||
logger.Err(err).Msg("message unmarshalling has failed")
|
||||
break
|
||||
}
|
||||
|
||||
// log events if not ignored
|
||||
if ok, _ := utils.ArrayIn(data.Event, nologEvents); !ok {
|
||||
payload := data.Payload
|
||||
if len(payload) > maxPayloadLogLength {
|
||||
payload = []byte("<truncated>")
|
||||
}
|
||||
|
||||
logger.Debug().
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Str("event", data.Event).
|
||||
Str("payload", string(payload)).
|
||||
Msg("received message from client")
|
||||
}
|
||||
|
||||
handled := manager.handler.Message(session, data)
|
||||
for _, handler := range manager.handlers {
|
||||
if handled {
|
||||
break
|
||||
}
|
||||
|
||||
handled = handler(session, data)
|
||||
}
|
||||
|
||||
if !handled {
|
||||
logger.Warn().Str("event", data.Event).Msg("unhandled message")
|
||||
}
|
||||
case err := <-cancel:
|
||||
return err
|
||||
case <-manager.shutdown:
|
||||
peer.Destroy("connection shutdown")
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := peer.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) startInactiveCursors() {
|
||||
if manager.shutdownInactiveCursors != nil {
|
||||
manager.logger.Warn().Msg("inactive cursors handler already running")
|
||||
return
|
||||
}
|
||||
|
||||
manager.logger.Info().Msg("starting inactive cursors handler")
|
||||
manager.shutdownInactiveCursors = make(chan struct{})
|
||||
|
||||
manager.wg.Add(1)
|
||||
go func() {
|
||||
defer manager.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(inactiveCursorsPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
var currentEmpty bool
|
||||
var lastEmpty = false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-manager.shutdownInactiveCursors:
|
||||
manager.logger.Info().Msg("stopping inactive cursors handler")
|
||||
manager.shutdownInactiveCursors = nil
|
||||
|
||||
// remove last cursor entries and send empty message
|
||||
_ = manager.sessions.PopCursors()
|
||||
manager.sessions.InactiveCursorsBroadcast(event.SESSION_CURSORS, []message.SessionCursors{})
|
||||
return
|
||||
case <-ticker.C:
|
||||
cursorsMap := manager.sessions.PopCursors()
|
||||
|
||||
currentEmpty = len(cursorsMap) == 0
|
||||
if currentEmpty && lastEmpty {
|
||||
continue
|
||||
}
|
||||
lastEmpty = currentEmpty
|
||||
|
||||
sessionCursors := []message.SessionCursors{}
|
||||
for session, cursors := range cursorsMap {
|
||||
sessionCursors = append(
|
||||
sessionCursors,
|
||||
message.SessionCursors{
|
||||
ID: session.ID(),
|
||||
Cursors: cursors,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
manager.sessions.InactiveCursorsBroadcast(event.SESSION_CURSORS, sessionCursors)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (manager *WebSocketManagerCtx) stopInactiveCursors() {
|
||||
if manager.shutdownInactiveCursors != nil {
|
||||
close(manager.shutdownInactiveCursors)
|
||||
}
|
||||
}
|
91
server/internal/websocket/peer.go
Normal file
91
server/internal/websocket/peer.go
Normal file
@ -0,0 +1,91 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/demodesk/neko/pkg/types"
|
||||
"github.com/demodesk/neko/pkg/types/event"
|
||||
"github.com/demodesk/neko/pkg/types/message"
|
||||
"github.com/demodesk/neko/pkg/utils"
|
||||
)
|
||||
|
||||
type WebSocketPeerCtx struct {
|
||||
mu sync.Mutex
|
||||
logger zerolog.Logger
|
||||
connection *websocket.Conn
|
||||
}
|
||||
|
||||
func newPeer(logger zerolog.Logger, connection *websocket.Conn) *WebSocketPeerCtx {
|
||||
return &WebSocketPeerCtx{
|
||||
logger: logger.With().Str("submodule", "peer").Logger(),
|
||||
connection: connection,
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebSocketPeerCtx) Send(event string, payload any) {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
peer.logger.Err(err).Str("event", event).Msg("message marshalling has failed")
|
||||
return
|
||||
}
|
||||
|
||||
err = peer.connection.WriteJSON(types.WebSocketMessage{
|
||||
Event: event,
|
||||
Payload: raw,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
err = errors.Unwrap(err) // unwrap if possible
|
||||
peer.logger.Warn().Err(err).Str("event", event).Msg("send message error")
|
||||
return
|
||||
}
|
||||
|
||||
// log events if not ignored
|
||||
if ok, _ := utils.ArrayIn(event, nologEvents); !ok {
|
||||
if len(raw) > maxPayloadLogLength {
|
||||
raw = []byte("<truncated>")
|
||||
}
|
||||
|
||||
peer.logger.Debug().
|
||||
Str("address", peer.connection.RemoteAddr().String()).
|
||||
Str("event", event).
|
||||
Str("payload", string(raw)).
|
||||
Msg("sending message to client")
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *WebSocketPeerCtx) Ping() error {
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
// application level heartbeat
|
||||
if err := peer.connection.WriteJSON(types.WebSocketMessage{
|
||||
Event: event.SYSTEM_HEARTBEAT,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return peer.connection.WriteMessage(websocket.PingMessage, nil)
|
||||
}
|
||||
|
||||
func (peer *WebSocketPeerCtx) Destroy(reason string) {
|
||||
peer.Send(
|
||||
event.SYSTEM_DISCONNECT,
|
||||
message.SystemDisconnect{
|
||||
Message: reason,
|
||||
})
|
||||
|
||||
peer.mu.Lock()
|
||||
defer peer.mu.Unlock()
|
||||
|
||||
err := peer.connection.Close()
|
||||
peer.logger.Err(err).Msg("peer connection destroyed")
|
||||
}
|
Reference in New Issue
Block a user