2020-11-01 00:27:55 +13:00
|
|
|
package utils
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/go-chi/render"
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
|
|
)
|
|
|
|
|
2020-11-01 09:58:57 +13:00
|
|
|
type key int
|
|
|
|
|
|
|
|
const (
|
|
|
|
keyPrincipalID key = iota
|
|
|
|
)
|
|
|
|
|
2020-11-01 00:27:55 +13:00
|
|
|
func GetUserName(r *http.Request) interface{} {
|
2020-11-01 09:58:57 +13:00
|
|
|
props, _ := r.Context().Value(keyPrincipalID).(jwt.MapClaims)
|
2020-11-01 00:27:55 +13:00
|
|
|
return props["user_name"]
|
|
|
|
}
|
|
|
|
|
|
|
|
type HttpMiddleware = func(next http.Handler) http.Handler
|
|
|
|
|
|
|
|
func AuthMiddleware(next http.Handler, jwtSecrets ...[]byte) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ")
|
|
|
|
if len(authHeader) != 2 {
|
2020-11-01 09:58:57 +13:00
|
|
|
_ = render.Render(w, r, ErrMessage(401, "Malformed JWT token."))
|
2020-11-01 00:27:55 +13:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
jwtToken := authHeader[1]
|
|
|
|
var jwtVerified *jwt.Token
|
|
|
|
var err error
|
|
|
|
for _, jwtSecret := range jwtSecrets {
|
|
|
|
jwtVerified, err = jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
|
|
|
}
|
|
|
|
|
|
|
|
return jwtSecret, nil
|
|
|
|
})
|
|
|
|
|
|
|
|
if err == nil {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if err != nil {
|
2020-11-01 09:58:57 +13:00
|
|
|
_ = render.Render(w, r, ErrMessage(401, "Invalid JWT token."))
|
2020-11-01 00:27:55 +13:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if claims, ok := jwtVerified.Claims.(jwt.MapClaims); ok && jwtVerified.Valid {
|
2020-11-01 09:58:57 +13:00
|
|
|
ctx := context.WithValue(r.Context(), keyPrincipalID, claims)
|
2020-11-01 00:27:55 +13:00
|
|
|
// Access context values in handlers like this
|
|
|
|
// props, _ := r.Context().Value("props").(jwt.MapClaims)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
} else {
|
2020-11-01 09:58:57 +13:00
|
|
|
_ = render.Render(w, r, ErrMessage(401, "Unauthorized."))
|
2020-11-01 00:27:55 +13:00
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|