Codebase list golang-github-go-kit-kit / c84463f auth / jwt / middleware.go
c84463f

Tree @c84463f (Download .tar.gz)

middleware.go @c84463fraw · history · blame

package jwt

import (
	"errors"

	"golang.org/x/net/context"

	jwt "github.com/dgrijalva/jwt-go"
	"github.com/go-kit/kit/endpoint"
)

type contextKey string

const (
	// JWTTokenContextKey holds the key used to store a JWT Token in the context
	JWTTokenContextKey contextKey = "JWTToken"
	// JWTClaimsContxtKey holds the key used to store the JWT Claims in the context
	JWTClaimsContextKey contextKey = "JWTClaims"
)

var (
	ErrTokenContextMissing     = errors.New("Token up for parsing was not passed through the context")
	ErrTokenInvalid            = errors.New("JWT Token was invalid")
	ErrTokenExpired            = errors.New("JWT Token is expired")
	ErrTokenMalformed          = errors.New("JWT Token is malformed")
	ErrTokenNotActive          = errors.New("Token is not valid yet")
	ErrUnexpectedSigningMethod = errors.New("Unexpected signing method")
)

type Claims map[string]interface{}

// Create a new JWT token generating middleware, specifying signing method and the claims
// you would like it to contain. Particularly useful for clients.
func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware {
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {
			token := jwt.NewWithClaims(method, jwt.MapClaims(claims))
			token.Header["kid"] = kid

			// Sign and get the complete encoded token as a string using the secret
			tokenString, err := token.SignedString(key)
			if err != nil {
				return nil, err
			}
			ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString)

			return next(ctx, request)
		}
	}
}

// Create a new JWT token parsing middleware, specifying a jwt.Keyfunc interface and the
// signing method. Adds the resulting claims to endpoint context or returns error on invalid
// token. Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware {
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {
			// tokenString is stored in the context from the transport handlers
			tokenString, ok := ctx.Value(JWTTokenContextKey).(string)
			if !ok {
				return nil, ErrTokenContextMissing
			}

			// Parse takes the token string and a function for looking up the key. The latter is especially
			// useful if you use multiple keys for your application.  The standard is to use 'kid' in the
			// head of the token to identify which key to use, but the parsed token (head and claims) is provided
			// to the callback, providing flexibility.
			token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
				// Don't forget to validate the alg is what you expect:
				if token.Method != method {
					return nil, ErrUnexpectedSigningMethod
				}

				return keyFunc(token)
			})
			if err != nil {
				if e, ok := err.(*jwt.ValidationError); ok && e.Inner != nil {
					if e.Errors&jwt.ValidationErrorMalformed != 0 {
						// Token is malformed
						return nil, ErrTokenMalformed
					} else if e.Errors&jwt.ValidationErrorExpired != 0 {
						// Token is expired
						return nil, ErrTokenExpired
					} else if e.Errors&jwt.ValidationErrorNotValidYet != 0 {
						// Token is not active yet
						return nil, ErrTokenNotActive
					}

					return nil, e.Inner
				}

				return nil, err
			}

			if !token.Valid {
				return nil, ErrTokenInvalid
			}

			if claims, ok := token.Claims.(jwt.MapClaims); ok {
				ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims))
			}

			return next(ctx, request)
		}
	}
}