mirror of
https://github.com/m1k1o/neko.git
synced 2024-07-24 14:40:50 +12:00
add CORS.
This commit is contained in:
parent
79d67c4a09
commit
d30d6deb79
1
go.mod
1
go.mod
@ -5,6 +5,7 @@ go 1.13
|
|||||||
require (
|
require (
|
||||||
github.com/fsnotify/fsnotify v1.4.9 // indirect
|
github.com/fsnotify/fsnotify v1.4.9 // indirect
|
||||||
github.com/go-chi/chi v4.1.0+incompatible
|
github.com/go-chi/chi v4.1.0+incompatible
|
||||||
|
github.com/go-chi/cors v1.1.1
|
||||||
github.com/golang/protobuf v1.3.5 // indirect
|
github.com/golang/protobuf v1.3.5 // indirect
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/kataras/go-events v0.0.2
|
github.com/kataras/go-events v0.0.2
|
||||||
|
2
go.sum
2
go.sum
@ -50,6 +50,8 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME
|
|||||||
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
|
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
|
||||||
github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w=
|
github.com/go-chi/chi v4.1.0+incompatible h1:ETj3cggsVIY2Xao5ExCu6YhEh5MD6JTfcBzS37R260w=
|
||||||
github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
|
github.com/go-chi/chi v4.1.0+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ=
|
||||||
|
github.com/go-chi/cors v1.1.1 h1:eHuqxsIw89iXcWnWUN8R72JMibABJTN/4IOYI5WERvw=
|
||||||
|
github.com/go-chi/cors v1.1.1/go.mod h1:K2Yje0VW/SJzxiyMYu6iPQYa7hMjQX2i/F491VChg1I=
|
||||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
||||||
|
@ -3,6 +3,8 @@ package config
|
|||||||
import (
|
import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
|
||||||
|
"demodesk/neko/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -10,6 +12,7 @@ type Server struct {
|
|||||||
Key string
|
Key string
|
||||||
Bind string
|
Bind string
|
||||||
Static string
|
Static string
|
||||||
|
CORS []string
|
||||||
//Proxy bool
|
//Proxy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,6 +37,11 @@ func (Server) Init(cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.PersistentFlags().StringSlice("cors", []string{"http://icms:3001"}, "list of allowed origins for CORS")
|
||||||
|
if err := viper.BindPFlag("cors", cmd.PersistentFlags().Lookup("cors")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
//cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies")
|
//cmd.PersistentFlags().Bool("proxy", false, "allow reverse proxies")
|
||||||
//if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
|
//if err := viper.BindPFlag("proxy", cmd.PersistentFlags().Lookup("proxy")); err != nil {
|
||||||
// return err
|
// return err
|
||||||
@ -47,5 +55,17 @@ func (s *Server) Set() {
|
|||||||
s.Key = viper.GetString("key")
|
s.Key = viper.GetString("key")
|
||||||
s.Bind = viper.GetString("bind")
|
s.Bind = viper.GetString("bind")
|
||||||
s.Static = viper.GetString("static")
|
s.Static = viper.GetString("static")
|
||||||
|
|
||||||
|
s.CORS = viper.GetStringSlice("cors")
|
||||||
|
in, _ := utils.ArrayIn("*", s.CORS)
|
||||||
|
if len(s.CORS) == 0 || in {
|
||||||
|
s.CORS = []string{"*"}
|
||||||
|
}
|
||||||
|
|
||||||
//s.Proxy = viper.GetBool("proxy")
|
//s.Proxy = viper.GetBool("proxy")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) AllowOrigin(origin string) bool {
|
||||||
|
in, _ := utils.ArrayIn(origin, s.CORS)
|
||||||
|
return in || s.CORS[0] == "*"
|
||||||
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/chi/middleware"
|
"github.com/go-chi/chi/middleware"
|
||||||
|
"github.com/go-chi/cors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
@ -27,6 +28,16 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
|
|||||||
|
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
router.Use(middleware.Recoverer) // Recover from panics without crashing server
|
router.Use(middleware.Recoverer) // Recover from panics without crashing server
|
||||||
|
router.Use(cors.Handler(cors.Options{
|
||||||
|
AllowOriginFunc: func(r *http.Request, origin string) bool {
|
||||||
|
return conf.AllowOrigin(origin)
|
||||||
|
},
|
||||||
|
AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"},
|
||||||
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||||
|
ExposedHeaders: []string{"Link"},
|
||||||
|
AllowCredentials: false,
|
||||||
|
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||||
|
}))
|
||||||
router.Use(middleware.RequestID) // Create a request ID for each request
|
router.Use(middleware.RequestID) // Create a request ID for each request
|
||||||
router.Use(Logger) // Log API request calls using custom logger function
|
router.Use(Logger) // Log API request calls using custom logger function
|
||||||
|
|
||||||
@ -34,7 +45,9 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
|
|||||||
|
|
||||||
router.Get("/ws", func(w http.ResponseWriter, r *http.Request) {
|
router.Get("/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||||
//nolint
|
//nolint
|
||||||
WebSocketManager.Upgrade(w, r)
|
WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
|
||||||
|
return conf.AllowOrigin(r.Header.Get("Origin"))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
if conf.Static != "" {
|
if conf.Static != "" {
|
||||||
|
@ -4,6 +4,8 @@ import "net/http"
|
|||||||
|
|
||||||
type HandlerFunction func(Session, []byte) bool
|
type HandlerFunction func(Session, []byte) bool
|
||||||
|
|
||||||
|
type CheckOrigin func(r *http.Request) bool
|
||||||
|
|
||||||
type WebSocketPeer interface {
|
type WebSocketPeer interface {
|
||||||
Send(v interface{}) error
|
Send(v interface{}) error
|
||||||
Destroy() error
|
Destroy() error
|
||||||
@ -13,5 +15,5 @@ type WebSocketManager interface {
|
|||||||
Start()
|
Start()
|
||||||
Shutdown() error
|
Shutdown() error
|
||||||
AddHandler(handler HandlerFunction)
|
AddHandler(handler HandlerFunction)
|
||||||
Upgrade(w http.ResponseWriter, r *http.Request) error
|
Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin CheckOrigin) error
|
||||||
}
|
}
|
||||||
|
@ -27,11 +27,6 @@ func New(
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
sessions: sessions,
|
sessions: sessions,
|
||||||
desktop: desktop,
|
desktop: desktop,
|
||||||
upgrader: websocket.Upgrader{
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
handler: handler.New(sessions, desktop, capture, webrtc),
|
handler: handler.New(sessions, desktop, capture, webrtc),
|
||||||
handlers: []types.HandlerFunction{},
|
handlers: []types.HandlerFunction{},
|
||||||
}
|
}
|
||||||
@ -42,7 +37,6 @@ const pingPeriod = 60 * time.Second
|
|||||||
|
|
||||||
type WebSocketManagerCtx struct {
|
type WebSocketManagerCtx struct {
|
||||||
logger zerolog.Logger
|
logger zerolog.Logger
|
||||||
upgrader websocket.Upgrader
|
|
||||||
sessions types.SessionManager
|
sessions types.SessionManager
|
||||||
desktop types.DesktopManager
|
desktop types.DesktopManager
|
||||||
handler *handler.MessageHandlerCtx
|
handler *handler.MessageHandlerCtx
|
||||||
@ -145,10 +139,14 @@ func (ws *WebSocketManagerCtx) AddHandler(handler types.HandlerFunction) {
|
|||||||
ws.handlers = append(ws.handlers, handler)
|
ws.handlers = append(ws.handlers, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request, checkOrigin types.CheckOrigin) error {
|
||||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||||
|
|
||||||
connection, err := ws.upgrader.Upgrade(w, r, nil)
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: checkOrigin,
|
||||||
|
}
|
||||||
|
|
||||||
|
connection, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||||
return err
|
return err
|
||||||
@ -227,11 +225,9 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
|
|||||||
func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {
|
func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.Session) {
|
||||||
bytes := make(chan []byte)
|
bytes := make(chan []byte)
|
||||||
cancel := make(chan struct{})
|
cancel := make(chan struct{})
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
|
|
||||||
defer func() {
|
ticker := time.NewTicker(pingPeriod)
|
||||||
ticker.Stop()
|
defer ticker.Stop()
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
Loading…
Reference in New Issue
Block a user