Codebase list golang-github-go-kit-kit / e32d2aa auth / jwt / middleware.go
e32d2aa

Tree @e32d2aa (Download .tar.gz)

middleware.go @e32d2aaraw · history · blame

package jwt

import (
	"errors"

	"golang.org/x/net/context"

	jwt "github.com/dgrijalva/jwt-go"
	"github.com/go-kit/kit/endpoint"
)

const (
	// JWTTokenContextKey holds the key used to store a JWT Token in the context
	JWTTokenContextKey = "JWTToken"
	// JWTClaimsContxtKey holds the key used to store the JWT Claims in the context
	JWTClaimsContextKey = "JWTClaims"
)

var (
	ErrTokenContextMissing     = errors.New("Token up for parsing was not passed through the context")
	ErrTokenInvalid            = errors.New("JWT Token was invalid")
	ErrUnexpectedSigningMethod = errors.New("Unexpected signing method")
	ErrKIDNotFound             = errors.New("Key ID was not found in key set")
	ErrNoKIDHeader             = errors.New("Token doesn't have 'kid' header")
)

type Claims map[string]interface{}

type KeySet map[string]struct {
	Method jwt.SigningMethod
	Key    []byte
}

// Create a new JWT token generating middleware, specifying signing method and the claims
// you would like it to contain. Particularly useful for clients.
func NewSigner(kid string, keys KeySet, claims Claims) endpoint.Middleware {
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {
			key, ok := keys[kid]
			if !ok {
				return nil, ErrKIDNotFound
			}

			token := jwt.NewWithClaims(key.Method, jwt.MapClaims(claims))
			token.Header["kid"] = kid
			// Sign and get the complete encoded token as a string using the secret
			tokenString, err := token.SignedString(key.Key)
			if err != nil {
				return nil, err
			}
			ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString)

			return next(ctx, request)
		}
	}
}

// Create a new JWT token parsing middleware, specifying a jwt.Keyfunc interface and the
// signing method. Adds the resulting claims to endpoint context or returns error on invalid
// token. Particularly useful for servers.
func NewParser(keys KeySet) endpoint.Middleware {
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (response interface{}, err error) {
			// tokenString is stored in the context from the transport handlers
			tokenString, ok := ctx.Value(JWTTokenContextKey).(string)
			if !ok {
				return nil, ErrTokenContextMissing
			}

			// Parse takes the token string and a function for looking up the key. The latter is especially
			// useful if you use multiple keys for your application.  The standard is to use 'kid' in the
			// head of the token to identify which key to use, but the parsed token (head and claims) is provided
			// to the callback, providing flexibility.
			token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
				kid, ok := token.Header["kid"]
				if !ok {
					return nil, ErrNoKIDHeader
				}

				key, ok := keys[kid.(string)]
				if !ok {
					return nil, ErrKIDNotFound
				}

				// Don't forget to validate the alg is what you expect:
				if token.Method != key.Method {
					return nil, ErrUnexpectedSigningMethod
				}

				return key.Key, nil
			})
			if err != nil {
				if e, ok := err.(*jwt.ValidationError); ok && e.Inner != nil {
					return nil, e.Inner
				}

				return nil, err
			}

			if !token.Valid {
				return nil, ErrTokenInvalid
			}

			if claims, ok := token.Claims.(jwt.MapClaims); ok {
				ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims))
			}

			return next(ctx, request)
		}
	}
}