diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index b03b9aa..ff44ed7 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,13 +66,25 @@ } } -type claimsFactory func() jwt.Claims +type ClaimsFactory func() jwt.Claims + +// MapClaimsFactory is a ClaimsFactory that returns +// an empty jwt.MapClaims. +func MapClaimsFactory() jwt.Claims { + return jwt.MapClaims{} +} + +// StandardClaimsFactory is a ClaimsFactory that returns +// an empty jwt.StandardClaims. +func StandardClaimsFactory() jwt.Claims { + return &jwt.StandardClaims{} +} // NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims claimsFactory) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 977bef6..3278e13 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -74,7 +74,7 @@ return key, nil } - parser := NewParser(keys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) + parser := NewParser(keys, method, MapClaimsFactory)(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -94,7 +94,7 @@ } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod, func() jwt.Claims { return jwt.MapClaims{} })(e) + badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -110,7 +110,7 @@ return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method, func() jwt.Claims { return jwt.MapClaims{} })(e) + badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -134,7 +134,7 @@ } // Test for malformed token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { @@ -142,7 +142,7 @@ } // Test for expired token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) token, err := expired.SignedString(key) if err != nil { @@ -155,7 +155,7 @@ } // Test for not activated token error response - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) token, err = notactive.SignedString(key) if err != nil { @@ -168,7 +168,7 @@ } // test valid standard claims token - parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -204,7 +204,7 @@ func TestIssue562(t *testing.T) { var ( kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } - e = NewParser(kf, jwt.SigningMethodHS256, func() jwt.Claims { return jwt.MapClaims{} })(endpoint.Nop) + e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) key = JWTTokenContextKey val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" ctx = context.WithValue(context.Background(), key, val)