diff --git a/circuitbreaker/gobreaker_test.go b/circuitbreaker/gobreaker_test.go index 483e4a9..6425e61 100644 --- a/circuitbreaker/gobreaker_test.go +++ b/circuitbreaker/gobreaker_test.go @@ -1,77 +1,19 @@ package circuitbreaker_test import ( - "errors" "testing" - "time" "github.com/sony/gobreaker" - "golang.org/x/net/context" "github.com/go-kit/kit/circuitbreaker" - "github.com/go-kit/kit/endpoint" ) func TestGobreaker(t *testing.T) { var ( - thru int - last gobreaker.State - myError = errors.New("❤️") - timeout = time.Millisecond - stateChange = func(_ string, from, to gobreaker.State) { last = to } + breaker = circuitbreaker.Gobreaker(gobreaker.Settings{}) + primeWith = 100 + shouldPass = func(n int) bool { return n <= 5 } // https://github.com/sony/gobreaker/blob/bfa846d/gobreaker.go#L76 + circuitOpenError = "circuit breaker is open" ) - - var e endpoint.Endpoint - e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } - e = circuitbreaker.Gobreaker(gobreaker.Settings{ - Timeout: timeout, - OnStateChange: stateChange, - })(e) - - // "Default ReadyToTrip returns true when the number of consecutive - // failures is more than 5." - // https://github.com/sony/gobreaker/blob/bfa846d/gobreaker.go#L76 - for i := 0; i < 5; i++ { - if _, err := e(context.Background(), struct{}{}); err != myError { - t.Errorf("want %v, have %v", myError, err) - } - } - - if want, have := 5, thru; want != have { - t.Errorf("want %d, have %d", want, have) - } - - e(context.Background(), struct{}{}) - if want, have := 6, thru; want != have { // got thru - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateOpen, last; want != have { // tripped - t.Errorf("want %v, have %v", want, have) - } - - e(context.Background(), struct{}{}) - if want, have := 6, thru; want != have { // didn't get thru - t.Errorf("want %d, have %d", want, have) - } - - time.Sleep(2 * timeout) - - e(context.Background(), struct{}{}) - if want, have := 7, thru; want != have { // got thru via halfopen - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateOpen, last; want != have { // re-tripped - t.Errorf("want %v, have %v", want, have) - } - - time.Sleep(2 * timeout) - - myError = nil - e(context.Background(), struct{}{}) - if want, have := 8, thru; want != have { // got thru via halfopen - t.Errorf("want %d, have %d", want, have) - } - if want, have := gobreaker.StateClosed, last; want != have { // now it's good - t.Errorf("want %v, have %v", want, have) - } + testFailingEndpoint(t, breaker, primeWith, shouldPass, circuitOpenError) } diff --git a/circuitbreaker/handy_breaker.go b/circuitbreaker/handy_breaker.go index 93510a0..6e1584d 100644 --- a/circuitbreaker/handy_breaker.go +++ b/circuitbreaker/handy_breaker.go @@ -1,7 +1,6 @@ package circuitbreaker import ( - "errors" "time" "github.com/streadway/handy/breaker" @@ -9,10 +8,6 @@ "github.com/go-kit/kit/endpoint" ) - -// ErrCircuitBreakerOpen is returned when the HandyBreaker's circuit is open -// and the request is stopped from proceeding. -var ErrCircuitBreakerOpen = errors.New("circuit breaker open") // HandyBreaker returns an endpoint.Middleware that implements the circuit // breaker pattern using the streadway/handy/breaker package. Only errors @@ -26,7 +21,7 @@ return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { if !b.Allow() { - return nil, ErrCircuitBreakerOpen + return nil, breaker.ErrCircuitOpen } defer func(begin time.Time) { diff --git a/circuitbreaker/handy_breaker_test.go b/circuitbreaker/handy_breaker_test.go index 82d5d8b..aa520ec 100644 --- a/circuitbreaker/handy_breaker_test.go +++ b/circuitbreaker/handy_breaker_test.go @@ -1,59 +1,20 @@ package circuitbreaker_test import ( - "errors" "testing" - "github.com/streadway/handy/breaker" - - "golang.org/x/net/context" + handybreaker "github.com/streadway/handy/breaker" "github.com/go-kit/kit/circuitbreaker" - "github.com/go-kit/kit/endpoint" ) func TestHandyBreaker(t *testing.T) { var ( - thru = 0 - myError = error(nil) - ratio = 0.05 - primeWith = breaker.DefaultMinObservations * 10 - shouldPass = func(failed int) bool { return (float64(failed) / float64(primeWith+failed)) <= ratio } - extraTries = 10 + failureRatio = 0.05 + breaker = circuitbreaker.HandyBreaker(failureRatio) + primeWith = handybreaker.DefaultMinObservations * 10 + shouldPass = func(n int) bool { return (float64(n) / float64(primeWith+n)) <= failureRatio } + openCircuitError = handybreaker.ErrCircuitOpen.Error() ) - - var e endpoint.Endpoint - e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } - e = circuitbreaker.HandyBreaker(ratio)(e) - - // Prime with some successes. - for i := 0; i < primeWith; i++ { - if _, err := e(context.Background(), struct{}{}); err != nil { - t.Fatal(err) - } - } - - // Now we start throwing errors. - myError = errors.New(":(") - - // The first few should get thru. - var letThru int - for i := 0; shouldPass(i); i++ { // off-by-one - letThru++ - if _, err := e(context.Background(), struct{}{}); err != myError { - t.Fatalf("want %v, have %v", myError, err) - } - } - - // But the rest should be blocked by an open circuit. - for i := 1; i <= extraTries; i++ { - if _, err := e(context.Background(), struct{}{}); err != circuitbreaker.ErrCircuitBreakerOpen { - t.Errorf("with request #%d, want %v, have %v", primeWith+letThru+i, circuitbreaker.ErrCircuitBreakerOpen, err) - } - } - - // Confirm the rest didn't get through. - if want, have := primeWith+letThru, thru; want != have { - t.Errorf("want %d, have %d", want, have) - } + testFailingEndpoint(t, breaker, primeWith, shouldPass, openCircuitError) } diff --git a/circuitbreaker/hystrix_test.go b/circuitbreaker/hystrix_test.go index d18959a..af92a54 100644 --- a/circuitbreaker/hystrix_test.go +++ b/circuitbreaker/hystrix_test.go @@ -1,119 +1,35 @@ package circuitbreaker_test import ( - "errors" + stdlog "log" + "os" "testing" - "time" "github.com/afex/hystrix-go/hystrix" - "golang.org/x/net/context" "github.com/go-kit/kit/circuitbreaker" - "github.com/go-kit/kit/endpoint" + kitlog "github.com/go-kit/kit/log" ) -func TestHystrixCircuitBreakerOpen(t *testing.T) { - var ( - thru = 0 - myError = error(nil) - ratio = 0.04 - primeWith = hystrix.DefaultVolumeThreshold * 2 - shouldPass = func(failed int) bool { return (float64(failed) / float64(primeWith+failed)) <= ratio } - extraTries = 10 +func TestHystrix(t *testing.T) { + logger := kitlog.NewLogfmtLogger(os.Stderr) + stdlog.SetOutput(kitlog.NewStdlibAdapter(logger)) + + const ( + commandName = "my-endpoint" + errorPercent = 5 + maxConcurrent = 1000 ) - - // configure hystrix - hystrix.ConfigureCommand("myEndpoint", hystrix.CommandConfig{ - ErrorPercentThreshold: 5, - MaxConcurrentRequests: 200, + hystrix.ConfigureCommand(commandName, hystrix.CommandConfig{ + ErrorPercentThreshold: errorPercent, + MaxConcurrentRequests: maxConcurrent, }) - var e endpoint.Endpoint - e = func(context.Context, interface{}) (interface{}, error) { thru++; return struct{}{}, myError } - e = circuitbreaker.Hystrix("myEndpoint")(e) - - // prime - for i := 0; i < primeWith; i++ { - if _, err := e(context.Background(), struct{}{}); err != nil { - t.Fatal(err) - } - } - - // Now we start throwing errors. - myError = errors.New(":(") - - // The first few should get thru. - var letThru int - for i := 0; shouldPass(i); i++ { // off-by-one - letThru++ - if _, err := e(context.Background(), struct{}{}); err != myError { - t.Fatalf("want %v, have %v", myError, err) - } - } - - // But the rest should be blocked by an open circuit. - for i := 1; i <= extraTries; i++ { - if _, err := e(context.Background(), struct{}{}); err != hystrix.ErrCircuitOpen { - t.Errorf("with request #%d, want %v, have %v", primeWith+letThru+i, hystrix.ErrCircuitOpen, err) - } - } - - // Confirm the rest didn't get through. - if want, have := primeWith+letThru, thru; want != have { - t.Errorf("want %d, have %d", want, have) - } + var ( + breaker = circuitbreaker.Hystrix(commandName) + primeWith = hystrix.DefaultVolumeThreshold * 2 + shouldPass = func(n int) bool { return (float64(n) / float64(primeWith+n)) <= (float64(errorPercent-1) / 100.0) } + openCircuitError = hystrix.ErrCircuitOpen.Error() + ) + testFailingEndpoint(t, breaker, primeWith, shouldPass, openCircuitError) } - -func TestHystrixTimeout(t *testing.T) { - var ( - timeout = time.Millisecond * 0 - primeWith = hystrix.DefaultVolumeThreshold * 2 - failNumber = 2 // 5% threshold - ) - - // configure hystrix - hystrix.ConfigureCommand("timeoutEndpoint", hystrix.CommandConfig{ - ErrorPercentThreshold: 5, - MaxConcurrentRequests: 200, - SleepWindow: 5, // milliseconds - Timeout: 1, // milliseconds - }) - - var e endpoint.Endpoint - e = func(context.Context, interface{}) (interface{}, error) { - time.Sleep(2 * timeout) - return struct{}{}, nil - } - e = circuitbreaker.Hystrix("timeoutEndpoint")(e) - - // prime - for i := 0; i < primeWith; i++ { - if _, err := e(context.Background(), struct{}{}); err != nil { - t.Errorf("expecting %v, have %v", nil, err) - } - } - - // times out - timeout = time.Millisecond * 2 - for i := 0; i < failNumber; i++ { - if _, err := e(context.Background(), struct{}{}); err != hystrix.ErrTimeout { - t.Errorf("%d expecting %v, have %v", i, hystrix.ErrTimeout, err) - } - } - - // fix timeout - timeout = time.Millisecond * 0 - - // fails for a little while still - for i := 0; i < failNumber; i++ { - if _, err := e(context.Background(), struct{}{}); err != hystrix.ErrCircuitOpen { - t.Errorf("expecting %v, have %v", hystrix.ErrCircuitOpen, err) - } - } - - // back to OK - time.Sleep(time.Millisecond * 5) - if _, err := e(context.Background(), struct{}{}); err != nil { - t.Errorf("expecting %v, have %v", nil, err) - } -} diff --git a/circuitbreaker/util_test.go b/circuitbreaker/util_test.go new file mode 100644 index 0000000..c6c6f03 --- /dev/null +++ b/circuitbreaker/util_test.go @@ -0,0 +1,59 @@ +package circuitbreaker_test + +import ( + "errors" + "testing" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +func testFailingEndpoint(t *testing.T, breaker endpoint.Middleware, primeWith int, shouldPass func(int) bool, openCircuitError string) { + // Create a mock endpoint and wrap it with the breaker. + m := mock{} + var e endpoint.Endpoint + e = m.endpoint + e = breaker(e) + + // Prime the endpoint with successful requests. + for i := 0; i < primeWith; i++ { + if _, err := e(context.Background(), struct{}{}); err != nil { + t.Fatalf("during priming, got error: %v", err) + } + } + + // Switch the endpoint to start throwing errors. + m.err = errors.New("tragedy+disaster") + m.thru = 0 + + // The first several should be allowed through and yield our error. + for i := 0; shouldPass(i); i++ { + if _, err := e(context.Background(), struct{}{}); err != m.err { + t.Fatalf("want %v, have %v", m.err, err) + } + } + thru := m.thru + + // But the rest should be blocked by an open circuit. + for i := 0; i < 10; i++ { + if _, err := e(context.Background(), struct{}{}); err.Error() != openCircuitError { + t.Fatalf("want %q, have %q", openCircuitError, err.Error()) + } + } + + // Make sure none of those got through. + if want, have := thru, m.thru; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type mock struct { + thru int + err error +} + +func (m *mock) endpoint(context.Context, interface{}) (interface{}, error) { + m.thru++ + return struct{}{}, m.err +}