From a18482b54e10a3c434a4263e3324ac3ee287edc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Sat, 14 Nov 2020 17:51:18 +0100 Subject: [PATCH] use websocket authentication & refactor. --- internal/api/member/handler.go | 6 +-- internal/api/room/clipboard.go | 24 ++-------- internal/api/room/handler.go | 17 +++---- internal/api/room/screen.go | 33 ++++--------- internal/api/router.go | 43 ++++++++++------- internal/api/utils/auth.go | 65 ------------------------- internal/api/utils/error.go | 88 ---------------------------------- internal/config/server.go | 14 ------ internal/utils/http.go | 76 +++++++++++++++++++++++++++++ 9 files changed, 124 insertions(+), 242 deletions(-) delete mode 100644 internal/api/utils/auth.go delete mode 100644 internal/api/utils/error.go create mode 100644 internal/utils/http.go diff --git a/internal/api/member/handler.go b/internal/api/member/handler.go index 82383fed..0f921a85 100644 --- a/internal/api/member/handler.go +++ b/internal/api/member/handler.go @@ -3,7 +3,6 @@ package member import ( "github.com/go-chi/chi" - "demodesk/neko/internal/api/utils" "demodesk/neko/internal/types" ) @@ -21,10 +20,7 @@ func New( } } -func (h *MemberHandler) Router( - usersOnly utils.HttpMiddleware, - adminsOnly utils.HttpMiddleware, -) *chi.Mux { +func (h *MemberHandler) Router() *chi.Mux { r := chi.NewRouter() return r diff --git a/internal/api/room/clipboard.go b/internal/api/room/clipboard.go index cef06f33..37259fe1 100644 --- a/internal/api/room/clipboard.go +++ b/internal/api/room/clipboard.go @@ -3,45 +3,29 @@ package room import ( "net/http" - "github.com/go-chi/render" - - "demodesk/neko/internal/api/utils" + "demodesk/neko/internal/utils" ) type ClipboardData struct { Text string `json:"text"` } -func (a *ClipboardData) Bind(r *http.Request) error { - // Bind will run after the unmarshalling is complete, its a - // good time to focus some post-processing after a decoding. - return nil -} - -func (a *ClipboardData) Render(w http.ResponseWriter, r *http.Request) error { - // Pre-processing before a response is marshalled and sent - // across the wire - return nil -} - func (h *RoomHandler) ClipboardRead(w http.ResponseWriter, r *http.Request) { // TODO: error check? text := h.desktop.ReadClipboard() - render.JSON(w, r, ClipboardData{ + utils.HttpSuccess(w, ClipboardData{ Text: text, }) } func (h *RoomHandler) ClipboardWrite(w http.ResponseWriter, r *http.Request) { data := &ClipboardData{} - if err := render.Bind(r, data); err != nil { - _ = render.Render(w, r, utils.ErrBadRequest(err)) + if !utils.HttpJsonRequest(w, r, data) { return } // TODO: error check? h.desktop.WriteClipboard(data.Text) - - w.WriteHeader(http.StatusNoContent) + utils.HttpSuccess(w) } diff --git a/internal/api/room/handler.go b/internal/api/room/handler.go index 7be10bcc..159755cb 100644 --- a/internal/api/room/handler.go +++ b/internal/api/room/handler.go @@ -3,7 +3,6 @@ package room import ( "github.com/go-chi/chi" - "demodesk/neko/internal/api/utils" "demodesk/neko/internal/types" ) @@ -27,20 +26,18 @@ func New( } } -func (h *RoomHandler) Router( - usersOnly utils.HttpMiddleware, - adminsOnly utils.HttpMiddleware, -) *chi.Mux { +func (h *RoomHandler) Router() *chi.Mux { r := chi.NewRouter() - + + // TODO: Authorizaton. r.Route("/screen", func(r chi.Router) { - r.With(usersOnly).Get("/", h.ScreenConfiguration) - r.With(adminsOnly).Post("/", h.ScreenConfigurationChange) + r.Get("/", h.ScreenConfiguration) + r.Post("/", h.ScreenConfigurationChange) - r.With(adminsOnly).Get("/configurations", h.ScreenConfigurationsList) + r.Get("/configurations", h.ScreenConfigurationsList) }) - r.With(adminsOnly).Route("/clipboard", func(r chi.Router) { + r.Route("/clipboard", func(r chi.Router) { r.Get("/", h.ClipboardRead) r.Post("/", h.ClipboardWrite) }) diff --git a/internal/api/room/screen.go b/internal/api/room/screen.go index 0e8d6c9b..b22808bf 100644 --- a/internal/api/room/screen.go +++ b/internal/api/room/screen.go @@ -3,9 +3,7 @@ package room import ( "net/http" - "github.com/go-chi/render" - - "demodesk/neko/internal/api/utils" + "demodesk/neko/internal/utils" ) type ScreenConfiguration struct { @@ -14,27 +12,15 @@ type ScreenConfiguration struct { Rate int `json:"rate"` } -func (a *ScreenConfiguration) Bind(r *http.Request) error { - // Bind will run after the unmarshalling is complete, its a - // good time to focus some post-processing after a decoding. - return nil -} - -func (a *ScreenConfiguration) Render(w http.ResponseWriter, r *http.Request) error { - // Pre-processing before a response is marshalled and sent - // across the wire - return nil -} - func (h *RoomHandler) ScreenConfiguration(w http.ResponseWriter, r *http.Request) { size := h.desktop.GetScreenSize() if size == nil { - _ = render.Render(w, r, utils.ErrMessage(500, "Unable to get screen configuration.")) + utils.HttpInternalServer(w, "Unable to get screen configuration.") return } - render.JSON(w, r, ScreenConfiguration{ + utils.HttpSuccess(w, ScreenConfiguration{ Width: size.Width, Height: size.Height, Rate: int(size.Rate), @@ -43,28 +29,27 @@ func (h *RoomHandler) ScreenConfiguration(w http.ResponseWriter, r *http.Request func (h *RoomHandler) ScreenConfigurationChange(w http.ResponseWriter, r *http.Request) { data := &ScreenConfiguration{} - if err := render.Bind(r, data); err != nil { - _ = render.Render(w, r, utils.ErrBadRequest(err)) + if !utils.HttpJsonRequest(w, r, data) { return } if err := h.desktop.ChangeScreenSize(data.Width, data.Height, data.Rate); err != nil { - _ = render.Render(w, r, utils.ErrUnprocessableEntity(err)) + utils.HttpUnprocessableEntity(w, err) return } // TODO: Broadcast change to all sessions. - render.JSON(w, r, data) + utils.HttpSuccess(w, data) } func (h *RoomHandler) ScreenConfigurationsList(w http.ResponseWriter, r *http.Request) { - list := []render.Renderer{} + list := []ScreenConfiguration{} ScreenConfigurations := h.desktop.ScreenConfigurations() for _, size := range ScreenConfigurations { for _, fps := range size.Rates { - list = append(list, &ScreenConfiguration{ + list = append(list, ScreenConfiguration{ Width: size.Width, Height: size.Height, Rate: int(fps), @@ -72,5 +57,5 @@ func (h *RoomHandler) ScreenConfigurationsList(w http.ResponseWriter, r *http.Re } } - _ = render.RenderList(w, r, list) + utils.HttpSuccess(w, list) } diff --git a/internal/api/router.go b/internal/api/router.go index 247144b9..735dd64a 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "github.com/go-chi/chi" @@ -8,8 +9,8 @@ import ( "demodesk/neko/internal/api/member" "demodesk/neko/internal/api/room" "demodesk/neko/internal/types" + "demodesk/neko/internal/utils" "demodesk/neko/internal/config" - "demodesk/neko/internal/api/utils" ) type ApiManagerCtx struct { @@ -18,8 +19,9 @@ type ApiManagerCtx struct { capture types.CaptureManager } -var AdminToken []byte -var UserToken []byte +const ( + keySessionCtx int = iota +) func New( sessions types.SessionManager, @@ -27,8 +29,6 @@ func New( capture types.CaptureManager, conf *config.Server, ) *ApiManagerCtx { - AdminToken = []byte(conf.AdminToken) - UserToken = []byte(conf.UserToken) return &ApiManagerCtx{ sessions: sessions, @@ -37,18 +37,29 @@ func New( } } -func (a *ApiManagerCtx) Mount(r *chi.Mux) { - memberHandler := member.New(a.sessions) - r.Mount("/member", memberHandler.Router(UsersOnly, AdminsOnly)) +func (api *ApiManagerCtx) Mount(r *chi.Mux) { + r.Use(api.Authenticate) - roomHandler := room.New(a.sessions, a.desktop, a.capture) - r.Mount("/room", roomHandler.Router(UsersOnly, AdminsOnly)) + memberHandler := member.New(api.sessions) + r.Mount("/member", memberHandler.Router()) + + roomHandler := room.New(api.sessions, api.desktop, api.capture) + r.Mount("/room", roomHandler.Router()) + + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + session, _ := r.Context().Value(keySessionCtx).(types.Session) + utils.HttpBadRequest(w, "Hi `" + session.ID() + "`, you are authenticated.") + }) } -func UsersOnly(next http.Handler) http.Handler { - return utils.AuthMiddleware(next, UserToken, AdminToken) -} - -func AdminsOnly(next http.Handler) http.Handler { - return utils.AuthMiddleware(next, AdminToken) +func (api *ApiManagerCtx) Authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := api.sessions.Authenticate(r) + if err != nil { + utils.HttpNotAuthenticated(w, err) + } else { + ctx := context.WithValue(r.Context(), keySessionCtx, session) + next.ServeHTTP(w, r.WithContext(ctx)) + } + }) } diff --git a/internal/api/utils/auth.go b/internal/api/utils/auth.go deleted file mode 100644 index 8d25bbfe..00000000 --- a/internal/api/utils/auth.go +++ /dev/null @@ -1,65 +0,0 @@ -package utils - -import ( - "context" - "fmt" - "net/http" - "strings" - - "github.com/go-chi/render" - "github.com/dgrijalva/jwt-go" -) - -type key int - -const ( - keyPrincipalID key = iota -) - -func GetUserName(r *http.Request) interface{} { - props, _ := r.Context().Value(keyPrincipalID).(jwt.MapClaims) - return props["user_name"] -} - -type HttpMiddleware = func(next http.Handler) http.Handler - -func AuthMiddleware(next http.Handler, jwtSecrets ...[]byte) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ") - if len(authHeader) != 2 { - _ = render.Render(w, r, ErrMessage(401, "Malformed JWT token.")) - return - } - - jwtToken := authHeader[1] - var jwtVerified *jwt.Token - var err error - for _, jwtSecret := range jwtSecrets { - jwtVerified, err = jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - - return jwtSecret, nil - }) - - if err == nil { - break - } - } - - if err != nil { - _ = render.Render(w, r, ErrMessage(401, "Invalid JWT token.")) - return - } - - if claims, ok := jwtVerified.Claims.(jwt.MapClaims); ok && jwtVerified.Valid { - ctx := context.WithValue(r.Context(), keyPrincipalID, claims) - // Access context values in handlers like this - // props, _ := r.Context().Value("props").(jwt.MapClaims) - next.ServeHTTP(w, r.WithContext(ctx)) - } else { - _ = render.Render(w, r, ErrMessage(401, "Unauthorized.")) - } - }) -} diff --git a/internal/api/utils/error.go b/internal/api/utils/error.go deleted file mode 100644 index 05ee8ed3..00000000 --- a/internal/api/utils/error.go +++ /dev/null @@ -1,88 +0,0 @@ -package utils - -import ( - "net/http" - - "github.com/go-chi/render" -) - -//-- -// Error response payloads & renderers -//-- - -// ErrResponse renderer type for handling all sorts of errors. -// -// In the best case scenario, the excellent github.com/pkg/errors package -// helps reveal information on the error, setting it on Err, and in the Render() -// method, using it to set the application-specific error code in AppCode. -type ErrResponse struct { - Err error `json:"-"` // low-level runtime error - HTTPStatusCode int `json:"-"` // http response status code - - StatusText string `json:"status"` // user-level status message - AppCode int64 `json:"code,omitempty"` // application-specific error code - ErrorText string `json:"error,omitempty"` // application-level error message, for debugging -} - -func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, e.HTTPStatusCode) - return nil -} - -func ErrMessage(HTTPStatusCode int, StatusText string) render.Renderer { - return &ErrResponse{ - HTTPStatusCode: HTTPStatusCode, - StatusText: StatusText, - } -} - -func ErrBadRequest(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: 400, - StatusText: "Bad request.", - ErrorText: err.Error(), - } -} - -func ErrUnprocessableEntity(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: 400, - StatusText: "Unprocessable Entity.", - ErrorText: err.Error(), - } -} - -func ErrInternalServer(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: 500, - StatusText: "Internal server error.", - ErrorText: err.Error(), - } -} - -func ErrNot(err error) render.Renderer { - return &ErrResponse{ - Err: err, - HTTPStatusCode: 422, - StatusText: "Error rendering response.", - ErrorText: err.Error(), - } -} - -var ErrNotAuthenticated = &ErrResponse{ - HTTPStatusCode: 401, - StatusText: "Invalid or missing access token.", -} - -var ErrNotAuthorized = &ErrResponse{ - HTTPStatusCode: 403, - StatusText: "Access token does not have the required scope.", -} - -var ErrNotFound = &ErrResponse{ - HTTPStatusCode: 404, - StatusText: "Resource not found.", -} diff --git a/internal/config/server.go b/internal/config/server.go index eab02938..0153dce4 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -11,8 +11,6 @@ type Server struct { Bind string Static string //Proxy bool - UserToken string - AdminToken string } func (Server) Init(cmd *cobra.Command) error { @@ -41,16 +39,6 @@ func (Server) Init(cmd *cobra.Command) error { // return err //} - cmd.PersistentFlags().String("user_token", "user_secret", "JWT token for users") - if err := viper.BindPFlag("user_token", cmd.PersistentFlags().Lookup("user_token")); err != nil { - return err - } - - cmd.PersistentFlags().String("admin_token", "admin_secret", "JWT token for admins") - if err := viper.BindPFlag("admin_token", cmd.PersistentFlags().Lookup("admin_token")); err != nil { - return err - } - return nil } @@ -60,6 +48,4 @@ func (s *Server) Set() { s.Bind = viper.GetString("bind") s.Static = viper.GetString("static") //s.Proxy = viper.GetBool("proxy") - s.UserToken = viper.GetString("user_token") - s.AdminToken = viper.GetString("admin_token") } diff --git a/internal/utils/http.go b/internal/utils/http.go new file mode 100644 index 00000000..0ec6d255 --- /dev/null +++ b/internal/utils/http.go @@ -0,0 +1,76 @@ +package utils + +import ( + "fmt" + "net/http" + "encoding/json" +) + +type ErrResponse struct { + Message string `json:"message"` +} + +func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) bool { + if err := json.NewDecoder(r.Body).Decode(res); err != nil { + HttpBadRequest(w, err) + return false + } + + return true +} + +func HttpJsonResponse(w http.ResponseWriter, status int, res interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + if err := json.NewEncoder(w).Encode(res); err != nil { + // TODO: Log. + //log.Warn().Err(err).Msg("failed writing json error response") + } +} + +func HttpError(w http.ResponseWriter, status int, res interface{}) { + HttpJsonResponse(w, status, &ErrResponse{ + Message: fmt.Sprint(res), + }) +} + +func HttpSuccess(w http.ResponseWriter, res ...interface{}) { + if len(res) == 0 { + w.WriteHeader(http.StatusNoContent) + } else { + HttpJsonResponse(w, http.StatusOK, res[0]) + } +} + +func HttpBadRequest(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusBadRequest, "Bad Request.", res...) +} + +func HttpNotAuthenticated(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusUnauthorized, "Invalid or missing access token.", res...) +} + +func HttpNotAuthorized(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusForbidden, "Access token does not have the required scope.", res...) +} + +func HttpNotFound(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusNotFound, "Resource not found.", res...) +} + +func HttpUnprocessableEntity(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusUnprocessableEntity, "Unprocessable Entity.", res...) +} + +func HttpInternalServer(w http.ResponseWriter, res ...interface{}) { + defHttpError(w, http.StatusInternalServerError, "Internal server error.", res...) +} + +func defHttpError(w http.ResponseWriter, status int, text string, res ...interface{}) { + if len(res) == 0 { + HttpError(w, status, text) + } else { + HttpError(w, status, res[0]) + } +}