diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index 2fd0e85..f2d4eb5 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -10,9 +10,9 @@ ) const ( - // JWTContextKey holds the key used to store a JWT Token in the context + // JWTTokenContextKey holds the key used to store a JWT Token in the context JWTTokenContextKey = "JWTToken" - // JWTContextKey holds the key used to store a JWT in the context + // JWTClaimsContxtKey holds the key used to store the JWT Claims in the context JWTClaimsContextKey = "JWTClaims" ) diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index ecefa71..011c73a 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -9,25 +9,32 @@ ) var ( - key = "test_signing_key" + kid = "kid" + key = []byte("test_signing_key") method = jwt.SigningMethodHS256 invalidMethod = jwt.SigningMethodRS256 - claims = jwt.MapClaims{"user": "go-kit"} - signedKey = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.MMefQU5pwDeoWBSdyagqNlr1tDGddGUOMGiIWmMlFvk" + claims = Claims{"user": "go-kit"} + signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" ) func TestSigner(t *testing.T) { e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - signer := NewSigner(key, method, claims)(e) - ctx := context.Background() - ctx1, err := signer(ctx, struct{}{}) + keys := KeySet{ + kid: { + Method: method, + Key: key, + }, + } + + signer := NewSigner(kid, keys, claims)(e) + ctx, err := signer(context.Background(), struct{}{}) if err != nil { t.Fatalf("Signer returned error: %s", err) } - token, ok := ctx1.(context.Context).Value(JWTTokenContextKey).(string) + token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string) if !ok { t.Fatal("Token did not exist in context") } @@ -40,10 +47,14 @@ func TestJWTParser(t *testing.T) { e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - keyfunc := func(token *jwt.Token) (interface{}, error) { return []byte(key), nil } - badKeyfunc := func(token *jwt.Token) (interface{}, error) { return []byte("bad"), nil } + keys := KeySet{ + kid: { + Method: method, + Key: key, + }, + } - parser := NewParser(keyfunc, method)(e) + parser := NewParser(keys)(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -59,7 +70,14 @@ } // Invalid Method is used in the parser - badParser := NewParser(keyfunc, invalidMethod)(e) + invalidMethodKeys := KeySet{ + kid: { + Method: invalidMethod, + Key: key, + }, + } + + badParser := NewParser(invalidMethodKeys)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -67,7 +85,14 @@ } // Invalid key is used in the parser - badParser = NewParser(badKeyfunc, method)(e) + invalidKeys := KeySet{ + kid: { + Method: method, + Key: []byte("bad"), + }, + } + + badParser = NewParser(invalidKeys)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -81,7 +106,7 @@ t.Fatalf("Parser returned error: %s", err) } - cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims) + cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims) if !ok { t.Fatal("Claims were not passed into context correctly") }