diff --git a/auth/jwt/README.md b/auth/jwt/README.md new file mode 100644 index 0000000..bec4f67 --- /dev/null +++ b/auth/jwt/README.md @@ -0,0 +1,122 @@ +# package auth/jwt + +`package auth/jwt` provides a set of interfaces for service authorization +through [JSON Web Tokens](https://jwt.io/). + +## Usage + +NewParser takes a key function and an expected signing method and returns an +`endpoint.Middleware`. The middleware will parse a token passed into the +context via the `jwt.JWTTokenContextKey`. If the token is valid, any claims +will be added to the context via the `jwt.JWTClaimsContextKey`. + +```go +import ( + stdjwt "github.com/dgrijalva/jwt-go" + + "github.com/go-kit/kit/auth/jwt" + "github.com/go-kit/kit/endpoint" +) + +func main() { + var exampleEndpoint endpoint.Endpoint + { + kf := func(token *stdjwt.Token) (interface{}, error) { return []byte("SigningString"), nil } + exampleEndpoint = MakeExampleEndpoint(service) + exampleEndpoint = jwt.NewParser(kf, stdjwt.SigningMethodHS256)(exampleEndpoint) + } +} +``` + +NewSigner takes a JWT key ID header, the signing key, signing method, and a +claims object. It returns an `endpoint.Middleware`. The middleware will build +the token string and add it to the context via the `jwt.JWTTokenContextKey`. + +```go +import ( + stdjwt "github.com/dgrijalva/jwt-go" + + "github.com/go-kit/kit/auth/jwt" + "github.com/go-kit/kit/endpoint" +) + +func main() { + var exampleEndpoint endpoint.Endpoint + { + exampleEndpoint = grpctransport.NewClient(...).Endpoint() + exampleEndpoint = jwt.NewSigner( + "kid-header", + []byte("SigningString"), + stdjwt.SigningMethodHS256, + jwt.Claims{}, + )(exampleEndpoint) + } +} +``` + +In order for the parser and the signer to work, the authorization headers need +to be passed between the request and the context. `ToHTTPContext()`, +`FromHTTPContext()`, `ToGRPCContext()`, and `FromGRPCContext()` are given as +helpers to do this. These functions implement the correlating transport's +RequestFunc interface and can be passed as ClientBefore or ServerBefore +options. + +Example of use in a client: + +```go +import ( + stdjwt "github.com/dgrijalva/jwt-go" + + grpctransport "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/auth/jwt" + "github.com/go-kit/kit/endpoint" +) + +func main() { + + options := []httptransport.ClientOption{} + var exampleEndpoint endpoint.Endpoint + { + exampleEndpoint = grpctransport.NewClient(..., grpctransport.ClientBefore(jwt.FromGRPCContext())).Endpoint() + exampleEndpoint = jwt.NewSigner( + "kid-header", + []byte("SigningString"), + stdjwt.SigningMethodHS256, + jwt.Claims{}, + )(exampleEndpoint) + } +} +``` + +Example of use in a server: + +```go +import ( + "golang.org/x/net/context" + + "github.com/go-kit/kit/auth/jwt" + "github.com/go-kit/kit/log" + grpctransport "github.com/go-kit/kit/transport/grpc" +) + +func MakeGRPCServer(ctx context.Context, endpoints Endpoints, logger log.Logger) pb.ExampleServer { + options := []grpctransport.ServerOption{grpctransport.ServerErrorLogger(logger)} + + return &grpcServer{ + createUser: grpctransport.NewServer( + ctx, + endpoints.CreateUserEndpoint, + DecodeGRPCCreateUserRequest, + EncodeGRPCCreateUserResponse, + append(options, grpctransport.ServerBefore(jwt.ToGRPCContext()))..., + ), + getUser: grpctransport.NewServer( + ctx, + endpoints.GetUserEndpoint, + DecodeGRPCGetUserRequest, + EncodeGRPCGetUserResponse, + options..., + ), + } +} +``` diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go new file mode 100644 index 0000000..8b5f826 --- /dev/null +++ b/auth/jwt/middleware.go @@ -0,0 +1,122 @@ +package jwt + +import ( + "errors" + + jwt "github.com/dgrijalva/jwt-go" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +type contextKey string + +const ( + // JWTTokenContextKey holds the key used to store a JWT Token in the + // context. + JWTTokenContextKey contextKey = "JWTToken" + // JWTClaimsContxtKey holds the key used to store the JWT Claims in the + // context. + JWTClaimsContextKey contextKey = "JWTClaims" +) + +var ( + // ErrTokenContextMissing denotes a token was not passed into the parsing + // middleware's context. + ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context") + // ErrTokenInvalid denotes a token was not able to be validated. + ErrTokenInvalid = errors.New("JWT Token was invalid") + // ErrTokenExpired denotes a token's expire header (exp) has since passed. + ErrTokenExpired = errors.New("JWT Token is expired") + // ErrTokenMalformed denotes a token was not formatted as a JWT token. + ErrTokenMalformed = errors.New("JWT Token is malformed") + // ErrTokenNotActive denotes a token's not before header (nbf) is in the + // future. + ErrTokenNotActive = errors.New("token is not valid yet") + // ErrUncesptedSigningMethod denotes a token was signed with an unexpected + // signing method. + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") +) + +type Claims map[string]interface{} + +// NewSigner creates a new JWT token generating middleware, specifying key ID, +// signing string, signing method and the claims you would like it to contain. +// Tokens are signed with a Key ID header (kid) which is useful for determining +// the key to use for parsing. Particularly useful for clients. +func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + token := jwt.NewWithClaims(method, jwt.MapClaims(claims)) + token.Header["kid"] = kid + + // Sign and get the complete encoded token as a string using the secret + tokenString, err := token.SignedString(key) + if err != nil { + return nil, err + } + ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString) + + return next(ctx, request) + } + } +} + +// NewParser creates a new JWT token parsing middleware, specifying a +// jwt.Keyfunc interface and the signing method. NewParser adds the resulting +// claims to endpoint context or returns error on invalid token. Particularly +// useful for servers. +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + // tokenString is stored in the context from the transport handlers. + tokenString, ok := ctx.Value(JWTTokenContextKey).(string) + if !ok { + return nil, ErrTokenContextMissing + } + + // Parse takes the token string and a function for looking up the + // key. The latter is especially useful if you use multiple keys + // for your application. The standard is to use 'kid' in the head + // of the token to identify which key to use, but the parsed token + // (head and claims) is provided to the callback, providing + // flexibility. + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if token.Method != method { + return nil, ErrUnexpectedSigningMethod + } + + return keyFunc(token) + }) + if err != nil { + if e, ok := err.(*jwt.ValidationError); ok && e.Inner != nil { + if e.Errors&jwt.ValidationErrorMalformed != 0 { + // Token is malformed + return nil, ErrTokenMalformed + } else if e.Errors&jwt.ValidationErrorExpired != 0 { + // Token is expired + return nil, ErrTokenExpired + } else if e.Errors&jwt.ValidationErrorNotValidYet != 0 { + // Token is not active yet + return nil, ErrTokenNotActive + } + + return nil, e.Inner + } + + return nil, err + } + + if !token.Valid { + return nil, ErrTokenInvalid + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok { + ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims)) + } + + return next(ctx, request) + } + } +} diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go new file mode 100644 index 0000000..46bae68 --- /dev/null +++ b/auth/jwt/middleware_test.go @@ -0,0 +1,106 @@ +package jwt + +import ( + "testing" + + jwt "github.com/dgrijalva/jwt-go" + + "golang.org/x/net/context" +) + +var ( + kid = "kid" + key = []byte("test_signing_key") + method = jwt.SigningMethodHS256 + invalidMethod = jwt.SigningMethodRS256 + claims = Claims{"user": "go-kit"} + // Signed tokens generated at https://jwt.io/ + signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" + invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" +) + +func TestSigner(t *testing.T) { + e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } + + signer := NewSigner(kid, key, method, claims)(e) + ctx, err := signer(context.Background(), struct{}{}) + if err != nil { + t.Fatalf("Signer returned error: %s", err) + } + + token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string) + if !ok { + t.Fatal("Token did not exist in context") + } + + if token != signedKey { + t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token) + } +} + +func TestJWTParser(t *testing.T) { + e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } + + keys := func(token *jwt.Token) (interface{}, error) { + return key, nil + } + + parser := NewParser(keys, method)(e) + + // No Token is passed into the parser + _, err := parser(context.Background(), struct{}{}) + if err == nil { + t.Error("Parser should have returned an error") + } + + if err != ErrTokenContextMissing { + t.Errorf("unexpected error returned, expected: %s got: %s", ErrTokenContextMissing, err) + } + + // Invalid Token is passed into the parser + ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey) + _, err = parser(ctx, struct{}{}) + if err == nil { + t.Error("Parser should have returned an error") + } + + // Invalid Method is used in the parser + badParser := NewParser(keys, invalidMethod)(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + _, err = badParser(ctx, struct{}{}) + if err == nil { + t.Error("Parser should have returned an error") + } + + if err != ErrUnexpectedSigningMethod { + t.Errorf("unexpected error returned, expected: %s got: %s", ErrUnexpectedSigningMethod, err) + } + + // Invalid key is used in the parser + invalidKeys := func(token *jwt.Token) (interface{}, error) { + return []byte("bad"), nil + } + + badParser = NewParser(invalidKeys, method)(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + _, err = badParser(ctx, struct{}{}) + if err == nil { + t.Error("Parser should have returned an error") + } + + // Correct token is passed into the parser + ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + ctx1, err := parser(ctx, struct{}{}) + if err != nil { + t.Fatalf("Parser returned error: %s", err) + } + + cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims) + if !ok { + t.Fatal("Claims were not passed into context correctly") + } + + if cl["user"] != claims["user"] { + t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"]) + } +} diff --git a/auth/jwt/transport.go b/auth/jwt/transport.go new file mode 100644 index 0000000..e496993 --- /dev/null +++ b/auth/jwt/transport.go @@ -0,0 +1,89 @@ +package jwt + +import ( + "fmt" + stdhttp "net/http" + "strings" + + "golang.org/x/net/context" + "google.golang.org/grpc/metadata" + + "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/transport/http" +) + +const ( + bearer string = "bearer" + bearerFormat string = "Bearer %s" +) + +// ToHTTPContext moves JWT token from request header to contexti. Particularly +// useful for servers +func ToHTTPContext() http.RequestFunc { + return func(ctx context.Context, r *stdhttp.Request) context.Context { + token, ok := extractTokenFromAuthHeader(r.Header.Get("Authorization")) + if !ok { + return ctx + } + + return context.WithValue(ctx, JWTTokenContextKey, token) + } +} + +// FromHTTPContext moves JWT token from context to request header. Particularly +// useful for clients +func FromHTTPContext() http.RequestFunc { + return func(ctx context.Context, r *stdhttp.Request) context.Context { + token, ok := ctx.Value(JWTTokenContextKey).(string) + if ok { + r.Header.Add("Authorization", generateAuthHeaderFromToken(token)) + } + return ctx + } +} + +// ToGRPCContext moves JWT token from grpc metadata to context. Particularly +// userful for servers +func ToGRPCContext() grpc.RequestFunc { + return func(ctx context.Context, md *metadata.MD) context.Context { + // capital "Key" is illegal in HTTP/2. + authHeader, ok := (*md)["authorization"] + if !ok { + return ctx + } + + token, ok := extractTokenFromAuthHeader(authHeader[0]) + if ok { + ctx = context.WithValue(ctx, JWTTokenContextKey, token) + } + + return ctx + } +} + +// FromGRPCContext moves JWT token from context to grpc metadata. Particularly +// useful for clients +func FromGRPCContext() grpc.RequestFunc { + return func(ctx context.Context, md *metadata.MD) context.Context { + token, ok := ctx.Value(JWTTokenContextKey).(string) + if ok { + // capital "Key" is illegal in HTTP/2. + (*md)["authorization"] = []string{generateAuthHeaderFromToken(token)} + } + + return ctx + } +} + +func extractTokenFromAuthHeader(val string) (token string, ok bool) { + authHeaderParts := strings.Split(val, " ") + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != bearer { + return "", false + } + + return authHeaderParts[1], true +} + +func generateAuthHeaderFromToken(token string) string { + return fmt.Sprintf(bearerFormat, token) +} diff --git a/auth/jwt/transport_test.go b/auth/jwt/transport_test.go new file mode 100644 index 0000000..829d87f --- /dev/null +++ b/auth/jwt/transport_test.go @@ -0,0 +1,126 @@ +package jwt + +import ( + "fmt" + "net/http" + "testing" + + "google.golang.org/grpc/metadata" + + "golang.org/x/net/context" +) + +func TestToHTTPContext(t *testing.T) { + reqFunc := ToHTTPContext() + + // When the header doesn't exist + ctx := reqFunc(context.Background(), &http.Request{}) + + if ctx.Value(JWTTokenContextKey) != nil { + t.Error("Context shouldn't contain the encoded JWT") + } + + // Authorization header value has invalid format + header := http.Header{} + header.Set("Authorization", "no expected auth header format value") + ctx = reqFunc(context.Background(), &http.Request{Header: header}) + + if ctx.Value(JWTTokenContextKey) != nil { + t.Error("Context shouldn't contain the encoded JWT") + } + + // Authorization header is correct + header.Set("Authorization", generateAuthHeaderFromToken(signedKey)) + ctx = reqFunc(context.Background(), &http.Request{Header: header}) + + token := ctx.Value(JWTTokenContextKey).(string) + if token != signedKey { + t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", signedKey, token) + } +} + +func TestFromHTTPContext(t *testing.T) { + reqFunc := FromHTTPContext() + + // No JWT Token is passed in the context + ctx := context.Background() + r := http.Request{} + reqFunc(ctx, &r) + + token := r.Header.Get("Authorization") + if token != "" { + t.Error("authorization key should not exist in metadata") + } + + // Correct JWT Token is passed in the context + ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + r = http.Request{Header: http.Header{}} + reqFunc(ctx, &r) + + token = r.Header.Get("Authorization") + expected := generateAuthHeaderFromToken(signedKey) + + if token != expected { + t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token) + } +} + +func TestToGRPCContext(t *testing.T) { + md := metadata.MD{} + reqFunc := ToGRPCContext() + + // No Authorization header is passed + ctx := reqFunc(context.Background(), &md) + token := ctx.Value(JWTTokenContextKey) + if token != nil { + t.Error("Context should not contain a JWT Token") + } + + // Invalid Authorization header is passed + md["authorization"] = []string{fmt.Sprintf("%s", signedKey)} + ctx = reqFunc(context.Background(), &md) + token = ctx.Value(JWTTokenContextKey) + if token != nil { + t.Error("Context should not contain a JWT Token") + } + + // Authorization header is correct + md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} + ctx = reqFunc(context.Background(), &md) + token, ok := ctx.Value(JWTTokenContextKey).(string) + if !ok { + t.Fatal("JWT Token not passed to context correctly") + } + + if token != signedKey { + t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token) + } +} + +func TestFromGRPCContext(t *testing.T) { + reqFunc := FromGRPCContext() + + // No JWT Token is passed in the context + ctx := context.Background() + md := metadata.MD{} + reqFunc(ctx, &md) + + _, ok := md["authorization"] + if ok { + t.Error("authorization key should not exist in metadata") + } + + // Correct JWT Token is passed in the context + ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + md = metadata.MD{} + reqFunc(ctx, &md) + + token, ok := md["authorization"] + if !ok { + t.Fatal("JWT Token not passed to metadata correctly") + } + + if token[0] != generateAuthHeaderFromToken(signedKey) { + t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token[0]) + } +}