package lb_test
import (
"context"
"errors"
"testing"
"time"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
"github.com/go-kit/kit/sd/lb"
)
func TestRetryMaxTotalFail(t *testing.T) {
var (
endpoints = sd.FixedEndpointer{} // no endpoints
rr = lb.NewRoundRobin(endpoints)
retry = lb.Retry(999, time.Second, rr) // lots of retries
ctx = context.Background()
)
if _, err := retry(ctx, struct{}{}); err == nil {
t.Errorf("expected error, got none") // should fail
}
}
func TestRetryMaxPartialFail(t *testing.T) {
var (
endpoints = []endpoint.Endpoint{
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
endpointer = sd.FixedEndpointer{
0: endpoints[0],
1: endpoints[1],
2: endpoints[2],
}
retries = len(endpoints) - 1 // not quite enough retries
rr = lb.NewRoundRobin(endpointer)
ctx = context.Background()
)
if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err == nil {
t.Errorf("expected error two, got none")
}
}
func TestRetryMaxSuccess(t *testing.T) {
var (
endpoints = []endpoint.Endpoint{
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
endpointer = sd.FixedEndpointer{
0: endpoints[0],
1: endpoints[1],
2: endpoints[2],
}
retries = len(endpoints) // exactly enough retries
rr = lb.NewRoundRobin(endpointer)
ctx = context.Background()
)
if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err != nil {
t.Error(err)
}
}
func TestRetryTimeout(t *testing.T) {
var (
step = make(chan struct{})
e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
timeout = time.Millisecond
retry = lb.Retry(999, timeout, lb.NewRoundRobin(sd.FixedEndpointer{0: e}))
errs = make(chan error, 1)
invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
)
go func() { step <- struct{}{} }() // queue up a flush of the endpoint
invoke() // invoke the endpoint and trigger the flush
if err := <-errs; err != nil { // that should succeed
t.Error(err)
}
go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
invoke() // invoke the endpoint
if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
t.Errorf("wanted %v, got none", context.DeadlineExceeded)
}
}
func TestAbortEarlyCustomMessage(t *testing.T) {
var (
myErr = errors.New("aborting early")
cb = func(int, error) (bool, error) { return false, myErr }
endpoints = sd.FixedEndpointer{} // no endpoints
rr = lb.NewRoundRobin(endpoints)
retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
ctx = context.Background()
)
_, err := retry(ctx, struct{}{})
if want, have := myErr, err.(lb.RetryError).Final; want != have {
t.Errorf("want %v, have %v", want, have)
}
}
func TestErrorPassedUnchangedToCallback(t *testing.T) {
var (
myErr = errors.New("my custom error")
cb = func(_ int, err error) (bool, error) {
if want, have := myErr, err; want != have {
t.Errorf("want %v, have %v", want, have)
}
return false, nil
}
endpoint = func(ctx context.Context, request interface{}) (interface{}, error) {
return nil, myErr
}
endpoints = sd.FixedEndpointer{endpoint} // no endpoints
rr = lb.NewRoundRobin(endpoints)
retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
ctx = context.Background()
)
_, err := retry(ctx, struct{}{})
if want, have := myErr, err.(lb.RetryError).Final; want != have {
t.Errorf("want %v, have %v", want, have)
}
}
func TestHandleNilCallback(t *testing.T) {
var (
endpointer = sd.FixedEndpointer{
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
rr = lb.NewRoundRobin(endpointer)
ctx = context.Background()
)
retry := lb.RetryWithCallback(time.Second, rr, nil)
if _, err := retry(ctx, struct{}{}); err != nil {
t.Error(err)
}
}