Refactor away from passing a function to passing a struct with multiple options for signing keys
Brian Kassouf
7 years ago
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "errors" |
4 | "fmt" | |
5 | "reflect" | |
6 | 4 | |
7 | 5 | "golang.org/x/net/context" |
8 | 6 | |
17 | 15 | JWTClaimsContextKey = "JWTClaims" |
18 | 16 | ) |
19 | 17 | |
18 | var ( | |
19 | ErrTokenContextMissing = errors.New("Token up for parsing was not passed through the context") | |
20 | ErrTokenInvalid = errors.New("JWT Token was invalid") | |
21 | ErrUnexpectedSigningMethod = errors.New("Unexptected signing method") | |
22 | ErrKIDNotFound = errors.New("Key ID was not found in key set") | |
23 | ErrNoKIDHeader = errors.New("Token doesn't have 'kid' header") | |
24 | ) | |
25 | ||
26 | type Claims map[string]interface{} | |
27 | ||
28 | type KeySet map[string]struct { | |
29 | Method jwt.SigningMethod | |
30 | Key []byte | |
31 | } | |
32 | ||
20 | 33 | // Create a new JWT token generating middleware, specifying signing method and the claims |
21 | 34 | // you would like it to contain. Particularly useful for clients. |
22 | func NewSigner(key string, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { | |
35 | func NewSigner(kid string, keys KeySet, claims Claims) endpoint.Middleware { | |
23 | 36 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
24 | 37 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
25 | token := jwt.NewWithClaims(method, claims) | |
38 | key, ok := keys[kid] | |
39 | if !ok { | |
40 | return nil, ErrKIDNotFound | |
41 | } | |
26 | 42 | |
43 | token := jwt.NewWithClaims(key.Method, jwt.MapClaims(claims)) | |
44 | token.Header["kid"] = kid | |
27 | 45 | // Sign and get the complete encoded token as a string using the secret |
28 | tokenString, err := token.SignedString([]byte(key)) | |
46 | tokenString, err := token.SignedString(key.Key) | |
29 | 47 | if err != nil { |
30 | 48 | return nil, err |
31 | 49 | } |
39 | 57 | // Create a new JWT token parsing middleware, specifying a jwt.Keyfunc interface and the |
40 | 58 | // signing method. Adds the resulting claims to endpoint context or returns error on invalid |
41 | 59 | // token. Particularly useful for servers. |
42 | func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { | |
60 | func NewParser(keys KeySet) endpoint.Middleware { | |
43 | 61 | return func(next endpoint.Endpoint) endpoint.Endpoint { |
44 | 62 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
45 | 63 | // tokenString is stored in the context from the transport handlers |
46 | 64 | tokenString, ok := ctx.Value(JWTTokenContextKey).(string) |
47 | 65 | if !ok { |
48 | return nil, errors.New("Token up for parsing was not passed through the context") | |
66 | return nil, ErrTokenContextMissing | |
49 | 67 | } |
50 | 68 | |
51 | 69 | // Parse takes the token string and a function for looking up the key. The latter is especially |
53 | 71 | // head of the token to identify which key to use, but the parsed token (head and claims) is provided |
54 | 72 | // to the callback, providing flexibility. |
55 | 73 | token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { |
74 | kid, ok := token.Header["kid"] | |
75 | if !ok { | |
76 | return nil, ErrNoKIDHeader | |
77 | } | |
78 | ||
79 | key, ok := keys[kid.(string)] | |
80 | if !ok { | |
81 | return nil, ErrKIDNotFound | |
82 | } | |
83 | ||
56 | 84 | // Don't forget to validate the alg is what you expect: |
57 | if reflect.TypeOf(token.Method) != reflect.TypeOf(method) { | |
58 | return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) | |
85 | if token.Method != key.Method { | |
86 | return nil, ErrUnexpectedSigningMethod | |
59 | 87 | } |
60 | return keyFunc(token) | |
88 | ||
89 | return key.Key, nil | |
61 | 90 | }) |
62 | 91 | if err != nil { |
63 | 92 | return nil, err |
64 | 93 | } |
65 | 94 | |
66 | 95 | if !token.Valid { |
67 | return nil, errors.New("Could not parse JWT Token") | |
96 | return nil, ErrTokenInvalid | |
68 | 97 | } |
69 | 98 | |
70 | 99 | if claims, ok := token.Claims.(jwt.MapClaims); ok { |
71 | ctx = context.WithValue(ctx, JWTClaimsContextKey, claims) | |
100 | ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims)) | |
72 | 101 | } |
73 | 102 | |
74 | 103 | return next(ctx, request) |