http refactor.

This commit is contained in:
Miroslav Šedivý 2021-09-17 00:24:33 +02:00
parent 4fa11e6a2a
commit 5a7cdd31fe
6 changed files with 300 additions and 138 deletions

View File

@ -1,86 +1,129 @@
package http
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/middleware"
"github.com/rs/zerolog/log"
"demodesk/neko/internal/http/auth"
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
)
func Logger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
req := map[string]interface{}{}
type logEntryKey int
// exclude healthcheck from logs
if r.RequestURI == "/api/health" {
next.ServeHTTP(w, r)
const logEntryKeyCtx logEntryKey = iota
func setLogEntry(r *http.Request, data logEntry) *http.Request {
ctx := context.WithValue(r.Context(), logEntryKeyCtx, data)
return r.WithContext(ctx)
}
func getLogEntry(r *http.Request) logEntry {
return r.Context().Value(logEntryKeyCtx).(logEntry)
}
func LoggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, setLogEntry(r, newLogEntry(w, r)))
})
}
type logEntry struct {
req struct {
time time.Time
id string
scheme string
proto string
method string
remote string
agent string
uri string
}
res struct {
time time.Time
code int
bytes int
}
err error
elapsed time.Duration
hasSession bool
session types.Session
}
func newLogEntry(w http.ResponseWriter, r *http.Request) logEntry {
e := logEntry{}
e.req.time = time.Now()
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
e.req.id = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
e.req.scheme = scheme
e.req.proto = r.Proto
e.req.method = r.Method
e.req.remote = r.RemoteAddr
e.req.agent = r.UserAgent()
e.req.uri = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
return e
}
func (e *logEntry) SetResponse(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
e.res.time = time.Now()
e.res.code = ww.Status()
e.res.bytes = ww.BytesWritten()
e.elapsed = e.res.time.Sub(e.req.time)
e.session, e.hasSession = auth.GetSession(r)
}
func (e *logEntry) SetError(err error) {
e.err = err
}
func (e *logEntry) Write() {
logger := log.With().
Str("module", "http").
Float64("elapsed", float64(e.elapsed.Nanoseconds())/1000000.0).
Interface("req", e.req).
Interface("res", e.res).
Logger()
if e.hasSession {
logger = logger.With().Str("session_id", e.session.ID()).Logger()
}
if e.err != nil {
httpErr, ok := e.err.(*utils.HTTPError)
if !ok {
logger.Err(e.err).Msgf("request failed (%d)", e.res.code)
return
}
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
req["id"] = reqID
if httpErr.Message == "" {
httpErr.Message = http.StatusText(httpErr.Code)
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
logger := logger.Error().Err(httpErr.InternalErr)
message := httpErr.Message
if httpErr.InternalMsg != "" {
message = httpErr.InternalMsg
}
req["scheme"] = scheme
req["proto"] = r.Proto
req["method"] = r.Method
req["remote"] = r.RemoteAddr
req["agent"] = r.UserAgent()
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
fields := map[string]interface{}{}
fields["req"] = req
entry := &entry{
fields: fields,
}
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), time.Since(t1))
}()
next.ServeHTTP(ww, r)
logger.Msgf("request failed (%d): %s", e.res.code, message)
return
}
return http.HandlerFunc(fn)
}
type entry struct {
fields map[string]interface{}
errors []map[string]interface{}
}
func (e *entry) Write(status, bytes int, elapsed time.Duration) {
res := map[string]interface{}{}
res["time"] = time.Now().UTC().Format(time.RFC1123)
res["status"] = status
res["bytes"] = bytes
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
e.fields["res"] = res
e.fields["module"] = "http"
if len(e.errors) > 0 {
e.fields["errors"] = e.errors
log.Error().Fields(e.fields).Msgf("request failed (%d)", status)
} else {
log.Debug().Fields(e.fields).Msgf("request complete (%d)", status)
}
}
func (e *entry) Panic(v interface{}, stack []byte) {
err := map[string]interface{}{}
err["message"] = fmt.Sprintf("%+v", v)
err["stack"] = string(stack)
e.errors = append(e.errors, err)
logger.Debug().Msgf("request complete (%d)", e.res.code)
}

View File

