Codebase list golang-github-go-kit-kit / c9c7219
Merge pull request #488 from cam-stitt/jwt-claims Add *WithClaims methods to jwt middleware for more advanced usage. Peter Bourgon authored 7 years ago GitHub committed 7 years ago
2 changed file(s) with 85 addition(s) and 32 deletion(s). Raw diff Collapse all Expand all
4343 ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
4444 )
4545
46 // Claims is a map of arbitrary claim data.
47 type Claims map[string]interface{}
48
4946 // NewSigner creates a new JWT token generating middleware, specifying key ID,
5047 // signing string, signing method and the claims you would like it to contain.
5148 // Tokens are signed with a Key ID header (kid) which is useful for determining
5249 // the key to use for parsing. Particularly useful for clients.
53 func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware {
50 func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
5451 return func(next endpoint.Endpoint) endpoint.Endpoint {
5552 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
56 token := jwt.NewWithClaims(method, jwt.MapClaims(claims))
53 token := jwt.NewWithClaims(method, claims)
5754 token.Header["kid"] = kid
5855
5956 // Sign and get the complete encoded token as a string using the secret
6966 }
7067
7168 // NewParser creates a new JWT token parsing middleware, specifying a
72 // jwt.Keyfunc interface and the signing method. NewParser adds the resulting
73 // claims to endpoint context or returns error on invalid token. Particularly
74 // useful for servers.
75 func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware {
69 // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
70 // adds the resulting claims to endpoint context or returns error on invalid token.
71 // Particularly useful for servers.
72 func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
7673 return func(next endpoint.Endpoint) endpoint.Endpoint {
7774 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
7875 // tokenString is stored in the context from the transport handlers.
8784 // of the token to identify which key to use, but the parsed token
8885 // (head and claims) is provided to the callback, providing
8986 // flexibility.
90 token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
87 token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
9188 // Don't forget to validate the alg is what you expect:
9289 if token.Method != method {
9390 return nil, ErrUnexpectedSigningMethod
118115 return nil, ErrTokenInvalid
119116 }
120117
121 if claims, ok := token.Claims.(jwt.MapClaims); ok {
122 ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims))
123 }
118 ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims)
124119
125120 return next(ctx, request)
126121 }
33 "context"
44 "testing"
55
6 "crypto/subtle"
7
68 jwt "github.com/dgrijalva/jwt-go"
9 "github.com/go-kit/kit/endpoint"
710 )
811
12 type customClaims struct {
13 MyProperty string `json:"my_property"`
14 jwt.StandardClaims
15 }
16
17 func (c customClaims) VerifyMyProperty(p string) bool {
18 return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
19 }
20
921 var (
10 kid = "kid"
11 key = []byte("test_signing_key")
12 method = jwt.SigningMethodHS256
13 invalidMethod = jwt.SigningMethodRS256
14 claims = Claims{"user": "go-kit"}
22 kid = "kid"
23 key = []byte("test_signing_key")
24 myProperty = "some value"
25 method = jwt.SigningMethodHS256
26 invalidMethod = jwt.SigningMethodRS256
27 mapClaims = jwt.MapClaims{"user": "go-kit"}
28 standardClaims = jwt.StandardClaims{Audience: "go-kit"}
29 myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
1530 // Signed tokens generated at https://jwt.io/
16 signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
17 invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
31 signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
32 standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
33 customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
34 invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
1835 )
1936
20 func TestSigner(t *testing.T) {
21 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
22
23 signer := NewSigner(kid, key, method, claims)(e)
37 func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
2438 ctx, err := signer(context.Background(), struct{}{})
2539 if err != nil {
2640 t.Fatalf("Signer returned error: %s", err)
3145 t.Fatal("Token did not exist in context")
3246 }
3347
34 if token != signedKey {
35 t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token)
48 if token != expectedKey {
49 t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
3650 }
51 }
52
53 func TestNewSigner(t *testing.T) {
54 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
55
56 signer := NewSigner(kid, key, method, mapClaims)(e)
57 signingValidator(t, signer, signedKey)
58
59 signer = NewSigner(kid, key, method, standardClaims)(e)
60 signingValidator(t, signer, standardSignedKey)
61
62 signer = NewSigner(kid, key, method, myCustomClaims)(e)
63 signingValidator(t, signer, customSignedKey)
3764 }
3865
3966 func TestJWTParser(t *testing.T) {
4370 return key, nil
4471 }
4572
46 parser := NewParser(keys, method)(e)
73 parser := NewParser(keys, method, jwt.MapClaims{})(e)
4774
4875 // No Token is passed into the parser
4976 _, err := parser(context.Background(), struct{}{})
6390 }
6491
6592 // Invalid Method is used in the parser
66 badParser := NewParser(keys, invalidMethod)(e)
93 badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
6794 ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
6895 _, err = badParser(ctx, struct{}{})
6996 if err == nil {
79106 return []byte("bad"), nil
80107 }
81108
82 badParser = NewParser(invalidKeys, method)(e)
109 badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
83110 ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
84111 _, err = badParser(ctx, struct{}{})
85112 if err == nil {
93120 t.Fatalf("Parser returned error: %s", err)
94121 }
95122
96 cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims)
123 cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
97124 if !ok {
98125 t.Fatal("Claims were not passed into context correctly")
99126 }
100127
101 if cl["user"] != claims["user"] {
102 t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"])
128 if cl["user"] != mapClaims["user"] {
129 t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
130 }
131
132 parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
133 ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
134 ctx1, err = parser(ctx, struct{}{})
135 if err != nil {
136 t.Fatalf("Parser returned error: %s", err)
137 }
138 stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
139 if !ok {
140 t.Fatal("Claims were not passed into context correctly")
141 }
142 if !stdCl.VerifyAudience("go-kit", true) {
143 t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
144 }
145
146 parser = NewParser(keys, method, &customClaims{})(e)
147 ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
148 ctx1, err = parser(ctx, struct{}{})
149 if err != nil {
150 t.Fatalf("Parser returned error: %s", err)
151 }
152 custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
153 if !ok {
154 t.Fatal("Claims were not passed into context correctly")
155 }
156 if !custCl.VerifyAudience("go-kit", true) {
157 t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
158 }
159 if !custCl.VerifyMyProperty(myProperty) {
160 t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
103161 }
104162 }