Codebase list golang-github-go-kit-kit / 57c36a7 ratelimit / token_bucket.go
57c36a7

Tree @57c36a7 (Download .tar.gz)

token_bucket.go @57c36a7raw · history · blame

package ratelimit

import (
	"errors"
	"time"

	juju "github.com/juju/ratelimit"
	"golang.org/x/net/context"

	"github.com/go-kit/kit/endpoint"
)

// ErrLimited is returned in the request path when the rate limiter is
// triggered and the request is rejected.
var ErrLimited = errors.New("rate limit exceeded")

// NewTokenBucketLimiter returns an endpoint.Middleware that acts as a rate
// limiter based on a token-bucket algorithm. Requests that would exceed the
// maximum request rate are simply rejected with an error.
func NewTokenBucketLimiter(options ...TokenBucketLimiterOption) endpoint.Middleware {
	limiter := tokenBucketLimiter{
		rate:     100,
		capacity: 100,
		take:     1,
	}
	for _, option := range options {
		option(&limiter)
	}
	tb := juju.NewBucketWithRate(limiter.rate, limiter.capacity)
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (interface{}, error) {
			if tb.TakeAvailable(limiter.take) == 0 {
				return nil, ErrLimited
			}
			return next(ctx, request)
		}
	}
}

type tokenBucketLimiter struct {
	rate     float64
	capacity int64
	take     int64
}

// TokenBucketLimiterOption sets a parameter on the TokenBucketLimiter.
type TokenBucketLimiterOption func(*tokenBucketLimiter)

// TokenBucketLimiterRate sets the rate (per second) at which tokens are
// replenished into the bucket. For most use cases, this should be the same as
// the capacity. By default, the rate is 100.
func TokenBucketLimiterRate(rate float64) TokenBucketLimiterOption {
	return func(tb *tokenBucketLimiter) { tb.rate = rate }
}

// TokenBucketLimiterCapacity sets the maximum number of tokens that the
// bucket will hold. For most use cases, this should be the same as the rate.
// By default, the capacity is 100.
func TokenBucketLimiterCapacity(capacity int64) TokenBucketLimiterOption {
	return func(tb *tokenBucketLimiter) { tb.capacity = capacity }
}

// TokenBucketLimiterTake sets the number of tokens that will be taken from
// the bucket with each request. By default, this is 1.
func TokenBucketLimiterTake(take int64) TokenBucketLimiterOption {
	return func(tb *tokenBucketLimiter) { tb.take = take }
}

// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a
// request throttler based on a token-bucket algorithm. Requests that would
// exceed the maximum request rate are delayed via a parameterized sleep
// function.
func NewTokenBucketThrottler(options ...TokenBucketThrottlerOption) endpoint.Middleware {
	throttler := tokenBucketThrottler{
		tokenBucketLimiter: tokenBucketLimiter{
			rate:     100,
			capacity: 100,
			take:     1,
		},
		sleep: time.Sleep,
	}
	for _, option := range options {
		option(&throttler)
	}
	tb := juju.NewBucketWithRate(throttler.rate, throttler.capacity)
	return func(next endpoint.Endpoint) endpoint.Endpoint {
		return func(ctx context.Context, request interface{}) (interface{}, error) {
			throttler.sleep(tb.Take(throttler.take))
			return next(ctx, request)
		}
	}
}

type tokenBucketThrottler struct {
	tokenBucketLimiter
	sleep func(time.Duration)
}

// TokenBucketThrottlerOption sets a parameter on the TokenBucketThrottler.
type TokenBucketThrottlerOption func(*tokenBucketThrottler)

// TokenBucketThrottlerRate sets the rate (per second) at which tokens are
// replenished into the bucket. For most use cases, this should be the same as
// the capacity. By default, the rate is 100.
func TokenBucketThrottlerRate(rate float64) TokenBucketThrottlerOption {
	return func(tb *tokenBucketThrottler) { tb.rate = rate }
}

// TokenBucketThrottlerCapacity sets the maximum number of tokens that the
// bucket will hold. For most use cases, this should be the same as the rate.
// By default, the capacity is 100.
func TokenBucketThrottlerCapacity(capacity int64) TokenBucketThrottlerOption {
	return func(tb *tokenBucketThrottler) { tb.capacity = capacity }
}

// TokenBucketThrottlerTake sets the number of tokens that will be taken from
// the bucket with each request. By default, this is 1.
func TokenBucketThrottlerTake(take int64) TokenBucketThrottlerOption {
	return func(tb *tokenBucketThrottler) { tb.take = take }
}

// TokenBucketThrottlerSleep sets the sleep function that's invoked to
// throttle requests. By default, it's time.Sleep.
func TokenBucketThrottlerSleep(sleep func(time.Duration)) TokenBucketThrottlerOption {
	return func(tb *tokenBucketThrottler) { tb.sleep = sleep }
}