2020-10-22 16:54:50 +02:00
package config
import (
2023-04-28 22:19:52 +02:00
"path"
2020-10-22 16:54:50 +02:00
"github.com/spf13/cobra"
"github.com/spf13/viper"
2021-01-23 18:18:14 +01:00
2022-07-14 00:58:22 +02:00
"github.com/demodesk/neko/pkg/utils"
2020-10-22 16:54:50 +02:00
)
type Server struct {
2023-04-28 22:19:52 +02:00
Cert string
Key string
Bind string
2023-11-19 14:35:19 +01:00
Proxy bool
2023-04-28 22:19:52 +02:00
Static string
PathPrefix string
PProf bool
Metrics bool
CORS [ ] string
2020-10-22 16:54:50 +02:00
}
func ( Server ) Init ( cmd * cobra . Command ) error {
2021-03-16 15:24:58 +01:00
cmd . PersistentFlags ( ) . String ( "server.bind" , "127.0.0.1:8080" , "address/port/socket to serve neko" )
if err := viper . BindPFlag ( "server.bind" , cmd . PersistentFlags ( ) . Lookup ( "server.bind" ) ) ; err != nil {
2020-10-22 16:54:50 +02:00
return err
}
2021-03-16 15:24:58 +01:00
cmd . PersistentFlags ( ) . String ( "server.cert" , "" , "path to the SSL cert used to secure the neko server" )
if err := viper . BindPFlag ( "server.cert" , cmd . PersistentFlags ( ) . Lookup ( "server.cert" ) ) ; err != nil {
2020-10-22 16:54:50 +02:00
return err
}
2021-03-16 15:24:58 +01:00
cmd . PersistentFlags ( ) . String ( "server.key" , "" , "path to the SSL key used to secure the neko server" )
if err := viper . BindPFlag ( "server.key" , cmd . PersistentFlags ( ) . Lookup ( "server.key" ) ) ; err != nil {
2020-10-22 16:54:50 +02:00
return err
}
2023-11-19 14:35:19 +01:00
cmd . PersistentFlags ( ) . Bool ( "server.proxy" , false , "trust reverse proxy headers" )
if err := viper . BindPFlag ( "server.proxy" , cmd . PersistentFlags ( ) . Lookup ( "server.proxy" ) ) ; err != nil {
return err
}
2021-03-16 15:24:58 +01:00
cmd . PersistentFlags ( ) . String ( "server.static" , "" , "path to neko client files to serve" )
if err := viper . BindPFlag ( "server.static" , cmd . PersistentFlags ( ) . Lookup ( "server.static" ) ) ; err != nil {
2020-10-22 16:54:50 +02:00
return err
}
2023-04-28 22:19:52 +02:00
cmd . PersistentFlags ( ) . String ( "server.path_prefix" , "/" , "path prefix for HTTP requests" )
if err := viper . BindPFlag ( "server.path_prefix" , cmd . PersistentFlags ( ) . Lookup ( "server.path_prefix" ) ) ; err != nil {
return err
}
2022-02-12 20:22:50 +01:00
cmd . PersistentFlags ( ) . Bool ( "server.pprof" , false , "enable pprof endpoint available at /debug/pprof" )
if err := viper . BindPFlag ( "server.pprof" , cmd . PersistentFlags ( ) . Lookup ( "server.pprof" ) ) ; err != nil {
return err
}
2022-07-04 18:26:29 +02:00
cmd . PersistentFlags ( ) . Bool ( "server.metrics" , true , "enable prometheus metrics available at /metrics" )
if err := viper . BindPFlag ( "server.metrics" , cmd . PersistentFlags ( ) . Lookup ( "server.metrics" ) ) ; err != nil {
return err
}
2023-11-19 14:35:19 +01:00
cmd . PersistentFlags ( ) . StringSlice ( "server.cors" , [ ] string { } , "list of allowed origins for CORS, if empty CORS is disabled, if '*' is present all origins are allowed" )
2021-03-16 15:24:58 +01:00
if err := viper . BindPFlag ( "server.cors" , cmd . PersistentFlags ( ) . Lookup ( "server.cors" ) ) ; err != nil {
2021-01-23 18:18:14 +01:00
return err
}
2020-10-22 16:54:50 +02:00
return nil
}
func ( s * Server ) Set ( ) {
2021-03-16 15:24:58 +01:00
s . Cert = viper . GetString ( "server.cert" )
s . Key = viper . GetString ( "server.key" )
s . Bind = viper . GetString ( "server.bind" )
2023-11-19 14:35:19 +01:00
s . Proxy = viper . GetBool ( "server.proxy" )
2021-03-16 15:24:58 +01:00
s . Static = viper . GetString ( "server.static" )
2023-04-28 22:19:52 +02:00
s . PathPrefix = path . Join ( "/" , path . Clean ( viper . GetString ( "server.path_prefix" ) ) )
2022-02-12 20:22:50 +01:00
s . PProf = viper . GetBool ( "server.pprof" )
2022-07-04 18:26:29 +02:00
s . Metrics = viper . GetBool ( "server.metrics" )
2021-01-23 18:18:14 +01:00
2021-03-16 15:24:58 +01:00
s . CORS = viper . GetStringSlice ( "server.cors" )
2021-01-23 18:18:14 +01:00
in , _ := utils . ArrayIn ( "*" , s . CORS )
if len ( s . CORS ) == 0 || in {
s . CORS = [ ] string { "*" }
}
2020-10-22 16:54:50 +02:00
}
2021-01-23 18:18:14 +01:00
2023-11-19 14:35:19 +01:00
func ( s * Server ) HasCors ( ) bool {
return len ( s . CORS ) > 0
}
2021-01-23 18:18:14 +01:00
func ( s * Server ) AllowOrigin ( origin string ) bool {
2023-11-19 14:35:19 +01:00
// if CORS is disabled, allow all origins
if len ( s . CORS ) == 0 {
return true
}
// if CORS is enabled, allow only origins in the list
2021-01-23 18:18:14 +01:00
in , _ := utils . ArrayIn ( origin , s . CORS )
return in || s . CORS [ 0 ] == "*"
}