@ -5,8 +5,6 @@ import (
"net/http"
"os"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -18,16 +16,15 @@ import (
type HttpManagerCtx struct {
logger zerolog.Logger
config *config.Server
router *chi.Mux
router *RouterCtx
http *http.Server
}
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
logger := log.With().Str("module", "http").Logger()
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(cors.Handler(cors.Options{
router := newRouter()
router.UseBypass(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return config.AllowOrigin(origin)
},
@ -37,32 +34,28 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(Logger) // Log API request calls using custom logger function
router.Route("/api", ApiManager.Route)
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error {
return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
return config.AllowOrigin(r.Header.Get("Origin"))
})
})
if config.Static != "" {
fs := http.FileServer(http.Dir(config.Static))
router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
if _, err := os.Stat(config.Static + r.URL.Path); !os.IsNotExist(err) {
router.Get("/*", func(w http.ResponseWriter, r *http.Request) error {
_, err := os.Stat(config.Static + r.URL.Path)
if !os.IsNotExist(err) {
fs.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
return err
})
}
router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
return &HttpManagerCtx{
logger: logger,
config: config,

112
internal/http/router.go Normal file
View File

@ -0,0 +1,112 @@
package http
import (
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
"net/http"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
)
type RouterCtx struct {
chi chi.Router
}
func newRouter() *RouterCtx {
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(LoggerMiddleware)
return &RouterCtx{router}
}
func (r *RouterCtx) Group(fn func(types.Router)) {
r.chi.Group(func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Route(pattern string, fn func(types.Router)) {
r.chi.Route(pattern, func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Get(pattern string, fn types.RouterHandler) {
r.chi.Get(pattern, routeHandler(fn))
}
func (r *RouterCtx) Post(pattern string, fn types.RouterHandler) {
r.chi.Post(pattern, routeHandler(fn))
}
func (r *RouterCtx) Put(pattern string, fn types.RouterHandler) {
r.chi.Put(pattern, routeHandler(fn))
}
func (r *RouterCtx) Delete(pattern string, fn types.RouterHandler) {
r.chi.Delete(pattern, routeHandler(fn))
}
func (r *RouterCtx) With(fn types.MiddlewareHandler) types.Router {
c := r.chi.With(middlewareHandler(fn))
return &RouterCtx{c}
}
func (r *RouterCtx) WithBypass(fn func(next http.Handler) http.Handler) types.Router {
c := r.chi.With(fn)
return &RouterCtx{c}
}
func (r *RouterCtx) Use(fn types.MiddlewareHandler) {
r.chi.Use(middlewareHandler(fn))
}
func (r *RouterCtx) UseBypass(fn func(next http.Handler) http.Handler) {
r.chi.Use(fn)
}
func (r *RouterCtx) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.chi.ServeHTTP(w, req)
}
func errorHandler(err error, w http.ResponseWriter, r *http.Request) {
httpErr, ok := err.(*utils.HTTPError)
if !ok {
httpErr = utils.HttpInternalServerError().WithInternalErr(err)
}
utils.HttpJsonResponse(w, httpErr.Code, httpErr)
}
func routeHandler(fn types.RouterHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
if err := fn(w, r); err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
}
logEntry.SetResponse(w, r)
logEntry.Write()
}
}
func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
ctx, err := fn(w, r)
if err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
logEntry.SetResponse(w, r)
logEntry.Write()
return
}
if ctx != nil {
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}

28
internal/types/http.go Normal file
View File

@ -0,0 +1,28 @@
package types
import (
"context"
"net/http"
)
type RouterHandler func(w http.ResponseWriter, r *http.Request) error
type MiddlewareHandler func(w http.ResponseWriter, r *http.Request) (context.Context, error)
type Router interface {
Group(fn func(Router))
Route(pattern string, fn func(Router))
Get(pattern string, fn RouterHandler)
Post(pattern string, fn RouterHandler)
Put(pattern string, fn RouterHandler)
Delete(pattern string, fn RouterHandler)
With(fn MiddlewareHandler) Router
WithBypass(fn func(next http.Handler) http.Handler) Router
Use(fn MiddlewareHandler)
UseBypass(fn func(next http.Handler) http.Handler)
ServeHTTP(w http.ResponseWriter, req *http.Request)
}
type HttpManager interface {
Start()
Shutdown() error
}

View File

@ -9,18 +9,18 @@ import (
"github.com/rs/zerolog/log"
)
func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) bool {
if err := json.NewDecoder(r.Body).Decode(res); err != nil {
if err == io.EOF {
HttpBadRequest(w).WithInternalErr(err).Msg("no data provided")
} else {
HttpBadRequest(w).WithInternalErr(err).Msg("unable to parse provided data")
}
func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) error {
err := json.NewDecoder(r.Body).Decode(res)
return false
if err == nil {
return nil
}
return true
if err == io.EOF {
return HttpBadRequest("no data provided").WithInternalErr(err)
}
return HttpBadRequest("unable to parse provided data").WithInternalErr(err)
}
func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) {
@ -32,12 +32,14 @@ func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) {
}
}
func HttpSuccess(w http.ResponseWriter, res ...interface{}) {
func HttpSuccess(w http.ResponseWriter, res ...interface{}) error {
if len(res) == 0 {
w.WriteHeader(http.StatusNoContent)
} else {
HttpJsonResponse(w, http.StatusOK, res[0])
}
return nil
}
// HTTPError is an error with a message and an HTTP status code.
@ -47,8 +49,6 @@ type HTTPError struct {
InternalErr error `json:"-"`
InternalMsg string `json:"-"`
w http.ResponseWriter `json:"-"`
}
func (e *HTTPError) Error() string {
@ -84,64 +84,50 @@ func (e *HTTPError) WithInternalMsgf(fmtStr string, args ...interface{}) *HTTPEr
}
// Sends error with custom formated message
func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) {
func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) *HTTPError {
e.Message = fmt.Sprintf(fmtSt, args...)
e.Send()
return e
}
// Sends error with custom message
func (e *HTTPError) Msg(str string) {
func (e *HTTPError) Msg(str string) *HTTPError {
e.Message = str
e.Send()
return e
}
// Sends error with default status text
func (e *HTTPError) Send() {
if e.Message == "" {
e.Message = http.StatusText(e.Code)
func HttpError(code int, res ...string) *HTTPError {
err := &HTTPError{
Code: code,
Message: http.StatusText(code),
}
logger := log.Error().
Err(e.InternalErr).
Str("module", "http").
Int("code", e.Code)
message := e.Message
if e.InternalMsg != "" {
message = e.InternalMsg
if len(res) == 1 {
err.Message = res[0]
}
logger.Msg(message)
HttpJsonResponse(e.w, e.Code, e)
return err
}
func HttpError(w http.ResponseWriter, code int) *HTTPError {
return &HTTPError{
Code: code,
w: w,
}
func HttpBadRequest(res ...string) *HTTPError {
return HttpError(http.StatusBadRequest, res...)
}
func HttpBadRequest(w http.ResponseWriter) *HTTPError {
return HttpError(w, http.StatusBadRequest)
func HttpUnauthorized(res ...string) *HTTPError {
return HttpError(http.StatusUnauthorized, res...)
}
func HttpUnauthorized(w http.ResponseWriter) *HTTPError {
return HttpError(w, http.StatusUnauthorized)
func HttpForbidden(res ...string) *HTTPError {
return HttpError(http.StatusForbidden, res...)
}
func HttpForbidden(w http.ResponseWriter) *HTTPError {
return HttpError(w, http.StatusForbidden)
func HttpNotFound(res ...string) *HTTPError {
return HttpError(http.StatusNotFound, res...)
}
func HttpNotFound(w http.ResponseWriter) *HTTPError {
return HttpError(w, http.StatusNotFound)
func HttpUnprocessableEntity(res ...string) *HTTPError {
return HttpError(http.StatusUnprocessableEntity, res...)
}
func HttpUnprocessableEntity(w http.ResponseWriter) *HTTPError {
return HttpError(w, http.StatusUnprocessableEntity)
}
func HttpInternalServerError(w http.ResponseWriter, err error) *HTTPError {
return HttpError(w, http.StatusInternalServerError).WithInternalErr(err)
func HttpInternalServerError(res ...string) *HTTPError {
return HttpError(http.StatusInternalServerError, res...)
}

View File

@ -145,7 +145,7 @@ func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) {
manager.handlers = append(manager.handlers, handler)
}
func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) {
func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error {
manager.logger.Debug().
Str("address", r.RemoteAddr).
Str("agent", r.UserAgent()).
@ -157,8 +157,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
connection, err := upgrader.Upgrade(w, r, nil)
if err != nil {
manager.logger.Err(err).Msg("failed to upgrade connection")
return
return err
}
// create new peer
@ -168,7 +167,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if err != nil {
manager.logger.Warn().Err(err).Msg("authentication failed")
peer.Destroy(err.Error())
return
return nil
}
// add session id to all log messages
@ -178,7 +177,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !session.Profile().CanConnect {
logger.Warn().Msg("connection disabled")
peer.Destroy("connection disabled")
return
return nil
}
if session.State().IsConnected {
@ -186,7 +185,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
if !manager.sessions.MercifulReconnect() {
peer.Destroy("already connected")
return
return nil
}
logger.Info().Msg("replacing peer connection")
@ -211,6 +210,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque
}()
manager.handle(connection, session)
return nil
}
func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {