diff --git a/ratelimit/token_bucket.go b/ratelimit/token_bucket.go index d829cca..48a4f60 100644 --- a/ratelimit/token_bucket.go +++ b/ratelimit/token_bucket.go @@ -4,7 +4,7 @@ "errors" "time" - juju "github.com/juju/ratelimit" + "github.com/juju/ratelimit" "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" @@ -17,19 +17,10 @@ // NewTokenBucketLimiter returns an endpoint.Middleware that acts as a rate // limiter based on a token-bucket algorithm. Requests that would exceed the // maximum request rate are simply rejected with an error. -func NewTokenBucketLimiter(options ...TokenBucketLimiterOption) endpoint.Middleware { - limiter := tokenBucketLimiter{ - rate: 100, - capacity: 100, - take: 1, - } - for _, option := range options { - option(&limiter) - } - tb := juju.NewBucketWithRate(limiter.rate, limiter.capacity) +func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - if tb.TakeAvailable(limiter.take) == 0 { + if tb.TakeAvailable(1) == 0 { return nil, ErrLimited } return next(ctx, request) @@ -37,90 +28,15 @@ } } -type tokenBucketLimiter struct { - rate float64 - capacity int64 - take int64 -} - -// TokenBucketLimiterOption sets a parameter on the TokenBucketLimiter. -type TokenBucketLimiterOption func(*tokenBucketLimiter) - -// TokenBucketLimiterRate sets the rate (per second) at which tokens are -// replenished into the bucket. For most use cases, this should be the same as -// the capacity. By default, the rate is 100. -func TokenBucketLimiterRate(rate float64) TokenBucketLimiterOption { - return func(tb *tokenBucketLimiter) { tb.rate = rate } -} - -// TokenBucketLimiterCapacity sets the maximum number of tokens that the -// bucket will hold. For most use cases, this should be the same as the rate. -// By default, the capacity is 100. -func TokenBucketLimiterCapacity(capacity int64) TokenBucketLimiterOption { - return func(tb *tokenBucketLimiter) { tb.capacity = capacity } -} - -// TokenBucketLimiterTake sets the number of tokens that will be taken from -// the bucket with each request. By default, this is 1. -func TokenBucketLimiterTake(take int64) TokenBucketLimiterOption { - return func(tb *tokenBucketLimiter) { tb.take = take } -} - // NewTokenBucketThrottler returns an endpoint.Middleware that acts as a // request throttler based on a token-bucket algorithm. Requests that would -// exceed the maximum request rate are delayed via a parameterized sleep -// function. -func NewTokenBucketThrottler(options ...TokenBucketThrottlerOption) endpoint.Middleware { - throttler := tokenBucketThrottler{ - tokenBucketLimiter: tokenBucketLimiter{ - rate: 100, - capacity: 100, - take: 1, - }, - sleep: time.Sleep, - } - for _, option := range options { - option(&throttler) - } - tb := juju.NewBucketWithRate(throttler.rate, throttler.capacity) +// exceed the maximum request rate are delayed via the parameterized sleep +// function. By default you may pass time.Sleep. +func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - throttler.sleep(tb.Take(throttler.take)) + sleep(tb.Take(1)) return next(ctx, request) } } } - -type tokenBucketThrottler struct { - tokenBucketLimiter - sleep func(time.Duration) -} - -// TokenBucketThrottlerOption sets a parameter on the TokenBucketThrottler. -type TokenBucketThrottlerOption func(*tokenBucketThrottler) - -// TokenBucketThrottlerRate sets the rate (per second) at which tokens are -// replenished into the bucket. For most use cases, this should be the same as -// the capacity. By default, the rate is 100. -func TokenBucketThrottlerRate(rate float64) TokenBucketThrottlerOption { - return func(tb *tokenBucketThrottler) { tb.rate = rate } -} - -// TokenBucketThrottlerCapacity sets the maximum number of tokens that the -// bucket will hold. For most use cases, this should be the same as the rate. -// By default, the capacity is 100. -func TokenBucketThrottlerCapacity(capacity int64) TokenBucketThrottlerOption { - return func(tb *tokenBucketThrottler) { tb.capacity = capacity } -} - -// TokenBucketThrottlerTake sets the number of tokens that will be taken from -// the bucket with each request. By default, this is 1. -func TokenBucketThrottlerTake(take int64) TokenBucketThrottlerOption { - return func(tb *tokenBucketThrottler) { tb.take = take } -} - -// TokenBucketThrottlerSleep sets the sleep function that's invoked to -// throttle requests. By default, it's time.Sleep. -func TokenBucketThrottlerSleep(sleep func(time.Duration)) TokenBucketThrottlerOption { - return func(tb *tokenBucketThrottler) { tb.sleep = sleep } -} diff --git a/ratelimit/token_bucket_test.go b/ratelimit/token_bucket_test.go index 96f2231..6f815de 100644 --- a/ratelimit/token_bucket_test.go +++ b/ratelimit/token_bucket_test.go @@ -5,6 +5,7 @@ "testing" "time" + jujuratelimit "github.com/juju/ratelimit" "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" @@ -14,10 +15,8 @@ func TestTokenBucketLimiter(t *testing.T) { e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } for _, n := range []int{1, 2, 100} { - testLimiter(t, ratelimit.NewTokenBucketLimiter( - ratelimit.TokenBucketLimiterRate(float64(n)), - ratelimit.TokenBucketLimiterCapacity(int64(n)), - )(e), int(n)) + tb := jujuratelimit.NewBucketWithRate(float64(n), int64(n)) + testLimiter(t, ratelimit.NewTokenBucketLimiter(tb)(e), n) } } @@ -26,11 +25,7 @@ s := func(d0 time.Duration) { d = d0 } e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - e = ratelimit.NewTokenBucketThrottler( - ratelimit.TokenBucketThrottlerRate(1), - ratelimit.TokenBucketThrottlerCapacity(1), - ratelimit.TokenBucketThrottlerSleep(s), - )(e) + e = ratelimit.NewTokenBucketThrottler(jujuratelimit.NewBucketWithRate(1, 1), s)(e) // First request should go through with no delay. e(context.Background(), struct{}{})