add CORS.

This commit is contained in:
Miroslav Šedivý
2021-01-23 18:18:14 +01:00
parent 79d67c4a09
commit d30d6deb79
6 changed files with 48 additions and 14 deletions

View File

@ -3,6 +3,8 @@ package config
import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"demodesk/neko/internal/utils"
)
type Server struct {
@ -10,6 +12,7 @@ type Server struct {
Key string
Bind string
Static string
CORS []string
//Proxy bool
}
@ -34,6 +37,11 @@ func (Server) Init(cmd *cobra.Command) error {
return err
}
cmd.PersistentFlags().StringSlice("cors", []string{"http://icms:3001"}, "list of allowed origins for CORS")
if err := viper.BindPFlag("cors", cmd.PersistentFlags().Lookup("cors")); err != nil {
return err
}
//cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies")
//if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
// return err
@ -47,5 +55,17 @@ func (s *Server) Set() {
s.Key = viper.GetString("key")
s.Bind = viper.GetString("bind")
s.Static = viper.GetString("static")
s.CORS = viper.GetStringSlice("cors")
in, _ := utils.ArrayIn("*", s.CORS)
if len(s.CORS) == 0 || in {
s.CORS = []string{"*"}
}
//s.Proxy = viper.GetBool("proxy")
}
func (s *Server) AllowOrigin(origin string) bool {
in, _ := utils.ArrayIn(origin, s.CORS)
return in || s.CORS[0] == "*"
}

View File

@ -7,6 +7,7 @@ import (
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -27,6 +28,16 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return conf.AllowOrigin(origin)
},
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(Logger) // Log API request calls using custom logger function
@ -34,7 +45,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
router.Get("/ws", func(w http.ResponseWriter, r *http.Request) {
//nolint
WebSocketManager.Upgrade(w, r)
WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
return conf.AllowOrigin(r.Header.Get("Origin"))
})
})
if conf.Static != "" {

View File

@ -4,6 +4,8 @@ import "net/http"
type HandlerFunction func(Session, []byte) bool
type CheckOrigin func(r *http.Request) bool
type WebSocketPeer interface {
Send(v interface{}) error
Destroy() error
@ -13,5 +15,5 @@ type WebSocketManager interface {
Start()
Shutdown() error
AddHandler(handler HandlerFunction)
Upgrade(w http.ResponseWriter, r *http.Request) error
Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error
}

View File

@ -27,11 +27,6 @@ func New(
logger: logger,
sessions: sessions,
desktop: desktop,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
handler: handler.New(sessions, desktop, capture, webrtc),
handlers: []types.HandlerFunction{},
}
@ -42,7 +37,6 @@ const pingPeriod = 60 * time.Second
type WebSocketManagerCtx struct {
logger zerolog.Logger
upgrader websocket.Upgrader
sessions types.SessionManager
desktop types.DesktopManager
handler *handler.MessageHandlerCtx
@ -145,10 +139,14 @@ func (ws *WebSocketManagerCtx) AddHandler(handler types.HandlerFunction) {
ws.handlers = append(ws.handlers, handler)
}
func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) error {
func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error {
ws.logger.Debug().Msg("attempting to upgrade connection")
connection, err := ws.upgrader.Upgrade(w, r, nil)
upgrader := websocket.Upgrader{
CheckOrigin: checkOrigin,
}
connection, err := upgrader.Upgrade(w, r, nil)
if err != nil {
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
return err
@ -227,11 +225,9 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {
bytes := make(chan []byte)
cancel := make(chan struct{})
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
}()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
go func() {
for {