add CORS.

This commit is contained in:
Miroslav Šedivý 2021-01-23 18:18:14 +01:00
parent 79d67c4a09
commit d30d6deb79
6 changed files with 48 additions and 14 deletions

1
go.mod
View File

@ -5,6 +5,7 @@ go 1.13
require ( require (
github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/go-chi/chi v4.1.0+incompatible github.com/go-chi/chi v4.1.0+incompatible
github.com/go-chi/cors v1.1.1
github.com/golang/protobuf v1.3.5 // indirect github.com/golang/protobuf v1.3.5 // indirect
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/kataras/go-events v0.0.2 github.com/kataras/go-events v0.0.2

2
go.sum
View File

@ -50,6 +50,8 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w= github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w=
github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
github.com/go-chi/cors v1.1.1 h1:eHuqxsIw89iXcWnWUN8R72JMibABJTN/4IOYI5WERvw=
github.com/go-chi/cors v1.1.1/go.mod h1:K2Yje0VW/SJzxiyMYu6iPQYa7hMjQX2i/F491VChg1I=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=

View File

@ -3,6 +3,8 @@ package config
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"demodesk/neko/internal/utils"
) )
type Server struct { type Server struct {
@ -10,6 +12,7 @@ type Server struct {
Key string Key string
Bind string Bind string
Static string Static string
CORS []string
//Proxy bool //Proxy bool
} }
@ -34,6 +37,11 @@ func (Server) Init(cmd *cobra.Command) error {
return err return err
} }
cmd.PersistentFlags().StringSlice("cors", []string{"http://icms:3001"}, "list of allowed origins for CORS")
if err := viper.BindPFlag("cors", cmd.PersistentFlags().Lookup("cors")); err != nil {
return err
}
//cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies") //cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies")
//if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil { //if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
// return err // return err
@ -47,5 +55,17 @@ func (s *Server) Set() {
s.Key = viper.GetString("key") s.Key = viper.GetString("key")
s.Bind = viper.GetString("bind") s.Bind = viper.GetString("bind")
s.Static = viper.GetString("static") s.Static = viper.GetString("static")
s.CORS = viper.GetStringSlice("cors")
in, _ := utils.ArrayIn("*", s.CORS)
if len(s.CORS) == 0 || in {
s.CORS = []string{"*"}
}
//s.Proxy = viper.GetBool("proxy") //s.Proxy = viper.GetBool("proxy")
} }
func (s *Server) AllowOrigin(origin string) bool {
in, _ := utils.ArrayIn(origin, s.CORS)
return in || s.CORS[0] == "*"
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -27,6 +28,16 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
router := chi.NewRouter() router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return conf.AllowOrigin(origin)
},
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
router.Use(middleware.RequestID) // Create a request ID for each request router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(Logger) // Log API request calls using custom logger function router.Use(Logger) // Log API request calls using custom logger function
@ -34,7 +45,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
router.Get("/ws", func(w http.ResponseWriter, r *http.Request) { router.Get("/ws", func(w http.ResponseWriter, r *http.Request) {
//nolint //nolint
WebSocketManager.Upgrade(w, r) WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
return conf.AllowOrigin(r.Header.Get("Origin"))
})
}) })
if conf.Static != "" { if conf.Static != "" {

View File

@ -4,6 +4,8 @@ import "net/http"
type HandlerFunction func(Session, []byte) bool type HandlerFunction func(Session, []byte) bool
type CheckOrigin func(r *http.Request) bool
type WebSocketPeer interface { type WebSocketPeer interface {
Send(v interface{}) error Send(v interface{}) error
Destroy() error Destroy() error
@ -13,5 +15,5 @@ type WebSocketManager interface {
Start() Start()
Shutdown() error Shutdown() error
AddHandler(handler HandlerFunction) AddHandler(handler HandlerFunction)
Upgrade(w http.ResponseWriter, r *http.Request) error Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error
} }

View File

@ -27,11 +27,6 @@ func New(
logger: logger, logger: logger,
sessions: sessions, sessions: sessions,
desktop: desktop, desktop: desktop,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
handler: handler.New(sessions, desktop, capture, webrtc), handler: handler.New(sessions, desktop, capture, webrtc),
handlers: []types.HandlerFunction{}, handlers: []types.HandlerFunction{},
} }
@ -42,7 +37,6 @@ const pingPeriod = 60 * time.Second
type WebSocketManagerCtx struct { type WebSocketManagerCtx struct {
logger zerolog.Logger logger zerolog.Logger
upgrader websocket.Upgrader
sessions types.SessionManager sessions types.SessionManager
desktop types.DesktopManager desktop types.DesktopManager
handler *handler.MessageHandlerCtx handler *handler.MessageHandlerCtx
@ -145,10 +139,14 @@ func (ws *WebSocketManagerCtx) AddHandler(handler types.HandlerFunction) {
ws.handlers = append(ws.handlers, handler) ws.handlers = append(ws.handlers, handler)
} }
func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) error { func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error {
ws.logger.Debug().Msg("attempting to upgrade connection") ws.logger.Debug().Msg("attempting to upgrade connection")
connection, err := ws.upgrader.Upgrade(w, r, nil) upgrader := websocket.Upgrader{
CheckOrigin: checkOrigin,
}
connection, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection") ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err return err
@ -227,11 +225,9 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) { func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {
bytes := make(chan []byte) bytes := make(chan []byte)
cancel := make(chan struct{}) cancel := make(chan struct{})
ticker := time.NewTicker(pingPeriod)
defer func() { ticker := time.NewTicker(pingPeriod)
ticker.Stop() defer ticker.Stop()
}()
go func() { go func() {
for { for {