diff --git a/internal/http/logger.go b/internal/http/logger.go index 419be2a7..fead8fdd 100644 --- a/internal/http/logger.go +++ b/internal/http/logger.go @@ -1,86 +1,129 @@ package http import ( + "context" "fmt" "net/http" "time" "github.com/go-chi/chi/middleware" "github.com/rs/zerolog/log" + + "demodesk/neko/internal/http/auth" + "demodesk/neko/internal/types" + "demodesk/neko/internal/utils" ) -func Logger(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - req := map[string]interface{}{} +type logEntryKey int - // exclude healthcheck from logs - if r.RequestURI == "/api/health" { - next.ServeHTTP(w, r) +const logEntryKeyCtx logEntryKey = iota + +func setLogEntry(r *http.Request, data logEntry) *http.Request { + ctx := context.WithValue(r.Context(), logEntryKeyCtx, data) + return r.WithContext(ctx) +} + +func getLogEntry(r *http.Request) logEntry { + return r.Context().Value(logEntryKeyCtx).(logEntry) +} + +func LoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, setLogEntry(r, newLogEntry(w, r))) + }) +} + +type logEntry struct { + req struct { + time time.Time + id string + scheme string + proto string + method string + remote string + agent string + uri string + } + res struct { + time time.Time + code int + bytes int + } + err error + elapsed time.Duration + hasSession bool + session types.Session +} + +func newLogEntry(w http.ResponseWriter, r *http.Request) logEntry { + e := logEntry{} + e.req.time = time.Now() + + if reqID := middleware.GetReqID(r.Context()); reqID != "" { + e.req.id = reqID + } + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + e.req.scheme = scheme + e.req.proto = r.Proto + e.req.method = r.Method + e.req.remote = r.RemoteAddr + e.req.agent = r.UserAgent() + e.req.uri = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI) + return e +} + +func (e *logEntry) SetResponse(w http.ResponseWriter, r *http.Request) { + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + e.res.time = time.Now() + e.res.code = ww.Status() + e.res.bytes = ww.BytesWritten() + + e.elapsed = e.res.time.Sub(e.req.time) + e.session, e.hasSession = auth.GetSession(r) +} + +func (e *logEntry) SetError(err error) { + e.err = err +} + +func (e *logEntry) Write() { + logger := log.With(). + Str("module", "http"). + Float64("elapsed", float64(e.elapsed.Nanoseconds())/1000000.0). + Interface("req", e.req). + Interface("res", e.res). + Logger() + + if e.hasSession { + logger = logger.With().Str("session_id", e.session.ID()).Logger() + } + + if e.err != nil { + httpErr, ok := e.err.(*utils.HTTPError) + if !ok { + logger.Err(e.err).Msgf("request failed (%d)", e.res.code) return } - if reqID := middleware.GetReqID(r.Context()); reqID != "" { - req["id"] = reqID + if httpErr.Message == "" { + httpErr.Message = http.StatusText(httpErr.Code) } - scheme := "http" - if r.TLS != nil { - scheme = "https" + logger := logger.Error().Err(httpErr.InternalErr) + + message := httpErr.Message + if httpErr.InternalMsg != "" { + message = httpErr.InternalMsg } - req["scheme"] = scheme - req["proto"] = r.Proto - req["method"] = r.Method - req["remote"] = r.RemoteAddr - req["agent"] = r.UserAgent() - req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI) - - fields := map[string]interface{}{} - fields["req"] = req - - entry := &entry{ - fields: fields, - } - - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - t1 := time.Now() - - defer func() { - entry.Write(ww.Status(), ww.BytesWritten(), time.Since(t1)) - }() - - next.ServeHTTP(ww, r) + logger.Msgf("request failed (%d): %s", e.res.code, message) + return } - return http.HandlerFunc(fn) -} - -type entry struct { - fields map[string]interface{} - errors []map[string]interface{} -} - -func (e *entry) Write(status, bytes int, elapsed time.Duration) { - res := map[string]interface{}{} - res["time"] = time.Now().UTC().Format(time.RFC1123) - res["status"] = status - res["bytes"] = bytes - res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0 - - e.fields["res"] = res - e.fields["module"] = "http" - - if len(e.errors) > 0 { - e.fields["errors"] = e.errors - log.Error().Fields(e.fields).Msgf("request failed (%d)", status) - } else { - log.Debug().Fields(e.fields).Msgf("request complete (%d)", status) - } -} - -func (e *entry) Panic(v interface{}, stack []byte) { - err := map[string]interface{}{} - err["message"] = fmt.Sprintf("%+v", v) - err["stack"] = string(stack) - - e.errors = append(e.errors, err) + + logger.Debug().Msgf("request complete (%d)", e.res.code) } diff --git a/internal/http/manager.go b/internal/http/manager.go index 47a461d8..c5bc750d 100644 --- a/internal/http/manager.go +++ b/internal/http/manager.go @@ -5,8 +5,6 @@ import ( "net/http" "os" - "github.com/go-chi/chi" - "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -18,16 +16,15 @@ import ( type HttpManagerCtx struct { logger zerolog.Logger config *config.Server - router *chi.Mux + router *RouterCtx http *http.Server } func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx { logger := log.With().Str("module", "http").Logger() - router := chi.NewRouter() - router.Use(middleware.Recoverer) // Recover from panics without crashing server - router.Use(cors.Handler(cors.Options{ + router := newRouter() + router.UseBypass(cors.Handler(cors.Options{ AllowOriginFunc: func(r *http.Request, origin string) bool { return config.AllowOrigin(origin) }, @@ -37,32 +34,28 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c AllowCredentials: true, MaxAge: 300, // Maximum value not ignored by any of major browsers })) - router.Use(middleware.RequestID) // Create a request ID for each request - router.Use(Logger) // Log API request calls using custom logger function router.Route("/api", ApiManager.Route) - router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) { - WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { + router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error { + return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { return config.AllowOrigin(r.Header.Get("Origin")) }) }) if config.Static != "" { fs := http.FileServer(http.Dir(config.Static)) - router.Get("/*", func(w http.ResponseWriter, r *http.Request) { - if _, err := os.Stat(config.Static + r.URL.Path); !os.IsNotExist(err) { + router.Get("/*", func(w http.ResponseWriter, r *http.Request) error { + _, err := os.Stat(config.Static + r.URL.Path) + + if !os.IsNotExist(err) { fs.ServeHTTP(w, r) - } else { - http.NotFound(w, r) } + + return err }) } - router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.NotFound(w, r) - })) - return &HttpManagerCtx{ logger: logger, config: config, diff --git a/internal/http/router.go b/internal/http/router.go new file mode 100644 index 00000000..87679a4d --- /dev/null +++ b/internal/http/router.go @@ -0,0 +1,112 @@ +package http + +import ( + "demodesk/neko/internal/types" + "demodesk/neko/internal/utils" + "net/http" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" +) + +type RouterCtx struct { + chi chi.Router +} + +func newRouter() *RouterCtx { + router := chi.NewRouter() + router.Use(middleware.Recoverer) // Recover from panics without crashing server + router.Use(middleware.RequestID) // Create a request ID for each request + router.Use(LoggerMiddleware) + return &RouterCtx{router} +} + +func (r *RouterCtx) Group(fn func(types.Router)) { + r.chi.Group(func(c chi.Router) { + fn(&RouterCtx{c}) + }) +} + +func (r *RouterCtx) Route(pattern string, fn func(types.Router)) { + r.chi.Route(pattern, func(c chi.Router) { + fn(&RouterCtx{c}) + }) +} + +func (r *RouterCtx) Get(pattern string, fn types.RouterHandler) { + r.chi.Get(pattern, routeHandler(fn)) +} +func (r *RouterCtx) Post(pattern string, fn types.RouterHandler) { + r.chi.Post(pattern, routeHandler(fn)) +} +func (r *RouterCtx) Put(pattern string, fn types.RouterHandler) { + r.chi.Put(pattern, routeHandler(fn)) +} +func (r *RouterCtx) Delete(pattern string, fn types.RouterHandler) { + r.chi.Delete(pattern, routeHandler(fn)) +} + +func (r *RouterCtx) With(fn types.MiddlewareHandler) types.Router { + c := r.chi.With(middlewareHandler(fn)) + return &RouterCtx{c} +} + +func (r *RouterCtx) WithBypass(fn func(next http.Handler) http.Handler) types.Router { + c := r.chi.With(fn) + return &RouterCtx{c} +} + +func (r *RouterCtx) Use(fn types.MiddlewareHandler) { + r.chi.Use(middlewareHandler(fn)) +} +func (r *RouterCtx) UseBypass(fn func(next http.Handler) http.Handler) { + r.chi.Use(fn) +} + +func (r *RouterCtx) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.chi.ServeHTTP(w, req) +} + +func errorHandler(err error, w http.ResponseWriter, r *http.Request) { + httpErr, ok := err.(*utils.HTTPError) + if !ok { + httpErr = utils.HttpInternalServerError().WithInternalErr(err) + } + + utils.HttpJsonResponse(w, httpErr.Code, httpErr) +} + +func routeHandler(fn types.RouterHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logEntry := getLogEntry(r) + + if err := fn(w, r); err != nil { + logEntry.SetError(err) + errorHandler(err, w, r) + } + + logEntry.SetResponse(w, r) + logEntry.Write() + } +} + +func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logEntry := getLogEntry(r) + + ctx, err := fn(w, r) + if err != nil { + logEntry.SetError(err) + errorHandler(err, w, r) + logEntry.SetResponse(w, r) + logEntry.Write() + return + } + if ctx != nil { + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/types/http.go b/internal/types/http.go new file mode 100644 index 00000000..3725333b --- /dev/null +++ b/internal/types/http.go @@ -0,0 +1,28 @@ +package types + +import ( + "context" + "net/http" +) + +type RouterHandler func(w http.ResponseWriter, r *http.Request) error +type MiddlewareHandler func(w http.ResponseWriter, r *http.Request) (context.Context, error) + +type Router interface { + Group(fn func(Router)) + Route(pattern string, fn func(Router)) + Get(pattern string, fn RouterHandler) + Post(pattern string, fn RouterHandler) + Put(pattern string, fn RouterHandler) + Delete(pattern string, fn RouterHandler) + With(fn MiddlewareHandler) Router + WithBypass(fn func(next http.Handler) http.Handler) Router + Use(fn MiddlewareHandler) + UseBypass(fn func(next http.Handler) http.Handler) + ServeHTTP(w http.ResponseWriter, req *http.Request) +} + +type HttpManager interface { + Start() + Shutdown() error +} diff --git a/internal/utils/http.go b/internal/utils/http.go index 0056940d..f2b6e79f 100644 --- a/internal/utils/http.go +++ b/internal/utils/http.go @@ -9,18 +9,18 @@ import ( "github.com/rs/zerolog/log" ) -func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) bool { - if err := json.NewDecoder(r.Body).Decode(res); err != nil { - if err == io.EOF { - HttpBadRequest(w).WithInternalErr(err).Msg("no data provided") - } else { - HttpBadRequest(w).WithInternalErr(err).Msg("unable to parse provided data") - } +func HttpJsonRequest(w http.ResponseWriter, r *http.Request, res interface{}) error { + err := json.NewDecoder(r.Body).Decode(res) - return false + if err == nil { + return nil } - return true + if err == io.EOF { + return HttpBadRequest("no data provided").WithInternalErr(err) + } + + return HttpBadRequest("unable to parse provided data").WithInternalErr(err) } func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) { @@ -32,12 +32,14 @@ func HttpJsonResponse(w http.ResponseWriter, code int, res interface{}) { } } -func HttpSuccess(w http.ResponseWriter, res ...interface{}) { +func HttpSuccess(w http.ResponseWriter, res ...interface{}) error { if len(res) == 0 { w.WriteHeader(http.StatusNoContent) } else { HttpJsonResponse(w, http.StatusOK, res[0]) } + + return nil } // HTTPError is an error with a message and an HTTP status code. @@ -47,8 +49,6 @@ type HTTPError struct { InternalErr error `json:"-"` InternalMsg string `json:"-"` - - w http.ResponseWriter `json:"-"` } func (e *HTTPError) Error() string { @@ -84,64 +84,50 @@ func (e *HTTPError) WithInternalMsgf(fmtStr string, args ...interface{}) *HTTPEr } // Sends error with custom formated message -func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) { +func (e *HTTPError) Msgf(fmtSt string, args ...interface{}) *HTTPError { e.Message = fmt.Sprintf(fmtSt, args...) - e.Send() + return e } // Sends error with custom message -func (e *HTTPError) Msg(str string) { +func (e *HTTPError) Msg(str string) *HTTPError { e.Message = str - e.Send() + return e } -// Sends error with default status text -func (e *HTTPError) Send() { - if e.Message == "" { - e.Message = http.StatusText(e.Code) +func HttpError(code int, res ...string) *HTTPError { + err := &HTTPError{ + Code: code, + Message: http.StatusText(code), } - logger := log.Error(). - Err(e.InternalErr). - Str("module", "http"). - Int("code", e.Code) - - message := e.Message - if e.InternalMsg != "" { - message = e.InternalMsg + if len(res) == 1 { + err.Message = res[0] } - logger.Msg(message) - HttpJsonResponse(e.w, e.Code, e) + return err } -func HttpError(w http.ResponseWriter, code int) *HTTPError { - return &HTTPError{ - Code: code, - w: w, - } +func HttpBadRequest(res ...string) *HTTPError { + return HttpError(http.StatusBadRequest, res...) } -func HttpBadRequest(w http.ResponseWriter) *HTTPError { - return HttpError(w, http.StatusBadRequest) +func HttpUnauthorized(res ...string) *HTTPError { + return HttpError(http.StatusUnauthorized, res...) } -func HttpUnauthorized(w http.ResponseWriter) *HTTPError { - return HttpError(w, http.StatusUnauthorized) +func HttpForbidden(res ...string) *HTTPError { + return HttpError(http.StatusForbidden, res...) } -func HttpForbidden(w http.ResponseWriter) *HTTPError { - return HttpError(w, http.StatusForbidden) +func HttpNotFound(res ...string) *HTTPError { + return HttpError(http.StatusNotFound, res...) } -func HttpNotFound(w http.ResponseWriter) *HTTPError { - return HttpError(w, http.StatusNotFound) +func HttpUnprocessableEntity(res ...string) *HTTPError { + return HttpError(http.StatusUnprocessableEntity, res...) } -func HttpUnprocessableEntity(w http.ResponseWriter) *HTTPError { - return HttpError(w, http.StatusUnprocessableEntity) -} - -func HttpInternalServerError(w http.ResponseWriter, err error) *HTTPError { - return HttpError(w, http.StatusInternalServerError).WithInternalErr(err) +func HttpInternalServerError(res ...string) *HTTPError { + return HttpError(http.StatusInternalServerError, res...) } diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 3d68d74c..540447d3 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -145,7 +145,7 @@ func (manager *WebSocketManagerCtx) AddHandler(handler types.WebSocketHandler) { manager.handlers = append(manager.handlers, handler) } -func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) { +func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error { manager.logger.Debug(). Str("address", r.RemoteAddr). Str("agent", r.UserAgent()). @@ -157,8 +157,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque connection, err := upgrader.Upgrade(w, r, nil) if err != nil { - manager.logger.Err(err).Msg("failed to upgrade connection") - return + return err } // create new peer @@ -168,7 +167,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if err != nil { manager.logger.Warn().Err(err).Msg("authentication failed") peer.Destroy(err.Error()) - return + return nil } // add session id to all log messages @@ -178,7 +177,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if !session.Profile().CanConnect { logger.Warn().Msg("connection disabled") peer.Destroy("connection disabled") - return + return nil } if session.State().IsConnected { @@ -186,7 +185,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque if !manager.sessions.MercifulReconnect() { peer.Destroy("already connected") - return + return nil } logger.Info().Msg("replacing peer connection") @@ -211,6 +210,7 @@ func (manager *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Reque }() manager.handle(connection, session) + return nil } func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {