diff --git a/loadbalancer/README.md b/loadbalancer/README.md new file mode 100644 index 0000000..62dddc1 --- /dev/null +++ b/loadbalancer/README.md @@ -0,0 +1,63 @@ +# package loadbalancer + +`package loadbalancer` provides a client-side load balancer abstraction. + +A publisher is responsible for emitting the most recent set of endpoints for a +single logical service. Publishers exist for static endpoints, and endpoints +discovered via periodic DNS SRV lookups on a single logical name. Consul and +etcd publishers are planned. + +Different load balancers are implemented on top of publishers. Go kit +currently provides random and round-robin load balancers. Smarter behaviors, +e.g. load balancing based on underlying endpoint priority/weight, is planned. + +## Rationale + +TODO + +## Usage + +In your client, construct a publisher for a specific remote service, and pass +it to a load balancer. Then, request an endpoint from the load balancer +whenever you need to make a request to that remote service. + +```go +import ( + "github.com/go-kit/kit/loadbalancer" + "github.com/go-kit/kit/loadbalancer/dnssrv" +) + +func main() { + // Construct a load balancer for foosvc, which gets foosvc instances by + // polling a specific DNS SRV name. + p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) + lb := loadbalancer.NewRoundRobin(p) + + // Get a new endpoint from the load balancer. + endpoint, err := lb.Endpoint() + if err != nil { + panic(err) + } + + // Use the endpoint to make a request. + response, err := endpoint(ctx, request) +} + +func fooFactory(instance string) (endpoint.Endpoint, error) { + // Convert an instance (host:port) to an endpoint, via a defined transport binding. +} +``` + +It's also possible to wrap a load balancer with a retry strategy, so that it +can used as an endpoint directly. This may make load balancers more convenient +to use, at the cost of fine-grained control of failure. + +```go +func main() { + p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) + lb := loadbalancer.NewRoundRobin(p) + endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb) + + response, err := endpoint(ctx, request) // requests will be automatically load balanced +} +``` \ No newline at end of file diff --git a/loadbalancer/loadbalancer.go b/loadbalancer/loadbalancer.go index d551ae1..6a99e66 100644 --- a/loadbalancer/loadbalancer.go +++ b/loadbalancer/loadbalancer.go @@ -1,6 +1,16 @@ package loadbalancer -import "errors" +import ( + "errors" + + "github.com/go-kit/kit/endpoint" +) + +// LoadBalancer describes something that can yield endpoints for a remote +// service method. +type LoadBalancer interface { + Endpoint() (endpoint.Endpoint, error) +} // ErrNoEndpoints is returned when a load balancer (or one of its components) // has no endpoints to return. In a request lifecycle, this is usually a fatal diff --git a/loadbalancer/random_test.go b/loadbalancer/random_test.go index 01f95e5..3cb0643 100644 --- a/loadbalancer/random_test.go +++ b/loadbalancer/random_test.go @@ -28,7 +28,7 @@ endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } } - lb := loadbalancer.NewRandom(static.Publisher(endpoints), seed) + lb := loadbalancer.NewRandom(static.NewPublisher(endpoints), seed) for i := 0; i < iterations; i++ { e, err := lb.Endpoint() @@ -50,7 +50,7 @@ } func TestRandomNoEndpoints(t *testing.T) { - lb := loadbalancer.NewRandom(static.Publisher([]endpoint.Endpoint{}), 123) + lb := loadbalancer.NewRandom(static.NewPublisher([]endpoint.Endpoint{}), 123) _, have := lb.Endpoint() if want := loadbalancer.ErrNoEndpoints; want != have { t.Errorf("want %q, have %q", want, have) diff --git a/loadbalancer/retry.go b/loadbalancer/retry.go new file mode 100644 index 0000000..a8a2f67 --- /dev/null +++ b/loadbalancer/retry.go @@ -0,0 +1,53 @@ +package loadbalancer + +import ( + "fmt" + "strings" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +// Retry wraps the load balancer to make it behave like a simple endpoint. +// Requests to the endpoint will be automatically load balanced via the load +// balancer. Requests that return errors will be retried until they succeed, +// up to max times, or until the timeout is elapsed, whichever comes first. +func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + var ( + newctx, cancel = context.WithTimeout(ctx, timeout) + responses = make(chan interface{}, 1) + errs = make(chan error, 1) + a = []string{} + ) + defer cancel() + for i := 1; i <= max; i++ { + go func() { + e, err := lb.Endpoint() + if err != nil { + errs <- err + return + } + response, err := e(newctx, request) + if err != nil { + errs <- err + return + } + responses <- response + }() + + select { + case <-newctx.Done(): + return nil, newctx.Err() + case response := <-responses: + return response, nil + case err := <-errs: + a = append(a, err.Error()) + continue + } + } + return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) + } +} diff --git a/loadbalancer/retry_test.go b/loadbalancer/retry_test.go new file mode 100644 index 0000000..fe60e3d --- /dev/null +++ b/loadbalancer/retry_test.go @@ -0,0 +1,83 @@ +package loadbalancer_test + +import ( + "errors" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer" + "github.com/go-kit/kit/loadbalancer/static" +) + +func TestRetryMaxTotalFail(t *testing.T) { + var ( + endpoints = []endpoint.Endpoint{} // no endpoints + p = static.NewPublisher(endpoints) + lb = loadbalancer.NewRoundRobin(p) + retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries + ctx = context.Background() + ) + if _, err := retry(ctx, struct{}{}); err == nil { + t.Errorf("expected error, got none") // should fail + } +} + +func TestRetryMaxPartialFail(t *testing.T) { + var ( + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, + } + retries = len(endpoints) - 1 // not quite enough retries + p = static.NewPublisher(endpoints) + lb = loadbalancer.NewRoundRobin(p) + ctx = context.Background() + ) + if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { + t.Errorf("expected error, got none") + } +} + +func TestRetryMaxSuccess(t *testing.T) { + var ( + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, + } + retries = len(endpoints) // exactly enough retries + p = static.NewPublisher(endpoints) + lb = loadbalancer.NewRoundRobin(p) + ctx = context.Background() + ) + if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { + t.Error(err) + } +} + +func TestRetryTimeout(t *testing.T) { + var ( + step = make(chan struct{}) + e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } + timeout = time.Millisecond + retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(static.NewPublisher([]endpoint.Endpoint{e}))) + errs = make(chan error, 1) + invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } + ) + + go func() { step <- struct{}{} }() // queue up a flush of the endpoint + invoke() // invoke the endpoint and trigger the flush + if err := <-errs; err != nil { // that should succeed + t.Error(err) + } + + go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush + invoke() // invoke the endpoint + if err := <-errs; err != context.DeadlineExceeded { // that should not succeed + t.Errorf("wanted %v, got none", context.DeadlineExceeded) + } +} diff --git a/loadbalancer/round_robin_test.go b/loadbalancer/round_robin_test.go index 561bcf6..a04357c 100644 --- a/loadbalancer/round_robin_test.go +++ b/loadbalancer/round_robin_test.go @@ -21,7 +21,7 @@ } ) - lb := loadbalancer.NewRoundRobin(static.Publisher(endpoints)) + lb := loadbalancer.NewRoundRobin(static.NewPublisher(endpoints)) for i, want := range [][]int{ {1, 0, 0}, diff --git a/loadbalancer/static/publisher.go b/loadbalancer/static/publisher.go index 6ffd3dd..d1e8ece 100644 --- a/loadbalancer/static/publisher.go +++ b/loadbalancer/static/publisher.go @@ -1,9 +1,35 @@ package static -import "github.com/go-kit/kit/endpoint" +import ( + "sync" + + "github.com/go-kit/kit/endpoint" +) // Publisher yields the same set of static endpoints. -type Publisher []endpoint.Endpoint +type Publisher struct { + mtx sync.RWMutex + endpoints []endpoint.Endpoint +} + +// NewPublisher returns a static endpoint Publisher. +func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { + return &Publisher{ + endpoints: endpoints, + } +} // Endpoints implements the Publisher interface. -func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) { return p, nil } +func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { + p.mtx.RLock() + defer p.mtx.RUnlock() + return p.endpoints, nil +} + +// Replace is a utility method to swap out the underlying endpoints of an +// existing static publisher. It's useful mostly for testing. +func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { + p.mtx.Lock() + defer p.mtx.Unlock() + p.endpoints = endpoints +} diff --git a/loadbalancer/static/publisher_test.go b/loadbalancer/static/publisher_test.go index 8152c93..fbe677c 100644 --- a/loadbalancer/static/publisher_test.go +++ b/loadbalancer/static/publisher_test.go @@ -16,7 +16,7 @@ e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } endpoints = []endpoint.Endpoint{e1, e2} ) - p := static.Publisher(endpoints) + p := static.NewPublisher(endpoints) have, err := p.Endpoints() if err != nil { t.Fatal(err) @@ -25,3 +25,24 @@ t.Fatalf("want %#+v, have %#+v", want, have) } } + +func TestStaticReplace(t *testing.T) { + p := static.NewPublisher([]endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + }) + have, err := p.Endpoints() + if err != nil { + t.Fatal(err) + } + if want, have := 1, len(have); want != have { + t.Fatalf("want %d, have %d", want, have) + } + p.Replace([]endpoint.Endpoint{}) + have, err = p.Endpoints() + if err != nil { + t.Fatal(err) + } + if want, have := 0, len(have); want != have { + t.Fatalf("want %d, have %d", want, have) + } +}