From d30d6deb7912d338ed9da1465bac211d2508fd6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Sat, 23 Jan 2021 18:18:14 +0100 Subject: [PATCH] add CORS. --- go.mod | 1 + go.sum | 2 ++ internal/config/server.go | 20 ++++++++++++++++++++ internal/http/manager.go | 15 ++++++++++++++- internal/types/websocket.go | 4 +++- internal/websocket/manager.go | 20 ++++++++------------ 6 files changed, 48 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index cd56a4d7..73ad1a85 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-chi/chi v4.1.0+incompatible + github.com/go-chi/cors v1.1.1 github.com/golang/protobuf v1.3.5 // indirect github.com/gorilla/websocket v1.4.2 github.com/kataras/go-events v0.0.2 diff --git a/go.sum b/go.sum index 08536a1d..47791a58 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,8 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w= github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/cors v1.1.1 h1:eHuqxsIw89iXcWnWUN8R72JMibABJTN/4IOYI5WERvw= +github.com/go-chi/cors v1.1.1/go.mod h1:K2Yje0VW/SJzxiyMYu6iPQYa7hMjQX2i/F491VChg1I= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= diff --git a/internal/config/server.go b/internal/config/server.go index 0153dce4..58848b39 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -3,6 +3,8 @@ package config import ( "github.com/spf13/cobra" "github.com/spf13/viper" + + "demodesk/neko/internal/utils" ) type Server struct { @@ -10,6 +12,7 @@ type Server struct { Key string Bind string Static string + CORS []string //Proxy bool } @@ -34,6 +37,11 @@ func (Server) Init(cmd *cobra.Command) error { return err } + cmd.PersistentFlags().StringSlice("cors", []string{"http://icms:3001"}, "list of allowed origins for CORS") + if err := viper.BindPFlag("cors", cmd.PersistentFlags().Lookup("cors")); err != nil { + return err + } + //cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies") //if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil { // return err @@ -47,5 +55,17 @@ func (s *Server) Set() { s.Key = viper.GetString("key") s.Bind = viper.GetString("bind") s.Static = viper.GetString("static") + + s.CORS = viper.GetStringSlice("cors") + in, _ := utils.ArrayIn("*", s.CORS) + if len(s.CORS) == 0 || in { + s.CORS = []string{"*"} + } + //s.Proxy = viper.GetBool("proxy") } + +func (s *Server) AllowOrigin(origin string) bool { + in, _ := utils.ArrayIn(origin, s.CORS) + return in || s.CORS[0] == "*" +} diff --git a/internal/http/manager.go b/internal/http/manager.go index 08d33eb6..3883507a 100644 --- a/internal/http/manager.go +++ b/internal/http/manager.go @@ -7,6 +7,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" + "github.com/go-chi/cors" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -27,6 +28,16 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c router := chi.NewRouter() router.Use(middleware.Recoverer) // Recover from panics without crashing server + router.Use(cors.Handler(cors.Options{ + AllowOriginFunc: func(r *http.Request, origin string) bool { + return conf.AllowOrigin(origin) + }, + AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + ExposedHeaders: []string{"Link"}, + AllowCredentials: false, + MaxAge: 300, // Maximum value not ignored by any of major browsers + })) router.Use(middleware.RequestID) // Create a request ID for each request router.Use(Logger) // Log API request calls using custom logger function @@ -34,7 +45,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c router.Get("/ws", func(w http.ResponseWriter, r *http.Request) { //nolint - WebSocketManager.Upgrade(w, r) + WebSocketManager.Upgrade(w, r, func(r *http.Request) bool { + return conf.AllowOrigin(r.Header.Get("Origin")) + }) }) if conf.Static != "" { diff --git a/internal/types/websocket.go b/internal/types/websocket.go index 836e3d31..154d5b74 100644 --- a/internal/types/websocket.go +++ b/internal/types/websocket.go @@ -4,6 +4,8 @@ import "net/http" type HandlerFunction func(Session, []byte) bool +type CheckOrigin func(r *http.Request) bool + type WebSocketPeer interface { Send(v interface{}) error Destroy() error @@ -13,5 +15,5 @@ type WebSocketManager interface { Start() Shutdown() error AddHandler(handler HandlerFunction) - Upgrade(w http.ResponseWriter, r *http.Request) error + Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error } diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 511c9450..a63e181a 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -27,11 +27,6 @@ func New( logger: logger, sessions: sessions, desktop: desktop, - upgrader: websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, handler: handler.New(sessions, desktop, capture, webrtc), handlers: []types.HandlerFunction{}, } @@ -42,7 +37,6 @@ const pingPeriod = 60 * time.Second type WebSocketManagerCtx struct { logger zerolog.Logger - upgrader websocket.Upgrader sessions types.SessionManager desktop types.DesktopManager handler *handler.MessageHandlerCtx @@ -145,10 +139,14 @@ func (ws *WebSocketManagerCtx) AddHandler(handler types.HandlerFunction) { ws.handlers = append(ws.handlers, handler) } -func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) error { +func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error { ws.logger.Debug().Msg("attempting to upgrade connection") - connection, err := ws.upgrader.Upgrade(w, r, nil) + upgrader := websocket.Upgrader{ + CheckOrigin: checkOrigin, + } + + connection, err := upgrader.Upgrade(w, r, nil) if err != nil { ws.logger.Error().Err(err).Msg("failed to upgrade connection") return err @@ -227,11 +225,9 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) { bytes := make(chan []byte) cancel := make(chan struct{}) - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - }() + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() go func() { for {