Remove ambiguous Claims type.
Cameron Stitt
6 years ago
43 | 43 | ErrUnexpectedSigningMethod = errors.New("unexpected signing method") |
44 | 44 | ) |
45 | 45 | |
46 | // Claims is a map of arbitrary claim data. | |
47 | type Claims map[string]interface{} | |
48 | ||
49 | // NewSignerWithClaims creates a new JWT token generating middleware, specifying key ID, | |
46 | // NewSigner creates a new JWT token generating middleware, specifying key ID, | |
50 | 47 | // signing string, signing method and the jwt.Claims you would like it to contain. |
51 | 48 | // Tokens are signed with a Key ID header (kid) which is useful for determining |
52 | 49 | // the key to use for parsing. Particularly useful for clients. |
53 | func NewSignerWithClaims(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { | |
50 | func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { | |
54 | 51 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
55 | 52 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
56 | 53 | token := jwt.NewWithClaims(method, claims) |
68 | 65 | } |
69 | 66 | } |
70 | 67 | |
71 | // NewSigner creates a new JWT token generating middleware, specifying key ID, | |
72 | // signing string, signing method and the claims you would like it to contain. | |
73 | // It passes these values onto NewSignerWithClaims to handle the signing process. | |
74 | func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { | |
75 | return NewSignerWithClaims(kid, key, method, jwt.MapClaims(claims)) | |
76 | } | |
77 | ||
78 | // NewParserWithClaims creates a new JWT token parsing middleware, specifying a | |
68 | // NewParser creates a new JWT token parsing middleware, specifying a | |
79 | 69 | // jwt.Keyfunc interface, the signing method as well as the claims to parse into. |
80 | 70 | // NewParserWithClaims adds the resulting claims to endpoint context or returns error on invalid token. |
81 | 71 | // Particularly useful for servers. |
82 | func NewParserWithClaims(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { | |
72 | func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { | |
83 | 73 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
84 | 74 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
85 | 75 | // tokenString is stored in the context from the transport handlers. |
125 | 115 | return nil, ErrTokenInvalid |
126 | 116 | } |
127 | 117 | |
128 | if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok { | |
129 | ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(tokenClaims)) | |
130 | } else { | |
131 | ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) | |
132 | } | |
118 | ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) | |
133 | 119 | |
134 | 120 | return next(ctx, request) |
135 | 121 | } |
136 | 122 | } |
137 | 123 | } |
138 | ||
139 | // NewParser creates a new JWT token parsing middleware, specifying a | |
140 | // jwt.KeyFunc interface and the signing method. It will utilize NewParserWithClaims | |
141 | // and fall back to implementing the jwt.MapClaims type. | |
142 | func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { | |
143 | return NewParserWithClaims(keyFunc, method, jwt.MapClaims{}) | |
144 | } |
12 | 12 | key = []byte("test_signing_key") |
13 | 13 | method = jwt.SigningMethodHS256 |
14 | 14 | invalidMethod = jwt.SigningMethodRS256 |
15 | claims = Claims{"user": "go-kit"} | |
16 | 15 | mapClaims = jwt.MapClaims{"user": "go-kit"} |
17 | 16 | standardClaims = jwt.StandardClaims{Audience: "go-kit"} |
18 | 17 | // Signed tokens generated at https://jwt.io/ |
40 | 39 | func TestNewSigner(t *testing.T) { |
41 | 40 | e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } |
42 | 41 | |
43 | signer := NewSigner(kid, key, method, claims)(e) | |
42 | signer := NewSigner(kid, key, method, mapClaims)(e) | |
44 | 43 | signingValidator(t, signer, signedKey) |
45 | 44 | |
46 | signer = NewSignerWithClaims(kid, key, method, mapClaims)(e) | |
47 | signingValidator(t, signer, signedKey) | |
48 | ||
49 | signer = NewSignerWithClaims(kid, key, method, standardClaims)(e) | |
45 | signer = NewSigner(kid, key, method, standardClaims)(e) | |
50 | 46 | signingValidator(t, signer, standardSignedKey) |
51 | 47 | } |
52 | 48 | |
57 | 53 | return key, nil |
58 | 54 | } |
59 | 55 | |
60 | parser := NewParser(keys, method)(e) | |
56 | parser := NewParser(keys, method, jwt.MapClaims{})(e) | |
61 | 57 | |
62 | 58 | // No Token is passed into the parser |
63 | 59 | _, err := parser(context.Background(), struct{}{}) |
77 | 73 | } |
78 | 74 | |
79 | 75 | // Invalid Method is used in the parser |
80 | badParser := NewParser(keys, invalidMethod)(e) | |
76 | badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) | |
81 | 77 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) |
82 | 78 | _, err = badParser(ctx, struct{}{}) |
83 | 79 | if err == nil { |
93 | 89 | return []byte("bad"), nil |
94 | 90 | } |
95 | 91 | |
96 | badParser = NewParser(invalidKeys, method)(e) | |
92 | badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) | |
97 | 93 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) |
98 | 94 | _, err = badParser(ctx, struct{}{}) |
99 | 95 | if err == nil { |
107 | 103 | t.Fatalf("Parser returned error: %s", err) |
108 | 104 | } |
109 | 105 | |
110 | cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims) | |
106 | cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims) | |
111 | 107 | if !ok { |
112 | 108 | t.Fatal("Claims were not passed into context correctly") |
113 | 109 | } |
114 | 110 | |
115 | if cl["user"] != claims["user"] { | |
116 | t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"]) | |
111 | if cl["user"] != mapClaims["user"] { | |
112 | t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"]) | |
117 | 113 | } |
118 | 114 | |
119 | parser = NewParserWithClaims(keys, method, &jwt.StandardClaims{})(e) | |
115 | parser = NewParser(keys, method, &jwt.StandardClaims{})(e) | |
120 | 116 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) |
121 | 117 | ctx1, err = parser(ctx, struct{}{}) |
122 | 118 | if err != nil { |