http refactor.

This commit is contained in:
Miroslav Šedivý 2021-09-17 00:24:33 +02:00
parent 4fa11e6a2a
commit 5a7cdd31fe
6 changed files with 300 additions and 138 deletions

View File

@ -1,86 +1,129 @@
package http package http
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"demodesk/neko/internal/http/auth"
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
) )
func Logger(next http.Handler) http.Handler { type logEntryKey int
fn := func(w http.ResponseWriter, r *http.Request) {
req := map[string]interface{}{}
// exclude healthcheck from logs const logEntryKeyCtx logEntryKey = iota
if r.RequestURI == "/api/health" {
next.ServeHTTP(w, r) func setLogEntry(r *http.Request, data logEntry) *http.Request {
ctx := context.WithValue(r.Context(), logEntryKeyCtx, data)
return r.WithContext(ctx)
}
func getLogEntry(r *http.Request) logEntry {
return r.Context().Value(logEntryKeyCtx).(logEntry)
}
func LoggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, setLogEntry(r, newLogEntry(w, r)))
})
}
type logEntry struct {
req struct {
time time.Time
id string
scheme string
proto string
method string
remote string
agent string
uri string
}
res struct {
time time.Time
code int
bytes int
}
err error
elapsed time.Duration
hasSession bool
session types.Session
}
func newLogEntry(w http.ResponseWriter, r *http.Request) logEntry {
e := logEntry{}
e.req.time = time.Now()
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
e.req.id = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
e.req.scheme = scheme
e.req.proto = r.Proto
e.req.method = r.Method
e.req.remote = r.RemoteAddr
e.req.agent = r.UserAgent()
e.req.uri = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
return e
}
func (e *logEntry) SetResponse(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
e.res.time = time.Now()
e.res.code = ww.Status()
e.res.bytes = ww.BytesWritten()
e.elapsed = e.res.time.Sub(e.req.time)
e.session, e.hasSession = auth.GetSession(r)
}
func (e *logEntry) SetError(err error) {
e.err = err
}
func (e *logEntry) Write() {
logger := log.With().
Str("module", "http").
Float64("elapsed", float64(e.elapsed.Nanoseconds())/1000000.0).
Interface("req", e.req).
Interface("res", e.res).
Logger()
if e.hasSession {
logger = logger.With().Str("session_id", e.session.ID()).Logger()
}
if e.err != nil {
httpErr, ok := e.err.(*utils.HTTPError)
if !ok {
logger.Err(e.err).Msgf("request failed (%d)", e.res.code)
return return
} }
if reqID := middleware.GetReqID(r.Context()); reqID != "" { if httpErr.Message == "" {
req["id"] = reqID httpErr.Message = http.StatusText(httpErr.Code)
} }
scheme := "http" logger := logger.Error().Err(httpErr.InternalErr)
if r.TLS != nil {
scheme = "https" message := httpErr.Message
if httpErr.InternalMsg != "" {
message = httpErr.InternalMsg
} }
req["scheme"] = scheme logger.Msgf("request failed (%d): %s", e.res.code, message)
req["proto"] = r.Proto return
req["method"] = r.Method
req["remote"] = r.RemoteAddr
req["agent"] = r.UserAgent()
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
fields := map[string]interface{}{}
fields["req"] = req
entry := &entry{
fields: fields,
}
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), time.Since(t1))
}()
next.ServeHTTP(ww, r)
} }
return http.HandlerFunc(fn)
} logger.Debug().Msgf("request complete (%d)", e.res.code)
type entry struct {
fields map[string]interface{}
errors []map[string]interface{}
}
func (e *entry) Write(status, bytes int, elapsed time.Duration) {
res := map[string]interface{}{}
res["time"] = time.Now().UTC().Format(time.RFC1123)
res["status"] = status
res["bytes"] = bytes
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
e.fields["res"] = res
e.fields["module"] = "http"
if len(e.errors) > 0 {
e.fields["errors"] = e.errors
log.Error().Fields(e.fields).Msgf("request failed (%d)", status)
} else {
log.Debug().Fields(e.fields).Msgf("request complete (%d)", status)
}
}
func (e *entry) Panic(v interface{}, stack []byte) {
err := map[string]interface{}{}
err["message"] = fmt.Sprintf("%+v", v)
err["stack"] = string(stack)
e.errors = append(e.errors, err)
} }

