diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index b5ccf0b..84863bf 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -47,14 +47,14 @@ // Claims is a map of arbitrary claim data. type Claims map[string]interface{} -// NewSigner creates a new JWT token generating middleware, specifying key ID, -// signing string, signing method and the claims you would like it to contain. +// NewSignerWithClaims creates a new JWT token generating middleware, specifying key ID, +// signing string, signing method and the jwt.Claims you would like it to contain. // Tokens are signed with a Key ID header (kid) which is useful for determining // the key to use for parsing. Particularly useful for clients. -func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { +func NewSignerWithClaims(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - token := jwt.NewWithClaims(method, jwt.MapClaims(claims)) + token := jwt.NewWithClaims(method, claims) token.Header["kid"] = kid // Sign and get the complete encoded token as a string using the secret @@ -69,11 +69,18 @@ } } -// NewParser creates a new JWT token parsing middleware, specifying a -// jwt.Keyfunc interface and the signing method. NewParser adds the resulting -// claims to endpoint context or returns error on invalid token. Particularly -// useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { +// NewSigner creates a new JWT token generating middleware, specifying key ID, +// signing string, signing method and the claims you would like it to contain. +// It passes these values onto NewSignerWithClaims to handle the signing process. +func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { + return NewSignerWithClaims(kid, key, method, jwt.MapClaims(claims)) +} + +// NewParserWithClaims creates a new JWT token parsing middleware, specifying a +// jwt.Keyfunc interface, the signing method as well as the claims to parse into. +// NewParserWithClaims adds the resulting claims to endpoint context or returns error on invalid token. +// Particularly useful for servers. +func NewParserWithClaims(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -88,7 +95,7 @@ // of the token to identify which key to use, but the parsed token // (head and claims) is provided to the callback, providing // flexibility. - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if token.Method != method { return nil, ErrUnexpectedSigningMethod @@ -119,11 +126,20 @@ return nil, ErrTokenInvalid } - if claims, ok := token.Claims.(jwt.MapClaims); ok { - ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims)) + if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok { + ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(tokenClaims)) + } else { + ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) } return next(ctx, request) } } } + +// NewParser creates a new JWT token parsing middleware, specifying a +// jwt.KeyFunc interface and the signing method. It will utilize NewParserWithClaims +// and fall back to implementing the jwt.MapClaims type. +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { + return NewParserWithClaims(keyFunc, method, jwt.MapClaims{}) +} diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 99b943c..0ba2999 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -5,23 +5,24 @@ "testing" jwt "github.com/dgrijalva/jwt-go" + "github.com/go-kit/kit/endpoint" ) var ( - kid = "kid" - key = []byte("test_signing_key") - method = jwt.SigningMethodHS256 - invalidMethod = jwt.SigningMethodRS256 - claims = Claims{"user": "go-kit"} + kid = "kid" + key = []byte("test_signing_key") + method = jwt.SigningMethodHS256 + invalidMethod = jwt.SigningMethodRS256 + claims = Claims{"user": "go-kit"} + mapClaims = jwt.MapClaims{"user": "go-kit"} + standardClaims = jwt.StandardClaims{Audience: "go-kit"} // Signed tokens generated at https://jwt.io/ - signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" - invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" + signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" + standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY" + invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" ) -func TestSigner(t *testing.T) { - e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - - signer := NewSigner(kid, key, method, claims)(e) +func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) { ctx, err := signer(context.Background(), struct{}{}) if err != nil { t.Fatalf("Signer returned error: %s", err) @@ -32,9 +33,22 @@ t.Fatal("Token did not exist in context") } - if token != signedKey { - t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token) + if token != expectedKey { + t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token) } +} + +func TestNewSigner(t *testing.T) { + e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } + + signer := NewSigner(kid, key, method, claims)(e) + signingValidator(t, signer, signedKey) + + signer = NewSignerWithClaims(kid, key, method, mapClaims)(e) + signingValidator(t, signer, signedKey) + + signer = NewSignerWithClaims(kid, key, method, standardClaims)(e) + signingValidator(t, signer, standardSignedKey) } func TestJWTParser(t *testing.T) { @@ -102,4 +116,18 @@ if cl["user"] != claims["user"] { t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"]) } + + parser = NewParserWithClaims(keys, method, &jwt.StandardClaims{})(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) + ctx1, err = parser(ctx, struct{}{}) + if err != nil { + t.Fatalf("Parser returned error: %s", err) + } + stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims) + if !ok { + t.Fatal("Claims were not passed into context correctly") + } + if !stdCl.VerifyAudience("go-kit", true) { + t.Fatal("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience) + } }