diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index cce4695..e7dcb9d 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -94,21 +94,24 @@ return keyFunc(token) }) if err != nil { - if e, ok := err.(*jwt.ValidationError); ok && e.Inner != nil { - if e.Errors&jwt.ValidationErrorMalformed != 0 { + if e, ok := err.(*jwt.ValidationError); ok { + switch { + case e.Errors&jwt.ValidationErrorMalformed != 0: // Token is malformed return nil, ErrTokenMalformed - } else if e.Errors&jwt.ValidationErrorExpired != 0 { + case e.Errors&jwt.ValidationErrorExpired != 0: // Token is expired return nil, ErrTokenExpired - } else if e.Errors&jwt.ValidationErrorNotValidYet != 0 { + case e.Errors&jwt.ValidationErrorNotValidYet != 0: // Token is not active yet return nil, ErrTokenNotActive + case e.Inner != nil: + // report e.Inner + return nil, e.Inner } - - return nil, e.Inner + // We have a ValidationError but have no specific Go kit error for it. + // Fall through to return original error. } - return nil, err } diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 76889d6..8a45201 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -3,6 +3,7 @@ import ( "context" "testing" + "time" "crypto/subtle" @@ -33,6 +34,7 @@ standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY" customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0" invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" + malformedKey = "malformed.jwt.token" ) func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) { @@ -130,6 +132,41 @@ t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"]) } + // Test for malformed token error response + parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) + ctx1, err = parser(ctx, struct{}{}) + if want, have := ErrTokenMalformed, err; want != have { + t.Fatalf("Expected %+v, got %+v", want, have) + } + + // Test for expired token error response + parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) + token, err := expired.SignedString(key) + if err != nil { + t.Fatalf("Unable to Sign Token: %+v", err) + } + ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) + ctx1, err = parser(ctx, struct{}{}) + if want, have := ErrTokenExpired, err; want != have { + t.Fatalf("Expected %+v, got %+v", want, have) + } + + // Test for not activated token error response + parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) + token, err = notactive.SignedString(key) + if err != nil { + t.Fatalf("Unable to Sign Token: %+v", err) + } + ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) + ctx1, err = parser(ctx, struct{}{}) + if want, have := ErrTokenNotActive, err; want != have { + t.Fatalf("Expected %+v, got %+v", want, have) + } + + // test valid standard claims token parser = NewParser(keys, method, &jwt.StandardClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) @@ -144,6 +181,7 @@ t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience) } + // test valid customized claims token parser = NewParser(keys, method, &customClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) ctx1, err = parser(ctx, struct{}{})