diff --git a/ratelimit/token_bucket.go b/ratelimit/token_bucket.go index f4bb16d..d829cca 100644 --- a/ratelimit/token_bucket.go +++ b/ratelimit/token_bucket.go @@ -4,64 +4,123 @@ "errors" "time" - "github.com/tsenart/tb" + juju "github.com/juju/ratelimit" "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" ) -// ErrThrottled is returned in the request path when the rate limiter is +// ErrLimited is returned in the request path when the rate limiter is // triggered and the request is rejected. -var ErrThrottled = errors.New("throttled") +var ErrLimited = errors.New("rate limit exceeded") -// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a rate -// limiter based on a "token-bucket" algorithm. Requests that would exceed the -// maximum request rate are rejected with an error. -func NewTokenBucketThrottler(options ...TokenBucketOption) endpoint.Middleware { - t := tokenBucketThrottler{ - freq: 100 * time.Millisecond, - key: "", - rate: 100, - take: 1, +// NewTokenBucketLimiter returns an endpoint.Middleware that acts as a rate +// limiter based on a token-bucket algorithm. Requests that would exceed the +// maximum request rate are simply rejected with an error. +func NewTokenBucketLimiter(options ...TokenBucketLimiterOption) endpoint.Middleware { + limiter := tokenBucketLimiter{ + rate: 100, + capacity: 100, + take: 1, } for _, option := range options { - option(&t) + option(&limiter) } - throttler := tb.NewThrottler(t.freq) + tb := juju.NewBucketWithRate(limiter.rate, limiter.capacity) return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - if throttler.Halt(t.key, t.take, t.rate) { - return nil, ErrThrottled + if tb.TakeAvailable(limiter.take) == 0 { + return nil, ErrLimited } return next(ctx, request) } } } -type tokenBucketThrottler struct { - freq time.Duration - key string - rate int64 - take int64 +type tokenBucketLimiter struct { + rate float64 + capacity int64 + take int64 } -// TokenBucketOption sets an option on the token bucket throttler. -type TokenBucketOption func(*tokenBucketThrottler) +// TokenBucketLimiterOption sets a parameter on the TokenBucketLimiter. +type TokenBucketLimiterOption func(*tokenBucketLimiter) -// TokenBucketFillFrequency sets the interval at which tokens are replenished -// into the bucket. By default, it's 100 milliseconds. -func TokenBucketFillFrequency(freq time.Duration) TokenBucketOption { - return func(t *tokenBucketThrottler) { t.freq = freq } +// TokenBucketLimiterRate sets the rate (per second) at which tokens are +// replenished into the bucket. For most use cases, this should be the same as +// the capacity. By default, the rate is 100. +func TokenBucketLimiterRate(rate float64) TokenBucketLimiterOption { + return func(tb *tokenBucketLimiter) { tb.rate = rate } } -// TokenBucketMaxRate sets the maximum allowed request rate. -// By default, it's 100. -func TokenBucketMaxRate(rate int64) TokenBucketOption { - return func(t *tokenBucketThrottler) { t.rate = rate } +// TokenBucketLimiterCapacity sets the maximum number of tokens that the +// bucket will hold. For most use cases, this should be the same as the rate. +// By default, the capacity is 100. +func TokenBucketLimiterCapacity(capacity int64) TokenBucketLimiterOption { + return func(tb *tokenBucketLimiter) { tb.capacity = capacity } } -// TokenBucketTake sets the number of tokens taken with each request. -// By default, it's 1. -func TokenBucketTake(take int64) TokenBucketOption { - return func(t *tokenBucketThrottler) { t.take = take } +// TokenBucketLimiterTake sets the number of tokens that will be taken from +// the bucket with each request. By default, this is 1. +func TokenBucketLimiterTake(take int64) TokenBucketLimiterOption { + return func(tb *tokenBucketLimiter) { tb.take = take } } + +// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a +// request throttler based on a token-bucket algorithm. Requests that would +// exceed the maximum request rate are delayed via a parameterized sleep +// function. +func NewTokenBucketThrottler(options ...TokenBucketThrottlerOption) endpoint.Middleware { + throttler := tokenBucketThrottler{ + tokenBucketLimiter: tokenBucketLimiter{ + rate: 100, + capacity: 100, + take: 1, + }, + sleep: time.Sleep, + } + for _, option := range options { + option(&throttler) + } + tb := juju.NewBucketWithRate(throttler.rate, throttler.capacity) + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + throttler.sleep(tb.Take(throttler.take)) + return next(ctx, request) + } + } +} + +type tokenBucketThrottler struct { + tokenBucketLimiter + sleep func(time.Duration) +} + +// TokenBucketThrottlerOption sets a parameter on the TokenBucketThrottler. +type TokenBucketThrottlerOption func(*tokenBucketThrottler) + +// TokenBucketThrottlerRate sets the rate (per second) at which tokens are +// replenished into the bucket. For most use cases, this should be the same as +// the capacity. By default, the rate is 100. +func TokenBucketThrottlerRate(rate float64) TokenBucketThrottlerOption { + return func(tb *tokenBucketThrottler) { tb.rate = rate } +} + +// TokenBucketThrottlerCapacity sets the maximum number of tokens that the +// bucket will hold. For most use cases, this should be the same as the rate. +// By default, the capacity is 100. +func TokenBucketThrottlerCapacity(capacity int64) TokenBucketThrottlerOption { + return func(tb *tokenBucketThrottler) { tb.capacity = capacity } +} + +// TokenBucketThrottlerTake sets the number of tokens that will be taken from +// the bucket with each request. By default, this is 1. +func TokenBucketThrottlerTake(take int64) TokenBucketThrottlerOption { + return func(tb *tokenBucketThrottler) { tb.take = take } +} + +// TokenBucketThrottlerSleep sets the sleep function that's invoked to +// throttle requests. By default, it's time.Sleep. +func TokenBucketThrottlerSleep(sleep func(time.Duration)) TokenBucketThrottlerOption { + return func(tb *tokenBucketThrottler) { tb.sleep = sleep } +} diff --git a/ratelimit/token_bucket_test.go b/ratelimit/token_bucket_test.go index 36a27a6..96f2231 100644 --- a/ratelimit/token_bucket_test.go +++ b/ratelimit/token_bucket_test.go @@ -1,7 +1,9 @@ package ratelimit_test import ( + "math" "testing" + "time" "golang.org/x/net/context" @@ -9,21 +11,50 @@ "github.com/go-kit/kit/ratelimit" ) -func TestTokenBucketThrottler(t *testing.T) { +func TestTokenBucketLimiter(t *testing.T) { e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - testRateLimit(t, ratelimit.NewTokenBucketThrottler(ratelimit.TokenBucketMaxRate(0))(e), 0) // all fail - testRateLimit(t, ratelimit.NewTokenBucketThrottler(ratelimit.TokenBucketMaxRate(1))(e), 1) // first pass - testRateLimit(t, ratelimit.NewTokenBucketThrottler(ratelimit.TokenBucketMaxRate(100))(e), 100) // 100 pass + for _, n := range []int{1, 2, 100} { + testLimiter(t, ratelimit.NewTokenBucketLimiter( + ratelimit.TokenBucketLimiterRate(float64(n)), + ratelimit.TokenBucketLimiterCapacity(int64(n)), + )(e), int(n)) + } } -func testRateLimit(t *testing.T, e endpoint.Endpoint, rate int) { - ctx := context.Background() +func TestTokenBucketThrottler(t *testing.T) { + d := time.Duration(0) + s := func(d0 time.Duration) { d = d0 } + + e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + e = ratelimit.NewTokenBucketThrottler( + ratelimit.TokenBucketThrottlerRate(1), + ratelimit.TokenBucketThrottlerCapacity(1), + ratelimit.TokenBucketThrottlerSleep(s), + )(e) + + // First request should go through with no delay. + e(context.Background(), struct{}{}) + if want, have := time.Duration(0), d; want != have { + t.Errorf("want %s, have %s", want, have) + } + + // Next request should request a ~1s sleep. + e(context.Background(), struct{}{}) + if want, have, tol := time.Second, d, time.Millisecond; math.Abs(float64(want-have)) > float64(tol) { + t.Errorf("want %s, have %s", want, have) + } +} + +func testLimiter(t *testing.T, e endpoint.Endpoint, rate int) { + // First requests should succeed. for i := 0; i < rate; i++ { - if _, err := e(ctx, struct{}{}); err != nil { + if _, err := e(context.Background(), struct{}{}); err != nil { t.Fatalf("rate=%d: request %d/%d failed: %v", rate, i+1, rate, err) } } - if _, err := e(ctx, struct{}{}); err != ratelimit.ErrThrottled { - t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrThrottled, err) + + // Next request should fail. + if _, err := e(context.Background(), struct{}{}); err != ratelimit.ErrLimited { + t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrLimited, err) } }