diff --git a/circuitbreaker/gobreaker.go b/circuitbreaker/gobreaker.go new file mode 100644 index 0000000..f3a210d --- /dev/null +++ b/circuitbreaker/gobreaker.go @@ -0,0 +1,22 @@ +package circuitbreaker + +import ( + "github.com/sony/gobreaker" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +// Gobreaker returns an endpoint.Middleware that implements the circuit +// breaker pattern using the sony/gobreaker package. Only errors returned by +// the wrapped endpoint count against the circuit breaker's error count. +// +// See http://godoc.org/github.com/sony/gobreaker for more information. +func Gobreaker(settings gobreaker.Settings) endpoint.Middleware { + cb := gobreaker.NewCircuitBreaker(settings) + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + return cb.Execute(func() (interface{}, error) { return next(ctx, request) }) + } + } +} diff --git a/circuitbreaker/gobreaker_test.go b/circuitbreaker/gobreaker_test.go new file mode 100644 index 0000000..483e4a9 --- /dev/null +++ b/circuitbreaker/gobreaker_test.go @@ -0,0 +1,77 @@ +package circuitbreaker_test + +import ( + "errors" + "testing" + "time" + + "github.com/sony/gobreaker" + "golang.org/x/net/context" + + "github.com/go-kit/kit/circuitbreaker" + "github.com/go-kit/kit/endpoint" +) + +func TestGobreaker(t *testing.T) { + var ( + thru int + last gobreaker.State + myError = errors.New("❤️") + timeout = time.Millisecond + stateChange = func(_ string, from, to gobreaker.State) { last = to } + ) + + var e endpoint.Endpoint + e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } + e = circuitbreaker.Gobreaker(gobreaker.Settings{ + Timeout: timeout, + OnStateChange: stateChange, + })(e) + + // "Default ReadyToTrip returns true when the number of consecutive + // failures is more than 5." + // https://github.com/sony/gobreaker/blob/bfa846d/gobreaker.go#L76 + for i := 0; i < 5; i++ { + if _, err := e(context.Background(), struct{}{}); err != myError { + t.Errorf("want %v, have %v", myError, err) + } + } + + if want, have := 5, thru; want != have { + t.Errorf("want %d, have %d", want, have) + } + + e(context.Background(), struct{}{}) + if want, have := 6, thru; want != have { // got thru + t.Errorf("want %d, have %d", want, have) + } + if want, have := gobreaker.StateOpen, last; want != have { // tripped + t.Errorf("want %v, have %v", want, have) + } + + e(context.Background(), struct{}{}) + if want, have := 6, thru; want != have { // didn't get thru + t.Errorf("want %d, have %d", want, have) + } + + time.Sleep(2 * timeout) + + e(context.Background(), struct{}{}) + if want, have := 7, thru; want != have { // got thru via halfopen + t.Errorf("want %d, have %d", want, have) + } + if want, have := gobreaker.StateOpen, last; want != have { // re-tripped + t.Errorf("want %v, have %v", want, have) + } + + time.Sleep(2 * timeout) + + myError = nil + e(context.Background(), struct{}{}) + if want, have := 8, thru; want != have { // got thru via halfopen + t.Errorf("want %d, have %d", want, have) + } + if want, have := gobreaker.StateClosed, last; want != have { // now it's good + t.Errorf("want %v, have %v", want, have) + } +} diff --git a/circuitbreaker/handy_breaker.go b/circuitbreaker/handy_breaker.go new file mode 100644 index 0000000..93510a0 --- /dev/null +++ b/circuitbreaker/handy_breaker.go @@ -0,0 +1,44 @@ +package circuitbreaker + +import ( + "errors" + "time" + + "github.com/streadway/handy/breaker" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +// ErrCircuitBreakerOpen is returned when the HandyBreaker's circuit is open +// and the request is stopped from proceeding. +var ErrCircuitBreakerOpen = errors.New("circuit breaker open") + +// HandyBreaker returns an endpoint.Middleware that implements the circuit +// breaker pattern using the streadway/handy/breaker package. Only errors +// returned by the wrapped endpoint count against the circuit breaker's error +// count. +// +// See http://godoc.org/github.com/streadway/handy/breaker for more +// information. +func HandyBreaker(failureRatio float64) endpoint.Middleware { + b := breaker.NewBreaker(failureRatio) + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + if !b.Allow() { + return nil, ErrCircuitBreakerOpen + } + + defer func(begin time.Time) { + if err == nil { + b.Success(time.Since(begin)) + } else { + b.Failure(time.Since(begin)) + } + }(time.Now()) + + response, err = next(ctx, request) + return + } + } +} diff --git a/circuitbreaker/handy_breaker_test.go b/circuitbreaker/handy_breaker_test.go new file mode 100644 index 0000000..82d5d8b --- /dev/null +++ b/circuitbreaker/handy_breaker_test.go @@ -0,0 +1,59 @@ +package circuitbreaker_test + +import ( + "errors" + "testing" + + "github.com/streadway/handy/breaker" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/circuitbreaker" + "github.com/go-kit/kit/endpoint" +) + +func TestHandyBreaker(t *testing.T) { + var ( + thru = 0 + myError = error(nil) + ratio = 0.05 + primeWith = breaker.DefaultMinObservations * 10 + shouldPass = func(failed int) bool { return (float64(failed) / float64(primeWith+failed)) <= ratio } + extraTries = 10 + ) + + var e endpoint.Endpoint + e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } + e = circuitbreaker.HandyBreaker(ratio)(e) + + // Prime with some successes. + for i := 0; i < primeWith; i++ { + if _, err := e(context.Background(), struct{}{}); err != nil { + t.Fatal(err) + } + } + + // Now we start throwing errors. + myError = errors.New(":(") + + // The first few should get thru. + var letThru int + for i := 0; shouldPass(i); i++ { // off-by-one + letThru++ + if _, err := e(context.Background(), struct{}{}); err != myError { + t.Fatalf("want %v, have %v", myError, err) + } + } + + // But the rest should be blocked by an open circuit. + for i := 1; i <= extraTries; i++ { + if _, err := e(context.Background(), struct{}{}); err != circuitbreaker.ErrCircuitBreakerOpen { + t.Errorf("with request #%d, want %v, have %v", primeWith+letThru+i, circuitbreaker.ErrCircuitBreakerOpen, err) + } + } + + // Confirm the rest didn't get through. + if want, have := primeWith+letThru, thru; want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/circuitbreaker/sony_gobreaker.go b/circuitbreaker/sony_gobreaker.go deleted file mode 100644 index b92eb28..0000000 --- a/circuitbreaker/sony_gobreaker.go +++ /dev/null @@ -1,21 +0,0 @@ -package circuitbreaker - -import ( - "github.com/sony/gobreaker" - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" -) - -// NewSonyCircuitBreaker returns an endpoint.Middleware that permits the -// request if the underlying circuit breaker allows it. Only errors returned -// by the wrapped endpoint count against the circuit breaker's error count. -// See github.com/sony/gobreaker for more information. -func NewSonyCircuitBreaker(settings gobreaker.Settings) endpoint.Middleware { - cb := gobreaker.NewCircuitBreaker(settings) - return func(next endpoint.Endpoint) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - return cb.Execute(func() (interface{}, error) { return next(ctx, request) }) - } - } -} diff --git a/circuitbreaker/sony_gobreaker_test.go b/circuitbreaker/sony_gobreaker_test.go deleted file mode 100644 index e2b08dc..0000000 --- a/circuitbreaker/sony_gobreaker_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package circuitbreaker_test - -import ( - "errors" - "testing" - "time" - - "github.com/sony/gobreaker" - "golang.org/x/net/context" - - "github.com/go-kit/kit/circuitbreaker" - "github.com/go-kit/kit/endpoint" -) - -func TestSonyCircuitBreaker(t *testing.T) { - var ( - thru int - last gobreaker.State - myError = errors.New("❤️") - timeout = time.Millisecond - stateChange = func(_ string, from, to gobreaker.State) { last = to } - ) - - var e endpoint.Endpoint - e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } - e = circuitbreaker.NewSonyCircuitBreaker(gobreaker.Settings{ - Timeout: timeout, - OnStateChange: stateChange, - })(e) - - // "Default ReadyToTrip returns true when the number of consecutive - // failures is more than 5." - // https://github.com/sony/gobreaker/blob/bfa846d/gobreaker.go#L76 - for i := 0; i < 5; i++ { - if _, err := e(context.Background(), struct{}{}); err != myError { - t.Errorf("want %v, have %v", myError, err) - } - } - - if want, have := 5, thru; want != have { - t.Errorf("want %d, have %d", want, have) - } - - e(context.Background(), struct{}{}) - if want, have := 6, thru; want != have { // got thru - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateOpen, last; want != have { // tripped - t.Errorf("want %v, have %v", want, have) - } - - e(context.Background(), struct{}{}) - if want, have := 6, thru; want != have { // didn't get thru - t.Errorf("want %d, have %d", want, have) - } - - time.Sleep(2 * timeout) - - e(context.Background(), struct{}{}) - if want, have := 7, thru; want != have { // got thru via halfopen - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateOpen, last; want != have { // re-tripped - t.Errorf("want %v, have %v", want, have) - } - - time.Sleep(2 * timeout) - - myError = nil - e(context.Background(), struct{}{}) - if want, have := 8, thru; want != have { // got thru via halfopen - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateClosed, last; want != have { // now it's good - t.Errorf("want %v, have %v", want, have) - } -}