View File

@ -5,8 +5,6 @@ import (
"net/http" "net/http"
"os" "os"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors" "github.com/go-chi/cors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -18,16 +16,15 @@ import (
type HttpManagerCtx struct { type HttpManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
config *config.Server config *config.Server
router *chi.Mux router *RouterCtx
http *http.Server http *http.Server
} }
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx { func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
logger := log.With().Str("module", "http").Logger() logger := log.With().Str("module", "http").Logger()
router := chi.NewRouter() router := newRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server router.UseBypass(cors.Handler(cors.Options{
router.Use(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool { AllowOriginFunc: func(r *http.Request, origin string) bool {
return config.AllowOrigin(origin) return config.AllowOrigin(origin)
}, },
@ -37,32 +34,28 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
AllowCredentials: true, AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers MaxAge: 300, // Maximum value not ignored by any of major browsers
})) }))
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(Logger) // Log API request calls using custom logger function
router.Route("/api", ApiManager.Route) router.Route("/api", ApiManager.Route)
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) { router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error {
WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
return config.AllowOrigin(r.Header.Get("Origin")) return config.AllowOrigin(r.Header.Get("Origin"))
}) })
}) })
if config.Static != "" { if config.Static != "" {
fs := http.FileServer(http.Dir(config.Static)) fs := http.FileServer(http.Dir(config.Static))
router.Get("/*", func(w http.ResponseWriter, r *http.Request) { router.Get("/*", func(w http.ResponseWriter, r *http.Request) error {
if _, err := os.Stat(config.Static + r.URL.Path); !os.IsNotExist(err) { _, err := os.Stat(config.Static + r.URL.Path)
if !os.IsNotExist(err) {
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
} }
return err
}) })
} }
router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
return &HttpManagerCtx{ return &HttpManagerCtx{
logger: logger, logger: logger,
config: config, config: config,

112
internal/http/router.go Normal file
View File

@ -0,0 +1,112 @@
package http
import (
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
"net/http"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
)
type RouterCtx struct {
chi chi.Router
}
func newRouter() *RouterCtx {
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(LoggerMiddleware)
return &RouterCtx{router}
}
func (r *RouterCtx) Group(fn func(types.Router)) {
r.chi.Group(func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Route(pattern string, fn func(types.Router)) {
r.chi.Route(pattern, func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Get(pattern string, fn types.RouterHandler) {
r.chi.Get(pattern, routeHandler(fn))
}
func (r *RouterCtx) Post(pattern string, fn types.RouterHandler) {
r.chi.Post(pattern, routeHandler(fn))
}
func (r *RouterCtx) Put(pattern string, fn types.RouterHandler) {
r.chi.Put(pattern, routeHandler(fn))
}
func (r *RouterCtx) Delete(pattern string, fn types.RouterHandler) {
r.chi.Delete(pattern, routeHandler(fn))
}
func (r *RouterCtx) With(fn types.MiddlewareHandler) types.Router {
c := r.chi.With(middlewareHandler(fn))
return &RouterCtx{c}
}
func (r *RouterCtx) WithBypass(fn func(next http.Handler) http.Handler) types.Router {
c := r.chi.With(fn)
return &RouterCtx{c}
}
func (r *RouterCtx) Use(fn types.MiddlewareHandler) {
r.chi.Use(middlewareHandler(fn))
}
func (r *RouterCtx) UseBypass(fn func(next http.Handler) http.Handler) {
r.chi.Use(fn)
}
func (r *RouterCtx) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.chi.ServeHTTP(w, req)
}
func errorHandler(err error, w http.ResponseWriter, r *http.Request) {
httpErr, ok := err.(*utils.HTTPError)
if !ok {
httpErr = utils.HttpInternalServerError().WithInternalErr(err)
}
utils.HttpJsonResponse(w, httpErr.Code, httpErr)
}
func routeHandler(fn types.RouterHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
if err := fn(w, r); err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
}
logEntry.SetResponse(w, r)
logEntry.Write()
}
}
func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
ctx, err := fn(w, r)
if err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
logEntry.SetResponse(w, r)
logEntry.Write()
return
}
if ctx != nil {
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}

28
internal/types/http.go Normal file
View File

@ -0,0 +1,28 @@
package types
import (
"context"
"net/http"
)
type RouterHandler func(w http.ResponseWriter, r *http.Request) error
type MiddlewareHandler func(w http.ResponseWriter, r *http.Request) (context.Context, error)
type Router interface {
Group(fn func(Router))
Route(pattern string, fn func(Router))
Get(pattern string, fn RouterHandler)
Post(pattern string, fn RouterHandler)
Put(pattern string, fn RouterHandler)
Delete(pattern string, fn RouterHandler)
With(fn MiddlewareHandler) Router
WithBypass(fn func(next http.Handler) http.Handler) Router
Use(fn MiddlewareHandler)
UseBypass(fn func(next http.Handler) http.Handler)
ServeHTTP(w http.ResponseWriter, req *http.Request)
}
type HttpManager interface {
Start()
Shutdown() error
}

View File

@ -9,18 +9,18 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) bool { func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) error {
if err := json.NewDecoder(r.Body).Decode(res); err != nil { err := json.NewDecoder(r.Body).Decode(res)
if err == io.EOF {
HttpBadRequest(w).WithInternalErr(err).Msg("no data provided")
} else {
HttpBadRequest(w).WithInternalErr(err).Msg("unable to parse provided data")
}
return false if err == nil {
return nil
} }
return true if err == io.EOF {
return HttpBadRequest("no data provided").WithInternalErr(err)
}
return HttpBadRequest("unable to parse provided data").WithInternalErr(err)
} }
func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) { func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) {
@ -32,12 +32,14 @@ func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) {
} }
} }
func HttpSuccess(w http.ResponseWriter, res ...interface{}) { func HttpSuccess(w http.ResponseWriter, res ...interface{}) error {
if len(res) == 0 { if len(res) == 0 {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} else { } else {
HttpJsonResponse(w, http.StatusOK, res[0]) HttpJsonResponse(w, http.StatusOK, res[0])
} }
return nil
} }
// HTTPError is an error with a message and an HTTP status code. // HTTPError is an error with a message and an HTTP status code.
@ -47,8 +49,6 @@ type HTTPError struct {
InternalErr error `json:"-"` InternalErr error `json:"-"`
InternalMsg string `json:"-"` InternalMsg string `json:"-"`
w http.ResponseWriter `json:"-"`
} }
func (e *HTTPError) Error() string { func (e *HTTPError) Error() string {
@ -84,64 +84,50 @@ func (e *HTTPError) WithInternalMsgf(fmtStr string, args ...interface{}) *HTTPEr
} }
// Sends error with custom formated message // Sends error with custom formated message
func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) { func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) *HTTPError {
e.Message = fmt.Sprintf(fmtSt, args...) e.Message = fmt.Sprintf(fmtSt, args...)
e.Send() return e
} }
// Sends error with custom message // Sends error with custom message
func (e *HTTPError) Msg(str string) { func (e *HTTPError) Msg(str string) *HTTPError {
e.Message = str e.Message = str
e.Send() return e
} }
// Sends error with default status text func HttpError(code int, res ...string) *HTTPError {
func (e *HTTPError) Send() { err := &HTTPError{
if e.Message == "" { Code: code,
e.Message = http.StatusText(e.Code) Message: http.StatusText(code),
} }
logger := log.Error(). if len(res) == 1 {
Err(e.InternalErr). err.Message = res[0]
Str("module", "http").
Int("code", e.Code)
message := e.Message
if e.InternalMsg != "" {
message = e.InternalMsg
} }
logger.Msg(message) return err
HttpJsonResponse(e.w, e.Code, e)
} }
func HttpError(w http.ResponseWriter, code int) *HTTPError { func HttpBadRequest(res ...string) *HTTPError {
return &HTTPError{ return HttpError(http.StatusBadRequest, res...)
Code: code,
w: w,
}
} }
func HttpBadRequest(w http.ResponseWriter) *HTTPError { func HttpUnauthorized(res ...string) *HTTPError {
return HttpError(w, http.StatusBadRequest) return HttpError(http.StatusUnauthorized, res...)
} }
func HttpUnauthorized(w http.ResponseWriter) *HTTPError { func HttpForbidden(res ...string) *HTTPError {
return HttpError(w, http.StatusUnauthorized) return HttpError(http.StatusForbidden, res...)
} }
func HttpForbidden(w http.ResponseWriter) *HTTPError { func HttpNotFound(res ...string) *HTTPError {
return HttpError(w, http.StatusForbidden) return HttpError(http.StatusNotFound, res...)
} }
func HttpNotFound(w http.ResponseWriter) *HTTPError { func HttpUnprocessableEntity(res ...string) *HTTPError {
return HttpError(w, http.StatusNotFound) return HttpError(http.StatusUnprocessableEntity, res...)
} }
func HttpUnprocessableEntity(w http.ResponseWriter) *HTTPError { func HttpInternalServerError(res ...string) *HTTPError {
return HttpError(w, http.StatusUnprocessableEntity) return HttpError(http.StatusInternalServerError, res...)
}
func HttpInternalServerError(w http.ResponseWriter, err error) *HTTPError {
return HttpError(w, http.StatusInternalServerError).WithInternalErr(err)
} }

