Move jwtToken strings into a const variable & rename metadata key to 'authorization'
Brian Kassouf
7 years ago
12 | 12 | "github.com/go-kit/kit/endpoint" |
13 | 13 | ) |
14 | 14 | |
15 | const ( | |
16 | // JWTContextKey holds the key used to store a JWT Token in the context | |
17 | JWTTokenContextKey = "JWTToken" | |
18 | // JWTContextKey holds the key used to store a JWT in the context | |
19 | JWTClaimsContextKey = "JWTClaims" | |
20 | ) | |
21 | ||
15 | 22 | // Create a new JWT token generating middleware, specifying signing method and the claims |
16 | 23 | // you would like it to contain. Particulary useful for clients. |
17 | 24 | func NewSigner(key string, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { |
24 | 31 | if err != nil { |
25 | 32 | return nil, err |
26 | 33 | } |
27 | md := metadata.Pairs("jwtToken", tokenString) | |
34 | md := metadata.MD{JWTTokenContextKey: []string{tokenString}} | |
28 | 35 | ctx = metadata.NewContext(ctx, md) |
29 | 36 | |
30 | 37 | return next(ctx, request) |
39 | 46 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
40 | 47 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
41 | 48 | // tokenString is stored in the context from the transport handlers |
42 | tokenString, ok := ctx.Value("jwtToken").(string) | |
49 | tokenString, ok := ctx.Value(JWTTokenContextKey).(string) | |
43 | 50 | if !ok { |
44 | 51 | return nil, errors.New("Token up for parsing was not passed through the context") |
45 | 52 | } |
64 | 71 | } |
65 | 72 | |
66 | 73 | if claims, ok := token.Claims.(jwt.MapClaims); ok { |
67 | ctx = context.WithValue(ctx, "jwtClaims", claims) | |
74 | ctx = context.WithValue(ctx, JWTClaimsContextKey, claims) | |
68 | 75 | } |
69 | 76 | |
70 | 77 | return next(ctx, request) |
32 | 32 | t.Fatal("Could not retrieve metadata from context") |
33 | 33 | } |
34 | 34 | |
35 | token, ok := md["jwttoken"] | |
35 | token, ok := md[jwt.JWTTokenContextKey] | |
36 | 36 | if !ok { |
37 | 37 | t.Fatal("Token did not exist in context") |
38 | 38 | } |
48 | 48 | keyfunc := func(token *stdjwt.Token) (interface{}, error) { return []byte(key), nil } |
49 | 49 | |
50 | 50 | parser := jwt.NewParser(keyfunc, method)(e) |
51 | ctx := context.WithValue(context.Background(), "jwtToken", signedKey) | |
51 | ctx := context.WithValue(context.Background(), jwt.JWTTokenContextKey, signedKey) | |
52 | 52 | ctx1, err := parser(ctx, struct{}{}) |
53 | 53 | if err != nil { |
54 | 54 | t.Fatalf("Parser returned error: %s", err) |
55 | 55 | } |
56 | 56 | |
57 | cl, ok := ctx1.(context.Context).Value("jwtClaims").(stdjwt.MapClaims) | |
57 | cl, ok := ctx1.(context.Context).Value(jwt.JWTClaimsContextKey).(stdjwt.MapClaims) | |
58 | 58 | if !ok { |
59 | 59 | t.Fatal("Claims were not passed into context correctly") |
60 | 60 | } |
17 | 17 | return ctx |
18 | 18 | } |
19 | 19 | |
20 | return context.WithValue(ctx, "jwtToken", token) | |
20 | return context.WithValue(ctx, JWTTokenContextKey, token) | |
21 | 21 | } |
22 | 22 | } |
23 | 23 | |
24 | 24 | func ToGRPCContext() grpc.RequestFunc { |
25 | 25 | return func(ctx context.Context, md *metadata.MD) context.Context { |
26 | authHeader, ok := (*md)["Authorization"] | |
26 | // capital "Key" is illegal in HTTP/2. | |
27 | authHeader, ok := (*md)["authorization"] | |
27 | 28 | if !ok { |
28 | 29 | return ctx |
29 | 30 | } |
30 | 31 | |
31 | 32 | token, ok := extractTokenFromAuthHeader(authHeader[0]) |
32 | 33 | if ok { |
33 | ctx = context.WithValue(ctx, "jwtToken", token) | |
34 | ctx = context.WithValue(ctx, JWTTokenContextKey, token) | |
34 | 35 | } |
35 | 36 | |
36 | 37 | return ctx |
44 | 45 | return ctx |
45 | 46 | } |
46 | 47 | |
47 | token, ok := md1["jwttoken"] | |
48 | token, ok := md1[JWTTokenContextKey] | |
48 | 49 | if ok { |
49 | (*md)["Authorization"] = []string{generateAuthHeaderFromToken(token[0])} | |
50 | // capital "Key" is illegal in HTTP/2. | |
51 | (*md)["authorization"] = []string{generateAuthHeaderFromToken(token[0])} | |
50 | 52 | } |
51 | 53 | |
52 | 54 | return ctx |
11 | 11 | |
12 | 12 | func TestToGRPCContext(t *testing.T) { |
13 | 13 | md := metadata.MD{} |
14 | md["Authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} | |
14 | md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} | |
15 | 15 | ctx := context.Background() |
16 | 16 | reqFunc := jwt.ToGRPCContext() |
17 | 17 | |
18 | 18 | ctx = reqFunc(ctx, &md) |
19 | token, ok := ctx.Value("jwtToken").(string) | |
19 | token, ok := ctx.Value(jwt.JWTTokenContextKey).(string) | |
20 | 20 | if !ok { |
21 | 21 | t.Fatal("JWT Token not passed to context correctly") |
22 | 22 | } |
27 | 27 | } |
28 | 28 | |
29 | 29 | func TestFromGRPCContext(t *testing.T) { |
30 | ctx := context.WithValue(context.Background(), "jwtToken", signedKey) | |
30 | ctx := metadata.NewContext(context.Background(), metadata.MD{jwt.JWTTokenContextKey: []string{signedKey}}) | |
31 | 31 | |
32 | 32 | reqFunc := jwt.FromGRPCContext() |
33 | 33 | md := metadata.MD{} |
34 | 34 | reqFunc(ctx, &md) |
35 | token, ok := md["jwttoken"] | |
35 | ||
36 | token, ok := md["authorization"] | |
36 | 37 | if !ok { |
37 | 38 | t.Fatal("JWT Token not passed to metadata correctly") |
38 | 39 | } |
39 | 40 | |
40 | if token[0] != signedKey { | |
41 | if token[0] != generateAuthHeaderFromToken(signedKey) { | |
41 | 42 | t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token[0]) |
42 | 43 | } |
43 | 44 | } |
45 | ||
46 | func generateAuthHeaderFromToken(token string) string { | |
47 | return fmt.Sprintf("Bearer %s", token) | |
48 | } |