server proxy, router opts and optional CORS.

This commit is contained in:
Miroslav Šedivý 2023-11-19 14:35:19 +01:00
parent cd9d31a627
commit 501280f8aa
4 changed files with 95 additions and 33 deletions

View File

@ -13,6 +13,7 @@ type Server struct {
Cert string Cert string
Key string Key string
Bind string Bind string
Proxy bool
Static string Static string
PathPrefix string PathPrefix string
PProf bool PProf bool
@ -36,6 +37,11 @@ func (Server) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().Bool("server.proxy", false, "trust reverse proxy headers")
if err := viper.BindPFlag("server.proxy", cmd.PersistentFlags().Lookup("server.proxy")); err != nil {
return err
}
cmd.PersistentFlags().String("server.static", "", "path to neko client files to serve") cmd.PersistentFlags().String("server.static", "", "path to neko client files to serve")
if err := viper.BindPFlag("server.static", cmd.PersistentFlags().Lookup("server.static")); err != nil { if err := viper.BindPFlag("server.static", cmd.PersistentFlags().Lookup("server.static")); err != nil {
return err return err
@ -56,7 +62,7 @@ func (Server) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().StringSlice("server.cors", []string{"*"}, "list of allowed origins for CORS") cmd.PersistentFlags().StringSlice("server.cors", []string{}, "list of allowed origins for CORS, if empty CORS is disabled, if '*' is present all origins are allowed")
if err := viper.BindPFlag("server.cors", cmd.PersistentFlags().Lookup("server.cors")); err != nil { if err := viper.BindPFlag("server.cors", cmd.PersistentFlags().Lookup("server.cors")); err != nil {
return err return err
} }
@ -68,6 +74,7 @@ func (s *Server) Set() {
s.Cert = viper.GetString("server.cert") s.Cert = viper.GetString("server.cert")
s.Key = viper.GetString("server.key") s.Key = viper.GetString("server.key")
s.Bind = viper.GetString("server.bind") s.Bind = viper.GetString("server.bind")
s.Proxy = viper.GetBool("server.proxy")
s.Static = viper.GetString("server.static") s.Static = viper.GetString("server.static")
s.PathPrefix = path.Join("/", path.Clean(viper.GetString("server.path_prefix"))) s.PathPrefix = path.Join("/", path.Clean(viper.GetString("server.path_prefix")))
s.PProf = viper.GetBool("server.pprof") s.PProf = viper.GetBool("server.pprof")
@ -80,7 +87,17 @@ func (s *Server) Set() {
} }
} }
func (s *Server) HasCors() bool {
return len(s.CORS) > 0
}
func (s *Server) AllowOrigin(origin string) bool { func (s *Server) AllowOrigin(origin string) bool {
// if CORS is disabled, allow all origins
if len(s.CORS) == 0 {
return true
}
// if CORS is enabled, allow only origins in the list
in, _ := utils.ArrayIn(origin, s.CORS) in, _ := utils.ArrayIn(origin, s.CORS)
return in || s.CORS[0] == "*" return in || s.CORS[0] == "*"
} }

View File

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"os" "os"
"github.com/go-chi/cors"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -24,24 +23,31 @@ type HttpManagerCtx struct {
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx { func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
logger := log.With().Str("module", "http").Logger() logger := log.With().Str("module", "http").Logger()
router := newRouter(logger) opts := []RouterOption{
router.UseBypass(cors.Handler(cors.Options{ WithRequestID(), // create a request id for each request
AllowOriginFunc: func(r *http.Request, origin string) bool { }
return config.AllowOrigin(origin)
}, // use real ip if behind proxy
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"}, // before logger so it can log the real ip
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, if config.Proxy {
ExposedHeaders: []string{"Link"}, opts = append(opts, WithRealIP())
AllowCredentials: true, }
MaxAge: 300, // Maximum value not ignored by any of major browsers
})) opts = append(opts,
WithLogger(logger),
WithRecoverer(), // recover from panics without crashing server
)
if config.HasCors() {
opts = append(opts, WithCORS(config.AllowOrigin))
}
if config.PathPrefix != "/" { if config.PathPrefix != "/" {
router.UseBypass(func(h http.Handler) http.Handler { opts = append(opts, WithPathPrefix(config.PathPrefix))
return http.StripPrefix(config.PathPrefix, h)
})
} }
router := newRouter(opts...)
router.Route("/api", ApiManager.Route) router.Route("/api", ApiManager.Route)
router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool { router.Get("/api/ws", WebSocketManager.Upgrade(func(r *http.Request) bool {

View File

@ -5,6 +5,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/demodesk/neko/pkg/auth" "github.com/demodesk/neko/pkg/auth"
@ -12,16 +13,65 @@ import (
"github.com/demodesk/neko/pkg/utils" "github.com/demodesk/neko/pkg/utils"
) )
type RouterOption func(*router)
func WithRequestID() RouterOption {
return func(r *router) {
r.chi.Use(middleware.RequestID)
}
}
func WithLogger(logger zerolog.Logger) RouterOption {
return func(r *router) {
r.chi.Use(middleware.RequestLogger(&logFormatter{logger}))
}
}
func WithRecoverer() RouterOption {
return func(r *router) {
r.chi.Use(middleware.Recoverer)
}
}
func WithCORS(allowOrigin func(origin string) bool) RouterOption {
return func(r *router) {
r.chi.Use(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return allowOrigin(origin)
},
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
}
}
func WithPathPrefix(prefix string) RouterOption {
return func(r *router) {
r.chi.Use(func(h http.Handler) http.Handler {
return http.StripPrefix(prefix, h)
})
}
}
func WithRealIP() RouterOption {
return func(r *router) {
r.chi.Use(middleware.RealIP)
}
}
type router struct { type router struct {
chi chi.Router chi chi.Router
} }
func newRouter(logger zerolog.Logger) *router { func newRouter(opts ...RouterOption) types.Router {
r := chi.NewRouter() r := &router{chi.NewRouter()}
r.Use(middleware.RequestID) // Create a request ID for each request for _, opt := range opts {
r.Use(middleware.RequestLogger(&logFormatter{logger})) opt(r)
r.Use(middleware.Recoverer) // Recover from panics without crashing server }
return &router{r} return r
} }
func (r *router) Group(fn func(types.Router)) { func (r *router) Group(fn func(types.Router)) {
@ -61,19 +111,10 @@ func (r *router) With(fn types.MiddlewareHandler) types.Router {
return &router{c} return &router{c}
} }
func (r *router) WithBypass(fn func(next http.Handler) http.Handler) types.Router {
c := r.chi.With(fn)
return &router{c}
}
func (r *router) Use(fn types.MiddlewareHandler) { func (r *router) Use(fn types.MiddlewareHandler) {
r.chi.Use(middlewareHandler(fn)) r.chi.Use(middlewareHandler(fn))
} }
func (r *router) UseBypass(fn func(next http.Handler) http.Handler) {
r.chi.Use(fn)
}
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.chi.ServeHTTP(w, req) r.chi.ServeHTTP(w, req)
} }

View File

@ -17,9 +17,7 @@ type Router interface {
Patch(pattern string, fn RouterHandler) Patch(pattern string, fn RouterHandler)
Delete(pattern string, fn RouterHandler) Delete(pattern string, fn RouterHandler)
With(fn MiddlewareHandler) Router With(fn MiddlewareHandler) Router
WithBypass(fn func(next http.Handler) http.Handler) Router
Use(fn MiddlewareHandler) Use(fn MiddlewareHandler)
UseBypass(fn func(next http.Handler) http.Handler)
ServeHTTP(w http.ResponseWriter, req *http.Request) ServeHTTP(w http.ResponseWriter, req *http.Request)
} }