neko/internal/api/utils/auth.go

66 lines
1.6 KiB
Go
Raw Normal View History

2020-11-01 00:27:55 +13:00
package utils
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/go-chi/render"
"github.com/dgrijalva/jwt-go"
)
2020-11-01 09:58:57 +13:00
type key int
const (
keyPrincipalID key = iota
)
2020-11-01 00:27:55 +13:00
func GetUserName(r *http.Request) interface{} {
2020-11-01 09:58:57 +13:00
props, _ := r.Context().Value(keyPrincipalID).(jwt.MapClaims)
2020-11-01 00:27:55 +13:00
return props["user_name"]
}
type HttpMiddleware = func(next http.Handler) http.Handler
func AuthMiddleware(next http.Handler, jwtSecrets ...[]byte) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := strings.Split(r.Header.Get("Authorization"), "Bearer ")
if len(authHeader) != 2 {
2020-11-01 09:58:57 +13:00
_ = render.Render(w, r, ErrMessage(401, "Malformed JWT token."))
2020-11-01 00:27:55 +13:00
return
}
jwtToken := authHeader[1]
var jwtVerified *jwt.Token
var err error
for _, jwtSecret := range jwtSecrets {
jwtVerified, err = jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err == nil {
break
}
}
if err != nil {
2020-11-01 09:58:57 +13:00
_ = render.Render(w, r, ErrMessage(401, "Invalid JWT token."))
2020-11-01 00:27:55 +13:00
return
}
if claims, ok := jwtVerified.Claims.(jwt.MapClaims); ok && jwtVerified.Valid {
2020-11-01 09:58:57 +13:00
ctx := context.WithValue(r.Context(), keyPrincipalID, claims)
2020-11-01 00:27:55 +13:00
// Access context values in handlers like this
// props, _ := r.Context().Value("props").(jwt.MapClaims)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
2020-11-01 09:58:57 +13:00
_ = render.Render(w, r, ErrMessage(401, "Unauthorized."))
2020-11-01 00:27:55 +13:00
}
})
}