Fix typo in jwt package (#1070)
* Fix typo in jwt package
* change JWTTokenContextKey to JWTContextKey
* revise errors as well as comments
* Revision of typo fixing in jwt package
* revert the API identifier to its previous value
* add JWTTokenContextKey side by side of JWTContextKey for historical compatibility
* Fixed a typo in JWT package (#1071)
* define JWTContextKey as a new constant
* mark JWTTokenContextKey as a deprecated constant
* revise corresponding error messages
* Fixed a typo in JWT package (#1071)
* define JWTContextKey as a new constant
* mark JWTTokenContextKey as a deprecated constant
* revise corresponding error messages
Co-authored-by: Amid <amid.dev@protonmail.com>
Amid authored 2 years ago
GitHub committed 2 years ago
6 | 6 | |
7 | 7 | NewParser takes a key function and an expected signing method and returns an |
8 | 8 | `endpoint.Middleware`. The middleware will parse a token passed into the |
9 | context via the `jwt.JWTTokenContextKey`. If the token is valid, any claims | |
9 | context via the `jwt.JWTContextKey`. If the token is valid, any claims | |
10 | 10 | will be added to the context via the `jwt.JWTClaimsContextKey`. |
11 | 11 | |
12 | 12 | ```go |
29 | 29 | |
30 | 30 | NewSigner takes a JWT key ID header, the signing key, signing method, and a |
31 | 31 | claims object. It returns an `endpoint.Middleware`. The middleware will build |
32 | the token string and add it to the context via the `jwt.JWTTokenContextKey`. | |
32 | the token string and add it to the context via the `jwt.JWTContextKey`. | |
33 | 33 | |
34 | 34 | ```go |
35 | 35 | import ( |
11 | 11 | type contextKey string |
12 | 12 | |
13 | 13 | const ( |
14 | // JWTTokenContextKey holds the key used to store a JWT Token in the | |
15 | // context. | |
16 | JWTTokenContextKey contextKey = "JWTToken" | |
14 | // JWTContextKey holds the key used to store a JWT in the context. | |
15 | JWTContextKey contextKey = "JWTToken" | |
16 | ||
17 | // JWTTokenContextKey is an alias for JWTContextKey. | |
18 | // | |
19 | // Deprecated: prefer JWTContextKey. | |
20 | JWTTokenContextKey = JWTContextKey | |
17 | 21 | |
18 | 22 | // JWTClaimsContextKey holds the key used to store the JWT Claims in the |
19 | 23 | // context. |
26 | 30 | ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context") |
27 | 31 | |
28 | 32 | // ErrTokenInvalid denotes a token was not able to be validated. |
29 | ErrTokenInvalid = errors.New("JWT Token was invalid") | |
33 | ErrTokenInvalid = errors.New("JWT was invalid") | |
30 | 34 | |
31 | 35 | // ErrTokenExpired denotes a token's expire header (exp) has since passed. |
32 | ErrTokenExpired = errors.New("JWT Token is expired") | |
36 | ErrTokenExpired = errors.New("JWT is expired") | |
33 | 37 | |
34 | // ErrTokenMalformed denotes a token was not formatted as a JWT token. | |
35 | ErrTokenMalformed = errors.New("JWT Token is malformed") | |
38 | // ErrTokenMalformed denotes a token was not formatted as a JWT. | |
39 | ErrTokenMalformed = errors.New("JWT is malformed") | |
36 | 40 | |
37 | 41 | // ErrTokenNotActive denotes a token's not before header (nbf) is in the |
38 | 42 | // future. |
43 | 47 | ErrUnexpectedSigningMethod = errors.New("unexpected signing method") |
44 | 48 | ) |
45 | 49 | |
46 | // NewSigner creates a new JWT token generating middleware, specifying key ID, | |
50 | // NewSigner creates a new JWT generating middleware, specifying key ID, | |
47 | 51 | // signing string, signing method and the claims you would like it to contain. |
48 | 52 | // Tokens are signed with a Key ID header (kid) which is useful for determining |
49 | 53 | // the key to use for parsing. Particularly useful for clients. |
58 | 62 | if err != nil { |
59 | 63 | return nil, err |
60 | 64 | } |
61 | ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString) | |
65 | ctx = context.WithValue(ctx, JWTContextKey, tokenString) | |
62 | 66 | |
63 | 67 | return next(ctx, request) |
64 | 68 | } |
81 | 85 | return &jwt.StandardClaims{} |
82 | 86 | } |
83 | 87 | |
84 | // NewParser creates a new JWT token parsing middleware, specifying a | |
88 | // NewParser creates a new JWT parsing middleware, specifying a | |
85 | 89 | // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser |
86 | 90 | // adds the resulting claims to endpoint context or returns error on invalid token. |
87 | 91 | // Particularly useful for servers. |
89 | 93 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
90 | 94 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
91 | 95 | // tokenString is stored in the context from the transport handlers. |
92 | tokenString, ok := ctx.Value(JWTTokenContextKey).(string) | |
96 | tokenString, ok := ctx.Value(JWTContextKey).(string) | |
93 | 97 | if !ok { |
94 | 98 | return nil, ErrTokenContextMissing |
95 | 99 | } |
43 | 43 | t.Fatalf("Signer returned error: %s", err) |
44 | 44 | } |
45 | 45 | |
46 | token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string) | |
46 | token, ok := ctx.(context.Context).Value(JWTContextKey).(string) | |
47 | 47 | if !ok { |
48 | 48 | t.Fatal("Token did not exist in context") |
49 | 49 | } |
50 | 50 | |
51 | 51 | if token != expectedKey { |
52 | t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token) | |
52 | t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token) | |
53 | 53 | } |
54 | 54 | } |
55 | 55 | |
86 | 86 | } |
87 | 87 | |
88 | 88 | // Invalid Token is passed into the parser |
89 | ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey) | |
89 | ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey) | |
90 | 90 | _, err = parser(ctx, struct{}{}) |
91 | 91 | if err == nil { |
92 | 92 | t.Error("Parser should have returned an error") |
94 | 94 | |
95 | 95 | // Invalid Method is used in the parser |
96 | 96 | badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) |
97 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) | |
97 | ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) | |
98 | 98 | _, err = badParser(ctx, struct{}{}) |
99 | 99 | if err == nil { |
100 | 100 | t.Error("Parser should have returned an error") |
110 | 110 | } |
111 | 111 | |
112 | 112 | badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) |
113 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) | |
113 | ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) | |
114 | 114 | _, err = badParser(ctx, struct{}{}) |
115 | 115 | if err == nil { |
116 | 116 | t.Error("Parser should have returned an error") |
117 | 117 | } |
118 | 118 | |
119 | 119 | // Correct token is passed into the parser |
120 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) | |
120 | ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) | |
121 | 121 | ctx1, err := parser(ctx, struct{}{}) |
122 | 122 | if err != nil { |
123 | 123 | t.Fatalf("Parser returned error: %s", err) |
134 | 134 | |
135 | 135 | // Test for malformed token error response |
136 | 136 | parser = NewParser(keys, method, StandardClaimsFactory)(e) |
137 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) | |
137 | ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey) | |
138 | 138 | ctx1, err = parser(ctx, struct{}{}) |
139 | 139 | if want, have := ErrTokenMalformed, err; want != have { |
140 | 140 | t.Fatalf("Expected %+v, got %+v", want, have) |
147 | 147 | if err != nil { |
148 | 148 | t.Fatalf("Unable to Sign Token: %+v", err) |
149 | 149 | } |
150 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) | |
150 | ctx = context.WithValue(context.Background(), JWTContextKey, token) | |
151 | 151 | ctx1, err = parser(ctx, struct{}{}) |
152 | 152 | if want, have := ErrTokenExpired, err; want != have { |
153 | 153 | t.Fatalf("Expected %+v, got %+v", want, have) |
160 | 160 | if err != nil { |
161 | 161 | t.Fatalf("Unable to Sign Token: %+v", err) |
162 | 162 | } |
163 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) | |
163 | ctx = context.WithValue(context.Background(), JWTContextKey, token) | |
164 | 164 | ctx1, err = parser(ctx, struct{}{}) |
165 | 165 | if want, have := ErrTokenNotActive, err; want != have { |
166 | 166 | t.Fatalf("Expected %+v, got %+v", want, have) |
168 | 168 | |
169 | 169 | // test valid standard claims token |
170 | 170 | parser = NewParser(keys, method, StandardClaimsFactory)(e) |
171 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) | |
171 | ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey) | |
172 | 172 | ctx1, err = parser(ctx, struct{}{}) |
173 | 173 | if err != nil { |
174 | 174 | t.Fatalf("Parser returned error: %s", err) |
183 | 183 | |
184 | 184 | // test valid customized claims token |
185 | 185 | parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e) |
186 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) | |
186 | ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey) | |
187 | 187 | ctx1, err = parser(ctx, struct{}{}) |
188 | 188 | if err != nil { |
189 | 189 | t.Fatalf("Parser returned error: %s", err) |
204 | 204 | var ( |
205 | 205 | kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } |
206 | 206 | e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) |
207 | key = JWTTokenContextKey | |
207 | key = JWTContextKey | |
208 | 208 | val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" |
209 | 209 | ctx = context.WithValue(context.Background(), key, val) |
210 | 210 | ) |
25 | 25 | return ctx |
26 | 26 | } |
27 | 27 | |
28 | return context.WithValue(ctx, JWTTokenContextKey, token) | |
28 | return context.WithValue(ctx, JWTContextKey, token) | |
29 | 29 | } |
30 | 30 | } |
31 | 31 | |
33 | 33 | // useful for clients. |
34 | 34 | func ContextToHTTP() http.RequestFunc { |
35 | 35 | return func(ctx context.Context, r *stdhttp.Request) context.Context { |
36 | token, ok := ctx.Value(JWTTokenContextKey).(string) | |
36 | token, ok := ctx.Value(JWTContextKey).(string) | |
37 | 37 | if ok { |
38 | 38 | r.Header.Add("Authorization", generateAuthHeaderFromToken(token)) |
39 | 39 | } |
53 | 53 | |
54 | 54 | token, ok := extractTokenFromAuthHeader(authHeader[0]) |
55 | 55 | if ok { |
56 | ctx = context.WithValue(ctx, JWTTokenContextKey, token) | |
56 | ctx = context.WithValue(ctx, JWTContextKey, token) | |
57 | 57 | } |
58 | 58 | |
59 | 59 | return ctx |
64 | 64 | // useful for clients. |
65 | 65 | func ContextToGRPC() grpc.ClientRequestFunc { |
66 | 66 | return func(ctx context.Context, md *metadata.MD) context.Context { |
67 | token, ok := ctx.Value(JWTTokenContextKey).(string) | |
67 | token, ok := ctx.Value(JWTContextKey).(string) | |
68 | 68 | if ok { |
69 | 69 | // capital "Key" is illegal in HTTP/2. |
70 | 70 | (*md)["authorization"] = []string{generateAuthHeaderFromToken(token)} |
14 | 14 | // When the header doesn't exist |
15 | 15 | ctx := reqFunc(context.Background(), &http.Request{}) |
16 | 16 | |
17 | if ctx.Value(JWTTokenContextKey) != nil { | |
17 | if ctx.Value(JWTContextKey) != nil { | |
18 | 18 | t.Error("Context shouldn't contain the encoded JWT") |
19 | 19 | } |
20 | 20 | |
23 | 23 | header.Set("Authorization", "no expected auth header format value") |
24 | 24 | ctx = reqFunc(context.Background(), &http.Request{Header: header}) |
25 | 25 | |
26 | if ctx.Value(JWTTokenContextKey) != nil { | |
26 | if ctx.Value(JWTContextKey) != nil { | |
27 | 27 | t.Error("Context shouldn't contain the encoded JWT") |
28 | 28 | } |
29 | 29 | |
31 | 31 | header.Set("Authorization", generateAuthHeaderFromToken(signedKey)) |
32 | 32 | ctx = reqFunc(context.Background(), &http.Request{Header: header}) |
33 | 33 | |
34 | token := ctx.Value(JWTTokenContextKey).(string) | |
34 | token := ctx.Value(JWTContextKey).(string) | |
35 | 35 | if token != signedKey { |
36 | 36 | t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", signedKey, token) |
37 | 37 | } |
40 | 40 | func TestContextToHTTP(t *testing.T) { |
41 | 41 | reqFunc := ContextToHTTP() |
42 | 42 | |
43 | // No JWT Token is passed in the context | |
43 | // No JWT is passed in the context | |
44 | 44 | ctx := context.Background() |
45 | 45 | r := http.Request{} |
46 | 46 | reqFunc(ctx, &r) |
50 | 50 | t.Error("authorization key should not exist in metadata") |
51 | 51 | } |
52 | 52 | |
53 | // Correct JWT Token is passed in the context | |
54 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) | |
53 | // Correct JWT is passed in the context | |
54 | ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) | |
55 | 55 | r = http.Request{Header: http.Header{}} |
56 | 56 | reqFunc(ctx, &r) |
57 | 57 | |
59 | 59 | expected := generateAuthHeaderFromToken(signedKey) |
60 | 60 | |
61 | 61 | if token != expected { |
62 | t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token) | |
62 | t.Errorf("Authorization header does not contain the expected JWT; expected %s, got %s", expected, token) | |
63 | 63 | } |
64 | 64 | } |
65 | 65 | |
69 | 69 | |
70 | 70 | // No Authorization header is passed |
71 | 71 | ctx := reqFunc(context.Background(), md) |
72 | token := ctx.Value(JWTTokenContextKey) | |
72 | token := ctx.Value(JWTContextKey) | |
73 | 73 | if token != nil { |
74 | t.Error("Context should not contain a JWT Token") | |
74 | t.Error("Context should not contain a JWT") | |
75 | 75 | } |
76 | 76 | |
77 | 77 | // Invalid Authorization header is passed |
78 | 78 | md["authorization"] = []string{fmt.Sprintf("%s", signedKey)} |
79 | 79 | ctx = reqFunc(context.Background(), md) |
80 | token = ctx.Value(JWTTokenContextKey) | |
80 | token = ctx.Value(JWTContextKey) | |
81 | 81 | if token != nil { |
82 | t.Error("Context should not contain a JWT Token") | |
82 | t.Error("Context should not contain a JWT") | |
83 | 83 | } |
84 | 84 | |
85 | 85 | // Authorization header is correct |
86 | 86 | md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} |
87 | 87 | ctx = reqFunc(context.Background(), md) |
88 | token, ok := ctx.Value(JWTTokenContextKey).(string) | |
88 | token, ok := ctx.Value(JWTContextKey).(string) | |
89 | 89 | if !ok { |
90 | t.Fatal("JWT Token not passed to context correctly") | |
90 | t.Fatal("JWT not passed to context correctly") | |
91 | 91 | } |
92 | 92 | |
93 | 93 | if token != signedKey { |
94 | t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token) | |
94 | t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token) | |
95 | 95 | } |
96 | 96 | } |
97 | 97 | |
98 | 98 | func TestContextToGRPC(t *testing.T) { |
99 | 99 | reqFunc := ContextToGRPC() |
100 | 100 | |
101 | // No JWT Token is passed in the context | |
101 | // No JWT is passed in the context | |
102 | 102 | ctx := context.Background() |
103 | 103 | md := metadata.MD{} |
104 | 104 | reqFunc(ctx, &md) |
108 | 108 | t.Error("authorization key should not exist in metadata") |
109 | 109 | } |
110 | 110 | |
111 | // Correct JWT Token is passed in the context | |
112 | ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) | |
111 | // Correct JWT is passed in the context | |
112 | ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) | |
113 | 113 | md = metadata.MD{} |
114 | 114 | reqFunc(ctx, &md) |
115 | 115 | |
116 | 116 | token, ok := md["authorization"] |
117 | 117 | if !ok { |
118 | t.Fatal("JWT Token not passed to metadata correctly") | |
118 | t.Fatal("JWT not passed to metadata correctly") | |
119 | 119 | } |
120 | 120 | |
121 | 121 | if token[0] != generateAuthHeaderFromToken(signedKey) { |
122 | t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token[0]) | |
122 | t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token[0]) | |
123 | 123 | } |
124 | 124 | } |