View File

@ -145,7 +145,7 @@ func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) {
manager.handlers = append(manager.handlers, handler) manager.handlers = append(manager.handlers, handler)
} }
func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) { func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error {
manager.logger.Debug(). manager.logger.Debug().
Str("address", r.RemoteAddr). Str("address", r.RemoteAddr).
Str("agent", r.UserAgent()). Str("agent", r.UserAgent()).
@ -157,8 +157,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
connection, err := upgrader.Upgrade(w, r, nil) connection, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
manager.logger.Err(err).Msg("failed to upgrade connection") return err
return
} }
// create new peer // create new peer
@ -168,7 +167,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if err != nil { if err != nil {
manager.logger.Warn().Err(err).Msg("authentication failed") manager.logger.Warn().Err(err).Msg("authentication failed")
peer.Destroy(err.Error()) peer.Destroy(err.Error())
return return nil
} }
// add session id to all log messages // add session id to all log messages
@ -178,7 +177,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !session.Profile().CanConnect { if !session.Profile().CanConnect {
logger.Warn().Msg("connection disabled") logger.Warn().Msg("connection disabled")
peer.Destroy("connection disabled") peer.Destroy("connection disabled")
return return nil
} }
if session.State().IsConnected { if session.State().IsConnected {
@ -186,7 +185,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !manager.sessions.MercifulReconnect() { if !manager.sessions.MercifulReconnect() {
peer.Destroy("already connected") peer.Destroy("already connected")
return return nil
} }
logger.Info().Msg("replacing peer connection") logger.Info().Msg("replacing peer connection")
@ -211,6 +210,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
}() }()
manager.handle(connection, session) manager.handle(connection, session)
return nil
} }
func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) { func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {