diff --git a/internal/api/members/bluk.go b/internal/api/members/bluk.go index d317b1a0..808db44a 100644 --- a/internal/api/members/bluk.go +++ b/internal/api/members/bluk.go @@ -14,25 +14,25 @@ type MemberBulkUpdatePayload struct { Profile types.MemberProfile `json:"profile"` } -func (h *MembersHandler) membersBulkUpdate(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersBulkUpdate(w http.ResponseWriter, r *http.Request) error { bytes, err := io.ReadAll(r.Body) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("unable to read post body") - return + return utils.HttpBadRequest("unable to read post body").WithInternalErr(err) } header := &MemberBulkUpdatePayload{} if err := json.Unmarshal(bytes, &header); err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("unable to unmarshal payload") - return + return utils.HttpBadRequest("unable to unmarshal payload").WithInternalErr(err) } for _, memberId := range header.IDs { // TODO: Bulk select? profile, err := h.members.Select(memberId) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to select member profile").Msgf("failed to update member %s", memberId) - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to select member profile"). + Msgf("failed to update member %s", memberId) } body := &MemberBulkUpdatePayload{ @@ -40,15 +40,18 @@ func (h *MembersHandler) membersBulkUpdate(w http.ResponseWriter, r *http.Reques } if err := json.Unmarshal(bytes, &body); err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msgf("unable to unmarshal payload for member %s", memberId) - return + return utils.HttpBadRequest(). + WithInternalErr(err). + Msgf("unable to unmarshal payload for member %s", memberId) } if err := h.members.UpdateProfile(memberId, body.Profile); err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to update member profile").Msgf("failed to update member %s", memberId) - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to update member profile"). + Msgf("failed to update member %s", memberId) } } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } diff --git a/internal/api/members/controler.go b/internal/api/members/controler.go index 6e5990a4..d26d979d 100644 --- a/internal/api/members/controler.go +++ b/internal/api/members/controler.go @@ -24,7 +24,7 @@ type MemberPasswordPayload struct { Password string `json:"password"` } -func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) error { limit, err := strconv.Atoi(r.URL.Query().Get("limit")) if err != nil { // TODO: Default zero. @@ -39,8 +39,7 @@ func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) { entries, err := h.members.SelectAll(limit, offset) if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } members := []MemberDataPayload{} @@ -51,10 +50,10 @@ func (h *MembersHandler) membersList(w http.ResponseWriter, r *http.Request) { }) } - utils.HttpSuccess(w, members) + return utils.HttpSuccess(w, members) } -func (h *MembersHandler) membersCreate(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersCreate(w http.ResponseWriter, r *http.Request) error { data := &MemberCreatePayload{ // default values Profile: types.MemberProfile{ @@ -67,82 +66,76 @@ func (h *MembersHandler) membersCreate(w http.ResponseWriter, r *http.Request) { }, } - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } if data.Username == "" { - utils.HttpBadRequest(w).Msg("username cannot be empty") - return + return utils.HttpBadRequest("username cannot be empty") } if data.Password == "" { - utils.HttpBadRequest(w).Msg("password cannot be empty") - return + return utils.HttpBadRequest("password cannot be empty") } id, err := h.members.Insert(data.Username, data.Password, data.Profile) if err != nil { if errors.Is(err, types.ErrMemberAlreadyExists) { - utils.HttpUnprocessableEntity(w).Msg("member already exists") - } else { - utils.HttpInternalServerError(w, err).Send() + return utils.HttpUnprocessableEntity("member already exists") } - return + + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w, MemberDataPayload{ + return utils.HttpSuccess(w, MemberDataPayload{ ID: id, Profile: data.Profile, }) } -func (h *MembersHandler) membersRead(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersRead(w http.ResponseWriter, r *http.Request) error { member := GetMember(r) profile := member.Profile - utils.HttpSuccess(w, profile) + return utils.HttpSuccess(w, profile) } -func (h *MembersHandler) membersUpdateProfile(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersUpdateProfile(w http.ResponseWriter, r *http.Request) error { member := GetMember(r) - profile := member.Profile + data := &member.Profile - if !utils.HttpJsonRequest(w, r, &profile) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } - if err := h.members.UpdateProfile(member.ID, profile); err != nil { - utils.HttpInternalServerError(w, err).Send() - return + if err := h.members.UpdateProfile(member.ID, *data); err != nil { + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *MembersHandler) membersUpdatePassword(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersUpdatePassword(w http.ResponseWriter, r *http.Request) error { member := GetMember(r) - data := MemberPasswordPayload{} + data := &MemberPasswordPayload{} - if !utils.HttpJsonRequest(w, r, &data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } if err := h.members.UpdatePassword(member.ID, data.Password); err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *MembersHandler) membersDelete(w http.ResponseWriter, r *http.Request) { +func (h *MembersHandler) membersDelete(w http.ResponseWriter, r *http.Request) error { member := GetMember(r) if err := h.members.Delete(member.ID); err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } diff --git a/internal/api/members/handler.go b/internal/api/members/handler.go index 4c25f94e..276c3b40 100644 --- a/internal/api/members/handler.go +++ b/internal/api/members/handler.go @@ -30,12 +30,12 @@ func New( } } -func (h *MembersHandler) Route(r chi.Router) { +func (h *MembersHandler) Route(r types.Router) { r.Get("/", h.membersList) - r.With(auth.AdminsOnly).Group(func(r chi.Router) { + r.With(auth.AdminsOnly).Group(func(r types.Router) { r.Post("/", h.membersCreate) - r.With(h.ExtractMember).Route("/{memberId}", func(r chi.Router) { + r.With(h.ExtractMember).Route("/{memberId}", func(r types.Router) { r.Get("/", h.membersRead) r.Post("/", h.membersUpdateProfile) r.Post("/password", h.membersUpdatePassword) @@ -44,8 +44,8 @@ func (h *MembersHandler) Route(r chi.Router) { }) } -func (h *MembersHandler) RouteBulk(r chi.Router) { - r.With(auth.AdminsOnly).Group(func(r chi.Router) { +func (h *MembersHandler) RouteBulk(r types.Router) { + r.With(auth.AdminsOnly).Group(func(r types.Router) { r.Post("/update", h.membersBulkUpdate) }) } @@ -55,33 +55,28 @@ type MemberData struct { Profile types.MemberProfile } -func SetMember(r *http.Request, session MemberData) *http.Request { - ctx := context.WithValue(r.Context(), keyMemberCtx, session) - return r.WithContext(ctx) +func SetMember(r *http.Request, session MemberData) context.Context { + return context.WithValue(r.Context(), keyMemberCtx, session) } func GetMember(r *http.Request) MemberData { return r.Context().Value(keyMemberCtx).(MemberData) } -func (h *MembersHandler) ExtractMember(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - memberId := chi.URLParam(r, "memberId") +func (h *MembersHandler) ExtractMember(w http.ResponseWriter, r *http.Request) (context.Context, error) { + memberId := chi.URLParam(r, "memberId") - profile, err := h.members.Select(memberId) - if err != nil { - if errors.Is(err, types.ErrMemberDoesNotExist) { - utils.HttpNotFound(w).Msg("member not found") - } else { - utils.HttpInternalServerError(w, err).Send() - } - - return + profile, err := h.members.Select(memberId) + if err != nil { + if errors.Is(err, types.ErrMemberDoesNotExist) { + return nil, utils.HttpNotFound("member not found") } - next.ServeHTTP(w, SetMember(r, MemberData{ - ID: memberId, - Profile: profile, - })) - }) + return nil, utils.HttpInternalServerError().WithInternalErr(err) + } + + return SetMember(r, MemberData{ + ID: memberId, + Profile: profile, + }), nil } diff --git a/internal/api/room/broadcast.go b/internal/api/room/broadcast.go index 24c91a92..c7afef9f 100644 --- a/internal/api/room/broadcast.go +++ b/internal/api/room/broadcast.go @@ -13,34 +13,32 @@ type BroadcastStatusPayload struct { IsActive bool `json:"is_active"` } -func (h *RoomHandler) broadcastStatus(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) broadcastStatus(w http.ResponseWriter, r *http.Request) error { broadcast := h.capture.Broadcast() - utils.HttpSuccess(w, BroadcastStatusPayload{ + + return utils.HttpSuccess(w, BroadcastStatusPayload{ IsActive: broadcast.Started(), URL: broadcast.Url(), }) } -func (h *RoomHandler) boradcastStart(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) boradcastStart(w http.ResponseWriter, r *http.Request) error { data := &BroadcastStatusPayload{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } if data.URL == "" { - utils.HttpBadRequest(w).Msg("missing broadcast URL") - return + return utils.HttpBadRequest("missing broadcast URL") } broadcast := h.capture.Broadcast() if broadcast.Started() { - utils.HttpUnprocessableEntity(w).Msg("server is already broadcasting") - return + return utils.HttpUnprocessableEntity("server is already broadcasting") } if err := broadcast.Start(data.URL); err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } h.sessions.AdminBroadcast( @@ -50,14 +48,13 @@ func (h *RoomHandler) boradcastStart(w http.ResponseWriter, r *http.Request) { URL: broadcast.Url(), }, nil) - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) boradcastStop(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) boradcastStop(w http.ResponseWriter, r *http.Request) error { broadcast := h.capture.Broadcast() if !broadcast.Started() { - utils.HttpUnprocessableEntity(w).Msg("server is not broadcasting") - return + return utils.HttpUnprocessableEntity("server is not broadcasting") } broadcast.Stop() @@ -69,5 +66,5 @@ func (h *RoomHandler) boradcastStop(w http.ResponseWriter, r *http.Request) { URL: broadcast.Url(), }, nil) - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } diff --git a/internal/api/room/clipboard.go b/internal/api/room/clipboard.go index 6c6dfcaf..203282c4 100644 --- a/internal/api/room/clipboard.go +++ b/internal/api/room/clipboard.go @@ -4,6 +4,7 @@ import ( // TODO: Unused now. //"bytes" //"strings" + "net/http" "demodesk/neko/internal/types" @@ -15,23 +16,22 @@ type ClipboardPayload struct { HTML string `json:"html,omitempty"` } -func (h *RoomHandler) clipboardGetText(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) clipboardGetText(w http.ResponseWriter, r *http.Request) error { data, err := h.desktop.ClipboardGetText() if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w, ClipboardPayload{ + return utils.HttpSuccess(w, ClipboardPayload{ Text: data.Text, HTML: data.HTML, }) } -func (h *RoomHandler) clipboardSetText(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) clipboardSetText(w http.ResponseWriter, r *http.Request) error { data := &ClipboardPayload{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } err := h.desktop.ClipboardSetText(types.ClipboardText{ @@ -40,32 +40,30 @@ func (h *RoomHandler) clipboardSetText(w http.ResponseWriter, r *http.Request) { }) if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) clipboardGetImage(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) clipboardGetImage(w http.ResponseWriter, r *http.Request) error { bytes, err := h.desktop.ClipboardGetBinary("image/png") if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Content-Type", "image/png") - //nolint - w.Write(bytes) + + _, err = w.Write(bytes) + return err } /* TODO: Unused now. -func (h *RoomHandler) clipboardSetImage(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) clipboardSetImage(w http.ResponseWriter, r *http.Request) error { err := r.ParseMultipartForm(MAX_UPLOAD_SIZE) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("failed to parse multipart form") - return + return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err) } //nolint @@ -73,41 +71,37 @@ func (h *RoomHandler) clipboardSetImage(w http.ResponseWriter, r *http.Request) file, header, err := r.FormFile("file") if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("no file received") - return + return utils.HttpBadRequest("no file received").WithInternalErr(err) } defer file.Close() mime := header.Header.Get("Content-Type") if !strings.HasPrefix(mime, "image/") { - utils.HttpBadRequest(w).Msg("file must be image") - return + return utils.HttpBadRequest("file must be image") } buffer := new(bytes.Buffer) _, err = buffer.ReadFrom(file) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to read from uploaded file").Send() - return + return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable to read from uploaded file") } err = h.desktop.ClipboardSetBinary("image/png", buffer.Bytes()) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable set image to clipboard").Send() - return + return utils.HttpInternalServerError().WithInternalErr(err).WithInternalMsg("unable set image to clipboard") } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) clipboardGetTargets(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) clipboardGetTargets(w http.ResponseWriter, r *http.Request) error { targets, err := h.desktop.ClipboardGetTargets() if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w, targets) + return utils.HttpSuccess(w, targets) } + */ diff --git a/internal/api/room/control.go b/internal/api/room/control.go index b4aeb5e1..400df38c 100644 --- a/internal/api/room/control.go +++ b/internal/api/room/control.go @@ -18,82 +18,76 @@ type ControlTargetPayload struct { ID string `json:"id"` } -func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) error { host := h.sessions.GetHost() - if host == nil { - utils.HttpSuccess(w, ControlStatusPayload{ - HasHost: false, - }) - } else { - utils.HttpSuccess(w, ControlStatusPayload{ + if host != nil { + return utils.HttpSuccess(w, ControlStatusPayload{ HasHost: true, HostId: host.ID(), }) } + + return utils.HttpSuccess(w, ControlStatusPayload{ + HasHost: false, + }) } -func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) error { host := h.sessions.GetHost() if host != nil { - utils.HttpUnprocessableEntity(w).Msg("there is already a host") - return + return utils.HttpUnprocessableEntity("there is already a host") } - session := auth.GetSession(r) + session, _ := auth.GetSession(r) h.sessions.SetHost(session) - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) controlRelease(w http.ResponseWriter, r *http.Request) { - session := auth.GetSession(r) +func (h *RoomHandler) controlRelease(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) if !session.IsHost() { - utils.HttpUnprocessableEntity(w).Msg("session is not the host") - return + return utils.HttpUnprocessableEntity("session is not the host") } h.desktop.ResetKeys() h.sessions.ClearHost() - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) controlTake(w http.ResponseWriter, r *http.Request) { - session := auth.GetSession(r) +func (h *RoomHandler) controlTake(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) h.sessions.SetHost(session) - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error { sessionId := chi.URLParam(r, "sessionId") target, ok := h.sessions.Get(sessionId) if !ok { - utils.HttpNotFound(w).Msg("target session was not found") - return + return utils.HttpNotFound("target session was not found") } if !target.Profile().CanHost { - utils.HttpBadRequest(w).Msg("target session is not allowed to host") - return + return utils.HttpBadRequest("target session is not allowed to host") } h.sessions.SetHost(target) - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) error { host := h.sessions.GetHost() - if host == nil { - utils.HttpSuccess(w) - return + + if host != nil { + h.desktop.ResetKeys() + h.sessions.ClearHost() } - h.desktop.ResetKeys() - h.sessions.ClearHost() - - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } diff --git a/internal/api/room/handler.go b/internal/api/room/handler.go index 91f0c52a..254f722f 100644 --- a/internal/api/room/handler.go +++ b/internal/api/room/handler.go @@ -1,10 +1,9 @@ package room import ( + "context" "net/http" - "github.com/go-chi/chi" - "demodesk/neko/internal/http/auth" "demodesk/neko/internal/types" "demodesk/neko/internal/utils" @@ -30,14 +29,14 @@ func New( } } -func (h *RoomHandler) Route(r chi.Router) { - r.With(auth.AdminsOnly).Route("/broadcast", func(r chi.Router) { +func (h *RoomHandler) Route(r types.Router) { + r.With(auth.AdminsOnly).Route("/broadcast", func(r types.Router) { r.Get("/", h.broadcastStatus) r.Post("/start", h.boradcastStart) r.Post("/stop", h.boradcastStop) }) - r.With(auth.CanAccessClipboardOnly).With(auth.HostsOnly).Route("/clipboard", func(r chi.Router) { + r.With(auth.CanAccessClipboardOnly).With(auth.HostsOnly).Route("/clipboard", func(r types.Router) { r.Get("/", h.clipboardGetText) r.Post("/", h.clipboardSetText) r.Get("/image.png", h.clipboardGetImage) @@ -52,7 +51,7 @@ func (h *RoomHandler) Route(r chi.Router) { //r.Get("/targets", h.clipboardGetTargets) }) - r.With(auth.CanHostOnly).Route("/keyboard", func(r chi.Router) { + r.With(auth.CanHostOnly).Route("/keyboard", func(r types.Router) { r.Get("/map", h.keyboardMapGet) r.With(auth.HostsOnly).Post("/map", h.keyboardMapSet) @@ -60,7 +59,7 @@ func (h *RoomHandler) Route(r chi.Router) { r.With(auth.HostsOnly).Post("/modifiers", h.keyboardModifiersSet) }) - r.With(auth.CanHostOnly).Route("/control", func(r chi.Router) { + r.With(auth.CanHostOnly).Route("/control", func(r types.Router) { r.Get("/", h.controlStatus) r.Post("/request", h.controlRequest) r.Post("/release", h.controlRelease) @@ -70,7 +69,7 @@ func (h *RoomHandler) Route(r chi.Router) { r.With(auth.AdminsOnly).Post("/reset", h.controlReset) }) - r.With(auth.CanWatchOnly).Route("/screen", func(r chi.Router) { + r.With(auth.CanWatchOnly).Route("/screen", func(r types.Router) { r.Get("/", h.screenConfiguration) r.With(auth.AdminsOnly).Post("/", h.screenConfigurationChange) r.With(auth.AdminsOnly).Get("/configurations", h.screenConfigurationsList) @@ -79,20 +78,18 @@ func (h *RoomHandler) Route(r chi.Router) { r.With(auth.AdminsOnly).Get("/shot.jpg", h.screenShotGet) }) - r.With(h.uploadMiddleware).Route("/upload", func(r chi.Router) { + r.With(h.uploadMiddleware).Route("/upload", func(r types.Router) { r.Post("/drop", h.uploadDrop) r.Post("/dialog", h.uploadDialogPost) r.Delete("/dialog", h.uploadDialogClose) }) } -func (h *RoomHandler) uploadMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := auth.GetSession(r) - if !session.IsHost() && (!session.Profile().CanHost || !h.sessions.ImplicitHosting()) { - utils.HttpForbidden(w).Msg("without implicit hosting, only host can upload files") - } else { - next.ServeHTTP(w, r) - } - }) +func (h *RoomHandler) uploadMiddleware(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := auth.GetSession(r) + if !ok || (!session.IsHost() && (!session.Profile().CanHost || !h.sessions.ImplicitHosting())) { + return nil, utils.HttpForbidden("without implicit hosting, only host can upload files") + } + + return nil, nil } diff --git a/internal/api/room/keyboard.go b/internal/api/room/keyboard.go index 289faf97..91cccab0 100644 --- a/internal/api/room/keyboard.go +++ b/internal/api/room/keyboard.go @@ -17,10 +17,10 @@ type KeyboardModifiersData struct { CapsLock *bool `json:"capslock"` } -func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) error { data := &KeyboardMapData{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } err := h.desktop.SetKeyboardMap(types.KeyboardMap{ @@ -29,44 +29,43 @@ func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) { }) if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) keyboardMapGet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) keyboardMapGet(w http.ResponseWriter, r *http.Request) error { data, err := h.desktop.GetKeyboardMap() if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } - utils.HttpSuccess(w, KeyboardMapData{ + return utils.HttpSuccess(w, KeyboardMapData{ Layout: data.Layout, Variant: data.Variant, }) } -func (h *RoomHandler) keyboardModifiersSet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) keyboardModifiersSet(w http.ResponseWriter, r *http.Request) error { data := &KeyboardModifiersData{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } h.desktop.SetKeyboardModifiers(types.KeyboardModifiers{ NumLock: data.NumLock, CapsLock: data.CapsLock, }) - utils.HttpSuccess(w) + + return utils.HttpSuccess(w) } -func (h *RoomHandler) keyboardModifiersGet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) keyboardModifiersGet(w http.ResponseWriter, r *http.Request) error { data := h.desktop.GetKeyboardModifiers() - utils.HttpSuccess(w, KeyboardModifiersData{ + return utils.HttpSuccess(w, KeyboardModifiersData{ NumLock: data.NumLock, CapsLock: data.CapsLock, }) diff --git a/internal/api/room/screen.go b/internal/api/room/screen.go index 925702dd..67c658ef 100644 --- a/internal/api/room/screen.go +++ b/internal/api/room/screen.go @@ -16,25 +16,24 @@ type ScreenConfigurationPayload struct { Rate int16 `json:"rate"` } -func (h *RoomHandler) screenConfiguration(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) screenConfiguration(w http.ResponseWriter, r *http.Request) error { size := h.desktop.GetScreenSize() if size == nil { - utils.HttpInternalServerError(w, nil).WithInternalMsg("unable to get screen configuration").Send() - return + return utils.HttpInternalServerError().WithInternalMsg("unable to get screen configuration") } - utils.HttpSuccess(w, ScreenConfigurationPayload{ + return utils.HttpSuccess(w, ScreenConfigurationPayload{ Width: size.Width, Height: size.Height, Rate: size.Rate, }) } -func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.Request) error { data := &ScreenConfigurationPayload{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } if err := h.desktop.SetScreenSize(types.ScreenSize{ @@ -42,8 +41,7 @@ func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.R Height: data.Height, Rate: data.Rate, }); err != nil { - utils.HttpUnprocessableEntity(w).WithInternalErr(err).Msg("cannot set screen size") - return + return utils.HttpUnprocessableEntity("cannot set screen size").WithInternalErr(err) } h.sessions.Broadcast( @@ -54,10 +52,10 @@ func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.R Rate: data.Rate, }, nil) - utils.HttpSuccess(w, data) + return utils.HttpSuccess(w, data) } -func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Request) error { list := []ScreenConfigurationPayload{} ScreenConfigurations := h.desktop.ScreenConfigurations() @@ -71,10 +69,10 @@ func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Re } } - utils.HttpSuccess(w, list) + return utils.HttpSuccess(w, list) } -func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) error { quality, err := strconv.Atoi(r.URL.Query().Get("quality")) if err != nil { quality = 90 @@ -83,31 +81,30 @@ func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) { img := h.desktop.GetScreenshotImage() bytes, err := utils.CreateJPGImage(img, quality) if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Content-Type", "image/jpeg") - //nolint - w.Write(bytes) + + _, err = w.Write(bytes) + return err } -func (h *RoomHandler) screenCastGet(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) screenCastGet(w http.ResponseWriter, r *http.Request) error { screencast := h.capture.Screencast() if !screencast.Enabled() { - utils.HttpBadRequest(w).Msg("screencast pipeline is not enabled") - return + return utils.HttpBadRequest("screencast pipeline is not enabled") } bytes, err := screencast.Image() if err != nil { - utils.HttpInternalServerError(w, err).Send() - return + return utils.HttpInternalServerError().WithInternalErr(err) } w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Content-Type", "image/jpeg") - //nolint - w.Write(bytes) + + _, err = w.Write(bytes) + return err } diff --git a/internal/api/room/upload.go b/internal/api/room/upload.go index c4afa22d..894bcead 100644 --- a/internal/api/room/upload.go +++ b/internal/api/room/upload.go @@ -15,11 +15,10 @@ import ( // maximum upload size of 32 MB const maxUploadSize = 32 << 20 -func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) error { err := r.ParseMultipartForm(maxUploadSize) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("failed to parse multipart form") - return + return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err) } //nolint @@ -27,26 +26,24 @@ func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) { X, err := strconv.Atoi(r.FormValue("x")) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("no X coordinate received") - return + return utils.HttpBadRequest("no X coordinate received").WithInternalErr(err) } Y, err := strconv.Atoi(r.FormValue("y")) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("no Y coordinate received") - return + return utils.HttpBadRequest("no Y coordinate received").WithInternalErr(err) } req_files := r.MultipartForm.File["files"] if len(req_files) == 0 { - utils.HttpBadRequest(w).Msg("no files received") - return + return utils.HttpBadRequest("no files received") } dir, err := os.MkdirTemp("", "neko-drop-*") if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to create temporary directory").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to create temporary directory") } files := []string{} @@ -55,62 +52,65 @@ func (h *RoomHandler) uploadDrop(w http.ResponseWriter, r *http.Request) { srcFile, err := req_file.Open() if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to open uploaded file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to open uploaded file") } defer srcFile.Close() dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to open destination file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to open destination file") } defer dstFile.Close() _, err = io.Copy(dstFile, srcFile) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to copy uploaded file to destination file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to copy uploaded file to destination file") } files = append(files, path) } if !h.desktop.DropFiles(X, Y, files) { - utils.HttpInternalServerError(w, nil).WithInternalMsg("unable to drop files").Send() - return + return utils.HttpInternalServerError(). + WithInternalMsg("unable to drop files") } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) uploadDialogPost(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) uploadDialogPost(w http.ResponseWriter, r *http.Request) error { err := r.ParseMultipartForm(maxUploadSize) if err != nil { - utils.HttpBadRequest(w).WithInternalErr(err).Msg("failed to parse multipart form") - return + return utils.HttpBadRequest("failed to parse multipart form").WithInternalErr(err) } //nolint defer r.MultipartForm.RemoveAll() if !h.desktop.IsFileChooserDialogOpened() { - utils.HttpUnprocessableEntity(w).Msg("file chooser dialog is not open") - return + return utils.HttpUnprocessableEntity("file chooser dialog is not open") } req_files := r.MultipartForm.File["files"] if len(req_files) == 0 { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to copy uploaded file to destination file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to copy uploaded file to destination file") } dir, err := os.MkdirTemp("", "neko-dialog-*") if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to create temporary directory").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to create temporary directory") } for _, req_file := range req_files { @@ -118,41 +118,45 @@ func (h *RoomHandler) uploadDialogPost(w http.ResponseWriter, r *http.Request) { srcFile, err := req_file.Open() if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to open uploaded file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to open uploaded file") } defer srcFile.Close() dstFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to open destination file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to open destination file") } defer dstFile.Close() _, err = io.Copy(dstFile, srcFile) if err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to copy uploaded file to destination file").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to copy uploaded file to destination file") } } if err := h.desktop.HandleFileChooserDialog(dir); err != nil { - utils.HttpInternalServerError(w, err).WithInternalMsg("unable to handle file chooser dialog").Send() - return + return utils.HttpInternalServerError(). + WithInternalErr(err). + WithInternalMsg("unable to handle file chooser dialog") } - utils.HttpSuccess(w) + return utils.HttpSuccess(w) } -func (h *RoomHandler) uploadDialogClose(w http.ResponseWriter, r *http.Request) { +func (h *RoomHandler) uploadDialogClose(w http.ResponseWriter, r *http.Request) error { if !h.desktop.IsFileChooserDialogOpened() { - utils.HttpUnprocessableEntity(w).Msg("file chooser dialog is not open") - return + return utils.HttpUnprocessableEntity("file chooser dialog is not open") } h.desktop.CloseFileChooserDialog() - utils.HttpSuccess(w) + + return utils.HttpSuccess(w) } diff --git a/internal/api/router.go b/internal/api/router.go index b0e49619..c64b9e68 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -1,11 +1,10 @@ package api import ( + "context" "errors" "net/http" - "github.com/go-chi/chi" - "demodesk/neko/internal/api/members" "demodesk/neko/internal/api/room" "demodesk/neko/internal/config" @@ -19,7 +18,7 @@ type ApiManagerCtx struct { members types.MemberManager desktop types.DesktopManager capture types.CaptureManager - routers map[string]func(chi.Router) + routers map[string]func(types.Router) } func New( @@ -35,15 +34,15 @@ func New( members: members, desktop: desktop, capture: capture, - routers: make(map[string]func(chi.Router)), + routers: make(map[string]func(types.Router)), } } -func (api *ApiManagerCtx) Route(r chi.Router) { +func (api *ApiManagerCtx) Route(r types.Router) { r.Post("/login", api.Login) // Authenticated area - r.Group(func(r chi.Router) { + r.Group(func(r types.Router) { r.Use(api.Authenticate) r.Post("/logout", api.Logout) @@ -61,33 +60,29 @@ func (api *ApiManagerCtx) Route(r chi.Router) { } }) - r.Get("/health", func(w http.ResponseWriter, r *http.Request) { - //nolint - w.Write([]byte("true")) + r.Get("/health", func(w http.ResponseWriter, r *http.Request) error { + _, err := w.Write([]byte("true")) + return err }) } -func (api *ApiManagerCtx) Authenticate(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := api.sessions.Authenticate(r) - if err != nil { - if api.sessions.CookieEnabled() { - api.sessions.CookieClearToken(w, r) - } - - if errors.Is(err, types.ErrSessionLoginDisabled) { - utils.HttpForbidden(w).Msg("login is disabled for this session") - } else { - utils.HttpUnauthorized(w).WithInternalErr(err).Send() - } - - return +func (api *ApiManagerCtx) Authenticate(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, err := api.sessions.Authenticate(r) + if err != nil { + if api.sessions.CookieEnabled() { + api.sessions.CookieClearToken(w, r) } - next.ServeHTTP(w, auth.SetSession(r, session)) - }) + if errors.Is(err, types.ErrSessionLoginDisabled) { + return nil, utils.HttpForbidden("login is disabled for this session") + } + + return nil, utils.HttpUnauthorized().WithInternalErr(err) + } + + return auth.SetSession(r, session), nil } -func (api *ApiManagerCtx) AddRouter(path string, router func(chi.Router)) { +func (api *ApiManagerCtx) AddRouter(path string, router func(types.Router)) { api.routers[path] = router } diff --git a/internal/api/session.go b/internal/api/session.go index 79e8c294..7e97434d 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -20,16 +20,15 @@ type SessionDataPayload struct { State types.SessionState `json:"state"` } -func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) { +func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) error { data := &SessionLoginPayload{} - if !utils.HttpJsonRequest(w, r, data) { - return + if err := utils.HttpJsonRequest(w, r, data); err != nil { + return err } session, token, err := api.members.Login(data.Username, data.Password) if err != nil { - utils.HttpUnauthorized(w).WithInternalErr(err).Send() - return + return utils.HttpUnauthorized().WithInternalErr(err) } sessionData := SessionDataPayload{ @@ -44,29 +43,28 @@ func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) { sessionData.Token = token } - utils.HttpSuccess(w, sessionData) + return utils.HttpSuccess(w, sessionData) } -func (api *ApiManagerCtx) Logout(w http.ResponseWriter, r *http.Request) { - session := auth.GetSession(r) +func (api *ApiManagerCtx) Logout(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) err := api.members.Logout(session.ID()) if err != nil { - utils.HttpUnauthorized(w).WithInternalErr(err).Send() - return + return utils.HttpUnauthorized().WithInternalErr(err) } if api.sessions.CookieEnabled() { api.sessions.CookieClearToken(w, r) } - utils.HttpSuccess(w, true) + return utils.HttpSuccess(w, true) } -func (api *ApiManagerCtx) Whoami(w http.ResponseWriter, r *http.Request) { - session := auth.GetSession(r) +func (api *ApiManagerCtx) Whoami(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) - utils.HttpSuccess(w, SessionDataPayload{ + return utils.HttpSuccess(w, SessionDataPayload{ ID: session.ID(), Profile: session.Profile(), State: session.State(), diff --git a/internal/http/auth/auth.go b/internal/http/auth/auth.go index 9d25ad90..2b350e63 100644 --- a/internal/http/auth/auth.go +++ b/internal/http/auth/auth.go @@ -12,66 +12,56 @@ type key int const keySessionCtx key = iota -func SetSession(r *http.Request, session types.Session) *http.Request { - ctx := context.WithValue(r.Context(), keySessionCtx, session) - return r.WithContext(ctx) +func SetSession(r *http.Request, session types.Session) context.Context { + return context.WithValue(r.Context(), keySessionCtx, session) } -func GetSession(r *http.Request) types.Session { - return r.Context().Value(keySessionCtx).(types.Session) +func GetSession(r *http.Request) (types.Session, bool) { + session, ok := r.Context().Value(keySessionCtx).(types.Session) + return session, ok } -func AdminsOnly(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) - if !session.Profile().IsAdmin { - utils.HttpForbidden(w).Msg("session is not admin") - } else { - next.ServeHTTP(w, r) - } - }) +func AdminsOnly(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := GetSession(r) + if !ok || !session.Profile().IsAdmin { + return nil, utils.HttpForbidden("session is not admin") + } + + return nil, nil } -func HostsOnly(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) - if !session.IsHost() { - utils.HttpForbidden(w).Msg("session is not host") - } else { - next.ServeHTTP(w, r) - } - }) +func HostsOnly(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := GetSession(r) + if !ok || !session.IsHost() { + return nil, utils.HttpForbidden("session is not host") + } + + return nil, nil } -func CanWatchOnly(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) - if !session.Profile().CanWatch { - utils.HttpForbidden(w).Msg("session cannot watch") - } else { - next.ServeHTTP(w, r) - } - }) +func CanWatchOnly(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := GetSession(r) + if !ok || !session.Profile().CanWatch { + return nil, utils.HttpForbidden("session cannot watch") + } + + return nil, nil } -func CanHostOnly(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) - if !session.Profile().CanHost { - utils.HttpForbidden(w).Msg("session cannot host") - } else { - next.ServeHTTP(w, r) - } - }) +func CanHostOnly(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := GetSession(r) + if !ok || !session.Profile().CanHost { + return nil, utils.HttpForbidden("session cannot host") + } + + return nil, nil } -func CanAccessClipboardOnly(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) - if !session.Profile().CanAccessClipboard { - utils.HttpForbidden(w).Msg("session cannot access clipboard") - } else { - next.ServeHTTP(w, r) - } - }) +func CanAccessClipboardOnly(w http.ResponseWriter, r *http.Request) (context.Context, error) { + session, ok := GetSession(r) + if !ok || !session.Profile().CanAccessClipboard { + return nil, utils.HttpForbidden("session cannot access clipboard") + } + + return nil, nil } diff --git a/internal/types/api.go b/internal/types/api.go index 16571b91..d9625bb3 100644 --- a/internal/types/api.go +++ b/internal/types/api.go @@ -1,10 +1,6 @@ package types -import ( - "github.com/go-chi/chi" -) - type ApiManager interface { - Route(r chi.Router) - AddRouter(path string, router func(chi.Router)) + Route(r Router) + AddRouter(path string, router func(Router)) } diff --git a/internal/types/websocket.go b/internal/types/websocket.go index 885a0c41..e8c22d6f 100644 --- a/internal/types/websocket.go +++ b/internal/types/websocket.go @@ -23,5 +23,5 @@ type WebSocketManager interface { Start() Shutdown() error AddHandler(handler WebSocketHandler) - Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) + Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error }