diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index 84863bf..090078c 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -44,14 +44,11 @@ ErrUnexpectedSigningMethod = errors.New("unexpected signing method") ) -// Claims is a map of arbitrary claim data. -type Claims map[string]interface{} - -// NewSignerWithClaims creates a new JWT token generating middleware, specifying key ID, +// NewSigner creates a new JWT token generating middleware, specifying key ID, // signing string, signing method and the jwt.Claims you would like it to contain. // Tokens are signed with a Key ID header (kid) which is useful for determining // the key to use for parsing. Particularly useful for clients. -func NewSignerWithClaims(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { +func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { token := jwt.NewWithClaims(method, claims) @@ -69,18 +66,11 @@ } } -// NewSigner creates a new JWT token generating middleware, specifying key ID, -// signing string, signing method and the claims you would like it to contain. -// It passes these values onto NewSignerWithClaims to handle the signing process. -func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { - return NewSignerWithClaims(kid, key, method, jwt.MapClaims(claims)) -} - -// NewParserWithClaims creates a new JWT token parsing middleware, specifying a +// NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method as well as the claims to parse into. // NewParserWithClaims adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParserWithClaims(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -126,20 +116,9 @@ return nil, ErrTokenInvalid } - if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok { - ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(tokenClaims)) - } else { - ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) - } + ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) return next(ctx, request) } } } - -// NewParser creates a new JWT token parsing middleware, specifying a -// jwt.KeyFunc interface and the signing method. It will utilize NewParserWithClaims -// and fall back to implementing the jwt.MapClaims type. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { - return NewParserWithClaims(keyFunc, method, jwt.MapClaims{}) -} diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 0ba2999..cc34066 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -13,7 +13,6 @@ key = []byte("test_signing_key") method = jwt.SigningMethodHS256 invalidMethod = jwt.SigningMethodRS256 - claims = Claims{"user": "go-kit"} mapClaims = jwt.MapClaims{"user": "go-kit"} standardClaims = jwt.StandardClaims{Audience: "go-kit"} // Signed tokens generated at https://jwt.io/ @@ -41,13 +40,10 @@ func TestNewSigner(t *testing.T) { e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - signer := NewSigner(kid, key, method, claims)(e) + signer := NewSigner(kid, key, method, mapClaims)(e) signingValidator(t, signer, signedKey) - signer = NewSignerWithClaims(kid, key, method, mapClaims)(e) - signingValidator(t, signer, signedKey) - - signer = NewSignerWithClaims(kid, key, method, standardClaims)(e) + signer = NewSigner(kid, key, method, standardClaims)(e) signingValidator(t, signer, standardSignedKey) } @@ -58,7 +54,7 @@ return key, nil } - parser := NewParser(keys, method)(e) + parser := NewParser(keys, method, jwt.MapClaims{})(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -78,7 +74,7 @@ } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod)(e) + badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -94,7 +90,7 @@ return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method)(e) + badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -108,16 +104,16 @@ t.Fatalf("Parser returned error: %s", err) } - cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims) + cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims) if !ok { t.Fatal("Claims were not passed into context correctly") } - if cl["user"] != claims["user"] { - t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"]) + if cl["user"] != mapClaims["user"] { + t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"]) } - parser = NewParserWithClaims(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, &jwt.StandardClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil {