diff --git a/circuitbreaker/gobreaker.go b/circuitbreaker/gobreaker.go index f3a210d..b00de95 100644 --- a/circuitbreaker/gobreaker.go +++ b/circuitbreaker/gobreaker.go @@ -12,8 +12,7 @@ // the wrapped endpoint count against the circuit breaker's error count. // // See http://godoc.org/github.com/sony/gobreaker for more information. -func Gobreaker(settings gobreaker.Settings) endpoint.Middleware { - cb := gobreaker.NewCircuitBreaker(settings) +func Gobreaker(cb *gobreaker.CircuitBreaker) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { return cb.Execute(func() (interface{}, error) { return next(ctx, request) }) diff --git a/circuitbreaker/gobreaker_test.go b/circuitbreaker/gobreaker_test.go index 6425e61..448c6fb 100644 --- a/circuitbreaker/gobreaker_test.go +++ b/circuitbreaker/gobreaker_test.go @@ -10,7 +10,7 @@ func TestGobreaker(t *testing.T) { var ( - breaker = circuitbreaker.Gobreaker(gobreaker.Settings{}) + breaker = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{})) primeWith = 100 shouldPass = func(n int) bool { return n <= 5 } // https://github.com/sony/gobreaker/blob/bfa846d/gobreaker.go#L76 circuitOpenError = "circuit breaker is open" diff --git a/circuitbreaker/handy_breaker.go b/circuitbreaker/handy_breaker.go index 6e1584d..5875d4f 100644 --- a/circuitbreaker/handy_breaker.go +++ b/circuitbreaker/handy_breaker.go @@ -16,19 +16,18 @@ // // See http://godoc.org/github.com/streadway/handy/breaker for more // information. -func HandyBreaker(failureRatio float64) endpoint.Middleware { - b := breaker.NewBreaker(failureRatio) +func HandyBreaker(cb breaker.Breaker) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - if !b.Allow() { + if !cb.Allow() { return nil, breaker.ErrCircuitOpen } defer func(begin time.Time) { if err == nil { - b.Success(time.Since(begin)) + cb.Success(time.Since(begin)) } else { - b.Failure(time.Since(begin)) + cb.Failure(time.Since(begin)) } }(time.Now()) diff --git a/circuitbreaker/handy_breaker_test.go b/circuitbreaker/handy_breaker_test.go index aa520ec..dc2d615 100644 --- a/circuitbreaker/handy_breaker_test.go +++ b/circuitbreaker/handy_breaker_test.go @@ -11,7 +11,7 @@ func TestHandyBreaker(t *testing.T) { var ( failureRatio = 0.05 - breaker = circuitbreaker.HandyBreaker(failureRatio) + breaker = circuitbreaker.HandyBreaker(handybreaker.NewBreaker(failureRatio)) primeWith = handybreaker.DefaultMinObservations * 10 shouldPass = func(n int) bool { return (float64(n) / float64(primeWith+n)) <= failureRatio } openCircuitError = handybreaker.ErrCircuitOpen.Error()