use websocket authentication & refactor.

This commit is contained in:
Miroslav Šedivý 2020-11-14 17:51:18 +01:00
parent f136a31b03
commit a18482b54e
9 changed files with 124 additions and 242 deletions

View File

@ -3,7 +3,6 @@ package member
import ( import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"demodesk/neko/internal/api/utils"
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
) )
@ -21,10 +20,7 @@ func New(
} }
} }
func (h *MemberHandler) Router( func (h *MemberHandler) Router() *chi.Mux {
usersOnly utils.HttpMiddleware,
adminsOnly utils.HttpMiddleware,
) *chi.Mux {
r := chi.NewRouter() r := chi.NewRouter()
return r return r

View File

@ -3,45 +3,29 @@ package room
import ( import (
"net/http" "net/http"
"github.com/go-chi/render" "demodesk/neko/internal/utils"
"demodesk/neko/internal/api/utils"
) )
type ClipboardData struct { type ClipboardData struct {
Text string `json:"text"` Text string `json:"text"`
} }
func (a *ClipboardData) Bind(r *http.Request) error {
// Bind will run after the unmarshalling is complete, its a
// good time to focus some post-processing after a decoding.
return nil
}
func (a *ClipboardData) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent
// across the wire
return nil
}
func (h *RoomHandler) ClipboardRead(w http.ResponseWriter, r *http.Request) { func (h *RoomHandler) ClipboardRead(w http.ResponseWriter, r *http.Request) {
// TODO: error check? // TODO: error check?
text := h.desktop.ReadClipboard() text := h.desktop.ReadClipboard()
render.JSON(w, r, ClipboardData{ utils.HttpSuccess(w, ClipboardData{
Text: text, Text: text,
}) })
} }
func (h *RoomHandler) ClipboardWrite(w http.ResponseWriter, r *http.Request) { func (h *RoomHandler) ClipboardWrite(w http.ResponseWriter, r *http.Request) {
data := &ClipboardData{} data := &ClipboardData{}
if err := render.Bind(r, data); err != nil { if !utils.HttpJsonRequest(w, r, data) {
_ = render.Render(w, r, utils.ErrBadRequest(err))
return return
} }
// TODO: error check? // TODO: error check?
h.desktop.WriteClipboard(data.Text) h.desktop.WriteClipboard(data.Text)
utils.HttpSuccess(w)
w.WriteHeader(http.StatusNoContent)
} }

View File

@ -3,7 +3,6 @@ package room
import ( import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"demodesk/neko/internal/api/utils"
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
) )
@ -27,20 +26,18 @@ func New(
} }
} }
func (h *RoomHandler) Router( func (h *RoomHandler) Router() *chi.Mux {
usersOnly utils.HttpMiddleware,
adminsOnly utils.HttpMiddleware,
) *chi.Mux {
r := chi.NewRouter() r := chi.NewRouter()
// TODO: Authorizaton.
r.Route("/screen", func(r chi.Router) { r.Route("/screen", func(r chi.Router) {
r.With(usersOnly).Get("/", h.ScreenConfiguration) r.Get("/", h.ScreenConfiguration)
r.With(adminsOnly).Post("/", h.ScreenConfigurationChange) r.Post("/", h.ScreenConfigurationChange)
r.With(adminsOnly).Get("/configurations", h.ScreenConfigurationsList) r.Get("/configurations", h.ScreenConfigurationsList)
}) })
r.With(adminsOnly).Route("/clipboard", func(r chi.Router) { r.Route("/clipboard", func(r chi.Router) {
r.Get("/", h.ClipboardRead) r.Get("/", h.ClipboardRead)
r.Post("/", h.ClipboardWrite) r.Post("/", h.ClipboardWrite)
}) })

View File

@ -3,9 +3,7 @@ package room
import ( import (
"net/http" "net/http"
"github.com/go-chi/render" "demodesk/neko/internal/utils"
"demodesk/neko/internal/api/utils"
) )
type ScreenConfiguration struct { type ScreenConfiguration struct {
@ -14,27 +12,15 @@ type ScreenConfiguration struct {
Rate int `json:"rate"` Rate int `json:"rate"`
} }
func (a *ScreenConfiguration) Bind(r *http.Request) error {
// Bind will run after the unmarshalling is complete, its a
// good time to focus some post-processing after a decoding.
return nil
}
func (a *ScreenConfiguration) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent
// across the wire
return nil
}
func (h *RoomHandler) ScreenConfiguration(w http.ResponseWriter, r *http.Request) { func (h *RoomHandler) ScreenConfiguration(w http.ResponseWriter, r *http.Request) {
size := h.desktop.GetScreenSize() size := h.desktop.GetScreenSize()
if size == nil { if size == nil {
_ = render.Render(w, r, utils.ErrMessage(500, "Unable to get screen configuration.")) utils.HttpInternalServer(w, "Unable to get screen configuration.")
return return
} }
render.JSON(w, r, ScreenConfiguration{ utils.HttpSuccess(w, ScreenConfiguration{
Width: size.Width, Width: size.Width,
Height: size.Height, Height: size.Height,
Rate: int(size.Rate), Rate: int(size.Rate),
@ -43,28 +29,27 @@ func (h *RoomHandler) ScreenConfiguration(w http.ResponseWriter, r *http.Request
func (h *RoomHandler) ScreenConfigurationChange(w http.ResponseWriter, r *http.Request) { func (h *RoomHandler) ScreenConfigurationChange(w http.ResponseWriter, r *http.Request) {
data := &ScreenConfiguration{} data := &ScreenConfiguration{}
if err := render.Bind(r, data); err != nil { if !utils.HttpJsonRequest(w, r, data) {
_ = render.Render(w, r, utils.ErrBadRequest(err))
return return
} }
if err := h.desktop.ChangeScreenSize(data.Width, data.Height, data.Rate); err != nil { if err := h.desktop.ChangeScreenSize(data.Width, data.Height, data.Rate); err != nil {
_ = render.Render(w, r, utils.ErrUnprocessableEntity(err)) utils.HttpUnprocessableEntity(w, err)
return return
} }
// TODO: Broadcast change to all sessions. // TODO: Broadcast change to all sessions.
render.JSON(w, r, data) utils.HttpSuccess(w, data)
} }
func (h *RoomHandler) ScreenConfigurationsList(w http.ResponseWriter, r *http.Request) { func (h *RoomHandler) ScreenConfigurationsList(w http.ResponseWriter, r *http.Request) {
list := []render.Renderer{} list := []ScreenConfiguration{}
ScreenConfigurations := h.desktop.ScreenConfigurations() ScreenConfigurations := h.desktop.ScreenConfigurations()
for _, size := range ScreenConfigurations { for _, size := range ScreenConfigurations {
for _, fps := range size.Rates { for _, fps := range size.Rates {
list = append(list, &ScreenConfiguration{ list = append(list, ScreenConfiguration{
Width: size.Width, Width: size.Width,
Height: size.Height, Height: size.Height,
Rate: int(fps), Rate: int(fps),
@ -72,5 +57,5 @@ func (h *RoomHandler) ScreenConfigurationsList(w http.ResponseWriter, r *http.Re
} }
} }
_ = render.RenderList(w, r, list) utils.HttpSuccess(w, list)
} }

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
@ -8,8 +9,8 @@ import (
"demodesk/neko/internal/api/member" "demodesk/neko/internal/api/member"
"demodesk/neko/internal/api/room" "demodesk/neko/internal/api/room"
"demodesk/neko/internal/types" "demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
"demodesk/neko/internal/config" "demodesk/neko/internal/config"
"demodesk/neko/internal/api/utils"
) )
type ApiManagerCtx struct { type ApiManagerCtx struct {
@ -18,8 +19,9 @@ type ApiManagerCtx struct {
capture types.CaptureManager capture types.CaptureManager
} }
var AdminToken []byte const (
var UserToken []byte keySessionCtx int = iota
)
func New( func New(
sessions types.SessionManager, sessions types.SessionManager,
@ -27,8 +29,6 @@ func New(
capture types.CaptureManager, capture types.CaptureManager,
conf *config.Server, conf *config.Server,
) *ApiManagerCtx { ) *ApiManagerCtx {
AdminToken = []byte(conf.AdminToken)
UserToken = []byte(conf.UserToken)
return &ApiManagerCtx{ return &ApiManagerCtx{
sessions: sessions, sessions: sessions,
@ -37,18 +37,29 @@ func New(
} }
} }
func (a *ApiManagerCtx) Mount(r *chi.Mux) { func (api *ApiManagerCtx) Mount(r *chi.Mux) {
memberHandler := member.New(a.sessions) r.Use(api.Authenticate)
r.Mount("/member", memberHandler.Router(UsersOnly, AdminsOnly))
roomHandler := room.New(a.sessions, a.desktop, a.capture) memberHandler := member.New(api.sessions)
r.Mount("/room", roomHandler.Router(UsersOnly, AdminsOnly)) r.Mount("/member", memberHandler.Router())
roomHandler := room.New(api.sessions, api.desktop, api.capture)
r.Mount("/room", roomHandler.Router())
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
session, _ := r.Context().Value(keySessionCtx).(types.Session)
utils.HttpBadRequest(w, "Hi `" + session.ID() + "`, you are authenticated.")
})
} }
func UsersOnly(next http.Handler) http.Handler { func (api *ApiManagerCtx) Authenticate(next http.Handler) http.Handler {
return utils.AuthMiddleware(next, UserToken, AdminToken) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := api.sessions.Authenticate(r)
if err != nil {
utils.HttpNotAuthenticated(w, err)
} else {
ctx := context.WithValue(r.Context(), keySessionCtx, session)
next.ServeHTTP(w, r.WithContext(ctx))
} }
})
func AdminsOnly(next http.Handler) http.Handler {
return utils.AuthMiddleware(next, AdminToken)
} }

View File

@ -1,65 +0,0 @@
package utils
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/go-chi/render"
"github.com/dgrijalva/jwt-go"
)
type key int
const (
keyPrincipalID key = iota
)
func GetUserName(r *http.Request) interface{} {
props, _ := r.Context().Value(keyPrincipalID).(jwt.MapClaims)
return props["user_name"]
}
type HttpMiddleware = func(next http.Handler) http.Handler
func AuthMiddleware(next http.Handler, jwtSecrets ...[]byte) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ")
if len(authHeader) != 2 {
_ = render.Render(w, r, ErrMessage(401, "Malformed JWT token."))
return
}
jwtToken := authHeader[1]
var jwtVerified *jwt.Token
var err error
for _, jwtSecret := range jwtSecrets {
jwtVerified, err = jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err == nil {
break
}
}
if err != nil {
_ = render.Render(w, r, ErrMessage(401, "Invalid JWT token."))
return
}
if claims, ok := jwtVerified.Claims.(jwt.MapClaims); ok && jwtVerified.Valid {
ctx := context.WithValue(r.Context(), keyPrincipalID, claims)
// Access context values in handlers like this
// props, _ := r.Context().Value("props").(jwt.MapClaims)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
_ = render.Render(w, r, ErrMessage(401, "Unauthorized."))
}
})
}

View File

@ -1,88 +0,0 @@
package utils
import (
"net/http"
"github.com/go-chi/render"
)
//--
// Error response payloads & renderers
//--
// ErrResponse renderer type for handling all sorts of errors.
//
// In the best case scenario, the excellent github.com/pkg/errors package
// helps reveal information on the error, setting it on Err, and in the Render()
// method, using it to set the application-specific error code in AppCode.
type ErrResponse struct {
Err error `json:"-"` // low-level runtime error
HTTPStatusCode int `json:"-"` // http response status code
StatusText string `json:"status"` // user-level status message
AppCode int64 `json:"code,omitempty"` // application-specific error code
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
}
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
func ErrMessage(HTTPStatusCode int, StatusText string) render.Renderer {
return &ErrResponse{
HTTPStatusCode: HTTPStatusCode,
StatusText: StatusText,
}
}
func ErrBadRequest(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 400,
StatusText: "Bad request.",
ErrorText: err.Error(),
}
}
func ErrUnprocessableEntity(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 400,
StatusText: "Unprocessable Entity.",
ErrorText: err.Error(),
}
}
func ErrInternalServer(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 500,
StatusText: "Internal server error.",
ErrorText: err.Error(),
}
}
func ErrNot(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 422,
StatusText: "Error rendering response.",
ErrorText: err.Error(),
}
}
var ErrNotAuthenticated = &ErrResponse{
HTTPStatusCode: 401,
StatusText: "Invalid or missing access token.",
}
var ErrNotAuthorized = &ErrResponse{
HTTPStatusCode: 403,
StatusText: "Access token does not have the required scope.",
}
var ErrNotFound = &ErrResponse{
HTTPStatusCode: 404,
StatusText: "Resource not found.",
}

View File

@ -11,8 +11,6 @@ type Server struct {
Bind string Bind string
Static string Static string
//Proxy bool //Proxy bool
UserToken string
AdminToken string
} }
func (Server) Init(cmd *cobra.Command) error { func (Server) Init(cmd *cobra.Command) error {
@ -41,16 +39,6 @@ func (Server) Init(cmd *cobra.Command) error {
// return err // return err
//} //}
cmd.PersistentFlags().String("user_token", "user_secret", "JWT token for users")
if err := viper.BindPFlag("user_token", cmd.PersistentFlags().Lookup("user_token")); err != nil {
return err
}
cmd.PersistentFlags().String("admin_token", "admin_secret", "JWT token for admins")
if err := viper.BindPFlag("admin_token", cmd.PersistentFlags().Lookup("admin_token")); err != nil {
return err
}
return nil return nil
} }
@ -60,6 +48,4 @@ func (s *Server) Set() {
s.Bind = viper.GetString("bind") s.Bind = viper.GetString("bind")
s.Static = viper.GetString("static") s.Static = viper.GetString("static")
//s.Proxy = viper.GetBool("proxy") //s.Proxy = viper.GetBool("proxy")
s.UserToken = viper.GetString("user_token")
s.AdminToken = viper.GetString("admin_token")
} }

76
internal/utils/http.go Normal file
View File

@ -0,0 +1,76 @@
package utils
import (
"fmt"
"net/http"
"encoding/json"
)
type ErrResponse struct {
Message string `json:"message"`
}
func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) bool {
if err := json.NewDecoder(r.Body).Decode(res); err != nil {
HttpBadRequest(w, err)
return false
}
return true
}
func HttpJsonResponse(w http.ResponseWriter, status int, res interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(res); err != nil {
// TODO: Log.
//log.Warn().Err(err).Msg("failed writing json error response")
}
}
func HttpError(w http.ResponseWriter, status int, res interface{}) {
HttpJsonResponse(w, status, &ErrResponse{
Message: fmt.Sprint(res),
})
}
func HttpSuccess(w http.ResponseWriter, res ...interface{}) {
if len(res) == 0 {
w.WriteHeader(http.StatusNoContent)
} else {
HttpJsonResponse(w, http.StatusOK, res[0])
}
}
func HttpBadRequest(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusBadRequest, "Bad Request.", res...)
}
func HttpNotAuthenticated(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusUnauthorized, "Invalid or missing access token.", res...)
}
func HttpNotAuthorized(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusForbidden, "Access token does not have the required scope.", res...)
}
func HttpNotFound(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusNotFound, "Resource not found.", res...)
}
func HttpUnprocessableEntity(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusUnprocessableEntity, "Unprocessable Entity.", res...)
}
func HttpInternalServer(w http.ResponseWriter, res ...interface{}) {
defHttpError(w, http.StatusInternalServerError, "Internal server error.", res...)
}
func defHttpError(w http.ResponseWriter, status int, text string, res ...interface{}) {
if len(res) == 0 {
HttpError(w, status, text)
} else {
HttpError(w, status, res[0])
}
}