diff --git a/loadbalancer/cache.go b/loadbalancer/cache.go deleted file mode 100644 index b93b230..0000000 --- a/loadbalancer/cache.go +++ /dev/null @@ -1,47 +0,0 @@ -package loadbalancer - -import "github.com/go-kit/kit/endpoint" - -type cache struct { - req chan []endpoint.Endpoint - cnt chan int - quit chan struct{} -} - -func newCache(p Publisher) *cache { - c := &cache{ - req: make(chan []endpoint.Endpoint), - cnt: make(chan int), - quit: make(chan struct{}), - } - go c.loop(p) - return c -} - -func (c *cache) loop(p Publisher) { - e := make(chan []endpoint.Endpoint, 1) - p.Subscribe(e) - defer p.Unsubscribe(e) - endpoints := <-e - for { - select { - case endpoints = <-e: - case c.cnt <- len(endpoints): - case c.req <- endpoints: - case <-c.quit: - return - } - } -} - -func (c *cache) count() int { - return <-c.cnt -} - -func (c *cache) get() []endpoint.Endpoint { - return <-c.req -} - -func (c *cache) stop() { - close(c.quit) -} diff --git a/loadbalancer/cache_internal_test.go b/loadbalancer/cache_internal_test.go deleted file mode 100644 index 0cad6f7..0000000 --- a/loadbalancer/cache_internal_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package loadbalancer - -import ( - "runtime" - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" -) - -func TestCache(t *testing.T) { - e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - endpoints := []endpoint.Endpoint{e} - - p := NewStaticPublisher(endpoints) - defer p.Stop() - - c := newCache(p) - defer c.stop() - - for _, n := range []int{2, 10, 0} { - endpoints = make([]endpoint.Endpoint, n) - for i := 0; i < n; i++ { - endpoints[i] = e - } - p.Replace(endpoints) - runtime.Gosched() - if want, have := len(endpoints), len(c.get()); want != have { - t.Errorf("want %d, have %d", want, have) - } - } -} diff --git a/loadbalancer/dns_srv_publisher.go b/loadbalancer/dns_srv_publisher.go deleted file mode 100644 index 154e854..0000000 --- a/loadbalancer/dns_srv_publisher.go +++ /dev/null @@ -1,105 +0,0 @@ -package loadbalancer - -import ( - "crypto/md5" - "fmt" - "net" - "sort" - "time" - - "github.com/go-kit/kit/endpoint" -) - -type dnssrvPublisher struct { - subscribe chan chan<- []endpoint.Endpoint - unsubscribe chan chan<- []endpoint.Endpoint - quit chan struct{} -} - -// NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and -func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher { - p := &dnssrvPublisher{ - subscribe: make(chan chan<- []endpoint.Endpoint), - unsubscribe: make(chan chan<- []endpoint.Endpoint), - quit: make(chan struct{}), - } - go p.loop(name, ttl, makeEndpoint) - return p -} - -func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) { - p.subscribe <- c -} - -func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { - p.unsubscribe <- c -} - -func (p *dnssrvPublisher) Stop() { - close(p.quit) -} - -var newTicker = time.NewTicker - -func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) { - var ( - subscriptions = map[chan<- []endpoint.Endpoint]struct{}{} - addrs, md5, _ = resolve(name) - endpoints = convert(addrs, makeEndpoint) - ticker = newTicker(ttl) - ) - defer ticker.Stop() - for { - select { - case <-ticker.C: - addrs, newmd5, err := resolve(name) - if err == nil && newmd5 != md5 { - endpoints = convert(addrs, makeEndpoint) - for c := range subscriptions { - c <- endpoints - } - md5 = newmd5 - } - - case c := <-p.subscribe: - subscriptions[c] = struct{}{} - c <- endpoints - - case c := <-p.unsubscribe: - delete(subscriptions, c) - - case <-p.quit: - return - } - } -} - -// Allow mocking in tests. -var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) { - _, addrs, err = net.LookupSRV("", "", name) - if err != nil { - return addrs, "", err - } - hostports := make([]string, len(addrs)) - for i, addr := range addrs { - hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) - } - sort.Sort(sort.StringSlice(hostports)) - h := md5.New() - for _, hostport := range hostports { - fmt.Fprintf(h, hostport) - } - return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil -} - -func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint { - endpoints := make([]endpoint.Endpoint, len(addrs)) - for i, addr := range addrs { - endpoints[i] = makeEndpoint(addr2hostport(addr)) - } - return endpoints -} - -func addr2hostport(addr *net.SRV) string { - return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port)) -} diff --git a/loadbalancer/dns_srv_publisher_internal_test.go b/loadbalancer/dns_srv_publisher_internal_test.go deleted file mode 100644 index 2f8c978..0000000 --- a/loadbalancer/dns_srv_publisher_internal_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package loadbalancer - -import ( - "fmt" - "net" - "testing" - "time" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" -) - -func TestDNSSRVPublisher(t *testing.T) { - // Reset the vars when we're done - oldResolve := resolve - defer func() { resolve = oldResolve }() - oldNewTicker := newTicker - defer func() { newTicker = oldNewTicker }() - - // Set up a fixture and swap the vars - a := []*net.SRV{ - {Target: "foo", Port: 123}, - {Target: "bar", Port: 456}, - {Target: "baz", Port: 789}, - } - ticker := make(chan time.Time) - resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil } - newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} } - - // Construct endpoint - m := map[string]int{} - e := func(hostport string) endpoint.Endpoint { - return func(context.Context, interface{}) (interface{}, error) { - m[hostport]++ - return struct{}{}, nil - } - } - - // Build the publisher - var ( - name = "irrelevant" - ttl = time.Second - makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) } - ) - p := NewDNSSRVPublisher(name, ttl, makeEndpoint) - defer p.Stop() - - // Subscribe - c := make(chan []endpoint.Endpoint, 1) - p.Subscribe(c) - defer p.Unsubscribe(c) - - // Invoke all of the endpoints - for _, e := range <-c { - e(context.Background(), struct{}{}) - } - - // Make sure we invoked what we expected to - for _, addr := range a { - hostport := addr2hostport(addr) - if want, have := 1, m[hostport]; want != have { - t.Errorf("%q: want %d, have %d", name, want, have) - } - delete(m, hostport) - } - if want, have := 0, len(m); want != have { - t.Errorf("want %d, have %d", want, have) - } - - // Reset the fixture, trigger the timer, count the endpoints - a = []*net.SRV{} - ticker <- time.Now() - if want, have := len(a), len(<-c); want != have { - t.Errorf("want %d, have %d", want, have) - } -} diff --git a/loadbalancer/endpoint_cache.go b/loadbalancer/endpoint_cache.go deleted file mode 100644 index 41851fd..0000000 --- a/loadbalancer/endpoint_cache.go +++ /dev/null @@ -1,41 +0,0 @@ -package loadbalancer - -import "github.com/go-kit/kit/endpoint" - -type endpointCache struct { - requests chan []endpoint.Endpoint - quit chan struct{} -} - -func newEndpointCache(p Publisher) *endpointCache { - c := &endpointCache{ - requests: make(chan []endpoint.Endpoint), - quit: make(chan struct{}), - } - go c.loop(p) - return c -} - -func (c *endpointCache) loop(p Publisher) { - updates := make(chan []endpoint.Endpoint, 1) - p.Subscribe(updates) - defer p.Unsubscribe(updates) - endpoints := <-updates - - for { - select { - case endpoints = <-updates: - case c.requests <- endpoints: - case <-c.quit: - return - } - } -} - -func (c *endpointCache) get() []endpoint.Endpoint { - return <-c.requests -} - -func (c *endpointCache) stop() { - close(c.quit) -} diff --git a/loadbalancer/endpoint_cache_internal_test.go b/loadbalancer/endpoint_cache_internal_test.go deleted file mode 100644 index c3e25a0..0000000 --- a/loadbalancer/endpoint_cache_internal_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package loadbalancer - -import ( - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" -) - -func TestEndpointCache(t *testing.T) { - endpoints := []endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - } - - p := NewStaticPublisher(endpoints) - defer p.Stop() - - c := newEndpointCache(p) - defer c.stop() - - if want, have := len(endpoints), len(c.get()); want != have { - t.Errorf("want %d, have %d", want, have) - } -} diff --git a/loadbalancer/load_balancer.go b/loadbalancer/load_balancer.go index 0e9d09e..241b766 100644 --- a/loadbalancer/load_balancer.go +++ b/loadbalancer/load_balancer.go @@ -12,6 +12,6 @@ Get() (endpoint.Endpoint, error) } -// ErrNoEndpointsAvailable is given by a load balancer when no endpoints are -// available to be returned. +// ErrNoEndpointsAvailable is given by a load balancer or strategy when no +// endpoints are available to be returned. var ErrNoEndpointsAvailable = errors.New("no endpoints available") diff --git a/loadbalancer/mock_publisher_test.go b/loadbalancer/mock_publisher_test.go deleted file mode 100644 index 510a419..0000000 --- a/loadbalancer/mock_publisher_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package loadbalancer_test - -import ( - "runtime" - "sync" - - "github.com/go-kit/kit/endpoint" -) - -type mockPublisher struct { - sync.Mutex - e []endpoint.Endpoint - s map[chan<- []endpoint.Endpoint]struct{} -} - -func newMockPublisher(endpoints []endpoint.Endpoint) *mockPublisher { - return &mockPublisher{ - e: endpoints, - s: map[chan<- []endpoint.Endpoint]struct{}{}, - } -} - -func (p *mockPublisher) replace(endpoints []endpoint.Endpoint) { - p.Lock() - defer p.Unlock() - p.e = endpoints - for s := range p.s { - s <- p.e - } - runtime.Gosched() -} - -func (p *mockPublisher) Subscribe(c chan<- []endpoint.Endpoint) { - p.Lock() - defer p.Unlock() - p.s[c] = struct{}{} - c <- p.e -} - -func (p *mockPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { - p.Lock() - defer p.Unlock() - delete(p.s, c) -} - -func (p *mockPublisher) Stop() {} diff --git a/loadbalancer/publisher/dns/srv_publisher.go b/loadbalancer/publisher/dns/srv_publisher.go new file mode 100644 index 0000000..e991919 --- /dev/null +++ b/loadbalancer/publisher/dns/srv_publisher.go @@ -0,0 +1,110 @@ +package dns + +import ( + "crypto/md5" + "fmt" + "net" + "sort" + "time" + + "github.com/go-kit/kit/endpoint" +) + +// SRVPublisher implements Publisher. +type SRVPublisher struct { + subscribe chan chan<- []endpoint.Endpoint + unsubscribe chan chan<- []endpoint.Endpoint + quit chan struct{} +} + +// NewSRVPublisher returns a publisher that resolves the SRV name every ttl, +// and yields endpoints constructed via the makeEndpoint factory. +func NewSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) *SRVPublisher { + p := &SRVPublisher{ + subscribe: make(chan chan<- []endpoint.Endpoint), + unsubscribe: make(chan chan<- []endpoint.Endpoint), + quit: make(chan struct{}), + } + go p.loop(name, ttl, makeEndpoint) + return p +} + +// Subscribe implements Publisher. +func (p *SRVPublisher) Subscribe(c chan<- []endpoint.Endpoint) { + p.subscribe <- c +} + +// Unsubscribe implements Publisher. +func (p *SRVPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { + p.unsubscribe <- c +} + +// Stop implements Publisher. +func (p *SRVPublisher) Stop() { + close(p.quit) +} + +var newTicker = time.NewTicker + +func (p *SRVPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) { + var ( + subscriptions = map[chan<- []endpoint.Endpoint]struct{}{} + addrs, md5, _ = resolve(name) + endpoints = convert(addrs, makeEndpoint) + ticker = newTicker(ttl) + ) + defer ticker.Stop() + for { + select { + case <-ticker.C: + addrs, newmd5, err := resolve(name) + if err == nil && newmd5 != md5 { + endpoints = convert(addrs, makeEndpoint) + for c := range subscriptions { + c <- endpoints + } + md5 = newmd5 + } + + case c := <-p.subscribe: + subscriptions[c] = struct{}{} + c <- endpoints + + case c := <-p.unsubscribe: + delete(subscriptions, c) + + case <-p.quit: + return + } + } +} + +// Allow mocking in tests. +var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) { + _, addrs, err = net.LookupSRV("", "", name) + if err != nil { + return addrs, "", err + } + hostports := make([]string, len(addrs)) + for i, addr := range addrs { + hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) + } + sort.Sort(sort.StringSlice(hostports)) + h := md5.New() + for _, hostport := range hostports { + fmt.Fprintf(h, hostport) + } + return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil +} + +func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint { + endpoints := make([]endpoint.Endpoint, len(addrs)) + for i, addr := range addrs { + endpoints[i] = makeEndpoint(addr2hostport(addr)) + } + return endpoints +} + +func addr2hostport(addr *net.SRV) string { + return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) +} diff --git a/loadbalancer/publisher/dns/srv_publisher_internal_test.go b/loadbalancer/publisher/dns/srv_publisher_internal_test.go new file mode 100644 index 0000000..f356041 --- /dev/null +++ b/loadbalancer/publisher/dns/srv_publisher_internal_test.go @@ -0,0 +1,77 @@ +package dns + +import ( + "fmt" + "net" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +func TestDNSSRVPublisher(t *testing.T) { + // Reset the vars when we're done + oldResolve := resolve + defer func() { resolve = oldResolve }() + oldNewTicker := newTicker + defer func() { newTicker = oldNewTicker }() + + // Set up a fixture and swap the vars + a := []*net.SRV{ + {Target: "foo", Port: 123}, + {Target: "bar", Port: 456}, + {Target: "baz", Port: 789}, + } + ticker := make(chan time.Time) + resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil } + newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} } + + // Construct endpoint + m := map[string]int{} + e := func(hostport string) endpoint.Endpoint { + return func(context.Context, interface{}) (interface{}, error) { + m[hostport]++ + return struct{}{}, nil + } + } + + // Build the publisher + var ( + name = "irrelevant" + ttl = time.Second + makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) } + ) + p := NewSRVPublisher(name, ttl, makeEndpoint) + defer p.Stop() + + // Subscribe + c := make(chan []endpoint.Endpoint, 1) + p.Subscribe(c) + defer p.Unsubscribe(c) + + // Invoke all of the endpoints + for _, e := range <-c { + e(context.Background(), struct{}{}) + } + + // Make sure we invoked what we expected to + for _, addr := range a { + hostport := addr2hostport(addr) + if want, have := 1, m[hostport]; want != have { + t.Errorf("%q: want %d, have %d", name, want, have) + } + delete(m, hostport) + } + if want, have := 0, len(m); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Reset the fixture, trigger the timer, count the endpoints + a = []*net.SRV{} + ticker <- time.Now() + if want, have := len(a), len(<-c); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/loadbalancer/publisher/publisher.go b/loadbalancer/publisher/publisher.go new file mode 100644 index 0000000..f8c2fcc --- /dev/null +++ b/loadbalancer/publisher/publisher.go @@ -0,0 +1,10 @@ +package publisher + +import "github.com/go-kit/kit/endpoint" + +// Publisher produces endpoints. +type Publisher interface { + Subscribe(chan<- []endpoint.Endpoint) + Unsubscribe(chan<- []endpoint.Endpoint) + Stop() +} diff --git a/loadbalancer/publisher/static/publisher.go b/loadbalancer/publisher/static/publisher.go new file mode 100644 index 0000000..59c539e --- /dev/null +++ b/loadbalancer/publisher/static/publisher.go @@ -0,0 +1,51 @@ +package static + +import ( + "sync" + + "github.com/go-kit/kit/endpoint" +) + +// Publisher holds a static set of endpoints. +type Publisher struct { + mu sync.Mutex + current []endpoint.Endpoint + subscribers map[chan<- []endpoint.Endpoint]struct{} +} + +// NewPublisher returns a publisher that yields a static set of endpoints, +// which can be completely replaced. +func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { + return &Publisher{ + current: endpoints, + subscribers: map[chan<- []endpoint.Endpoint]struct{}{}, + } +} + +// Subscribe implements Publisher. +func (p *Publisher) Subscribe(c chan<- []endpoint.Endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.subscribers[c] = struct{}{} + c <- p.current +} + +// Unsubscribe implements Publisher. +func (p *Publisher) Unsubscribe(c chan<- []endpoint.Endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.subscribers, c) +} + +// Stop implements Publisher, but is a no-op. +func (p *Publisher) Stop() {} + +// Replace replaces the endpoints and notifies all subscribers. +func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.current = endpoints + for c := range p.subscribers { + c <- p.current + } +} diff --git a/loadbalancer/publisher/static/publisher_test.go b/loadbalancer/publisher/static/publisher_test.go new file mode 100644 index 0000000..d5cdcee --- /dev/null +++ b/loadbalancer/publisher/static/publisher_test.go @@ -0,0 +1,30 @@ +package static_test + +import ( + "testing" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer/publisher/static" +) + +func TestStaticPublisher(t *testing.T) { + endpoints := []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + } + p := static.NewPublisher(endpoints) + defer p.Stop() + + c := make(chan []endpoint.Endpoint, 1) + p.Subscribe(c) + if want, have := len(endpoints), len(<-c); want != have { + t.Errorf("want %d, have %d", want, have) + } + + endpoints = []endpoint.Endpoint{} + p.Replace(endpoints) + if want, have := len(endpoints), len(<-c); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/loadbalancer/publisher.go b/loadbalancer/publisher.go deleted file mode 100644 index f697e10..0000000 --- a/loadbalancer/publisher.go +++ /dev/null @@ -1,10 +0,0 @@ -package loadbalancer - -import "github.com/go-kit/kit/endpoint" - -// Publisher produces endpoints. -type Publisher interface { - Subscribe(chan<- []endpoint.Endpoint) - Unsubscribe(chan<- []endpoint.Endpoint) - Stop() -} diff --git a/loadbalancer/random.go b/loadbalancer/random.go deleted file mode 100644 index b9cf17d..0000000 --- a/loadbalancer/random.go +++ /dev/null @@ -1,24 +0,0 @@ -package loadbalancer - -import ( - "math/rand" - - "github.com/go-kit/kit/endpoint" -) - -// Random returns a load balancer that yields random endpoints. -func Random(p Publisher) LoadBalancer { - return random{newCache(p)} -} - -type random struct{ *cache } - -func (r random) Count() int { return r.cache.count() } - -func (r random) Get() (endpoint.Endpoint, error) { - endpoints := r.cache.get() - if len(endpoints) <= 0 { - return nil, ErrNoEndpointsAvailable - } - return endpoints[rand.Intn(len(endpoints))], nil -} diff --git a/loadbalancer/random_test.go b/loadbalancer/random_test.go deleted file mode 100644 index bf3ac9c..0000000 --- a/loadbalancer/random_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package loadbalancer_test - -import ( - "math" - "testing" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "golang.org/x/net/context" -) - -func TestRandom(t *testing.T) { - p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) - defer p.Stop() - - lb := loadbalancer.Random(p) - if _, err := lb.Get(); err == nil { - t.Error("want error, got none") - } - - counts := []int{0, 0, 0} - p.Replace([]endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, - }) - assertLoadBalancerNotEmpty(t, lb) - - n := 10000 - for i := 0; i < n; i++ { - e, _ := lb.Get() - e(context.Background(), struct{}{}) - } - - want := float64(n) / float64(len(counts)) - tolerance := (want / 100.0) * 5 // 5% - for _, have := range counts { - if math.Abs(want-float64(have)) > tolerance { - t.Errorf("want %.0f, have %d", want, have) - } - } -} diff --git a/loadbalancer/retry_test.go b/loadbalancer/retry_test.go index 7a129ca..53ad619 100644 --- a/loadbalancer/retry_test.go +++ b/loadbalancer/retry_test.go @@ -2,20 +2,22 @@ import ( "errors" + "testing" "time" + + "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" "github.com/go-kit/kit/loadbalancer" - "golang.org/x/net/context" - - "testing" + "github.com/go-kit/kit/loadbalancer/publisher/static" + "github.com/go-kit/kit/loadbalancer/strategy" ) func TestRetryMax(t *testing.T) { var ( endpoints = []endpoint.Endpoint{} - p = loadbalancer.NewStaticPublisher(endpoints) - lb = loadbalancer.RoundRobin(p) + p = static.NewPublisher(endpoints) + lb = strategy.RoundRobin(p) ) if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil { @@ -28,7 +30,7 @@ func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, } p.Replace(endpoints) - assertLoadBalancerNotEmpty(t, lb) + time.Sleep(10 * time.Millisecond) //assertLoadBalancerNotEmpty(t, lb) // TODO if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil { t.Errorf("expected error, got none") @@ -44,7 +46,7 @@ step = make(chan struct{}) e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } timeout = time.Millisecond - retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e}))) + retry = loadbalancer.Retry(999, timeout, strategy.RoundRobin(static.NewPublisher([]endpoint.Endpoint{e}))) errs = make(chan error) invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } ) diff --git a/loadbalancer/round_robin.go b/loadbalancer/round_robin.go deleted file mode 100644 index 1ca844e..0000000 --- a/loadbalancer/round_robin.go +++ /dev/null @@ -1,34 +0,0 @@ -package loadbalancer - -import ( - "sync/atomic" - - "github.com/go-kit/kit/endpoint" -) - -// RoundRobin returns a load balancer that yields endpoints in sequence. -func RoundRobin(p Publisher) LoadBalancer { - return &roundRobin{newCache(p), 0} -} - -type roundRobin struct { - *cache - uint64 -} - -func (r *roundRobin) Count() int { return r.cache.count() } - -func (r *roundRobin) Get() (endpoint.Endpoint, error) { - endpoints := r.cache.get() - if len(endpoints) <= 0 { - return nil, ErrNoEndpointsAvailable - } - var old uint64 - for { - old = atomic.LoadUint64(&r.uint64) - if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) { - break - } - } - return endpoints[old%uint64(len(endpoints))], nil -} diff --git a/loadbalancer/round_robin_test.go b/loadbalancer/round_robin_test.go deleted file mode 100644 index d5f6f9d..0000000 --- a/loadbalancer/round_robin_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package loadbalancer_test - -import ( - "reflect" - "testing" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "golang.org/x/net/context" -) - -func TestRoundRobin(t *testing.T) { - p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) - defer p.Stop() - - lb := loadbalancer.RoundRobin(p) - if _, err := lb.Get(); err == nil { - t.Error("want error, got none") - } - - counts := []int{0, 0, 0} - p.Replace([]endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, - }) - assertLoadBalancerNotEmpty(t, lb) - - for i, want := range [][]int{ - {1, 0, 0}, - {1, 1, 0}, - {1, 1, 1}, - {2, 1, 1}, - {2, 2, 1}, - {2, 2, 2}, - {3, 2, 2}, - } { - e, _ := lb.Get() - e(context.Background(), struct{}{}) - if have := counts; !reflect.DeepEqual(want, have) { - t.Errorf("%d: want %v, have %v", i+1, want, have) - } - } -} diff --git a/loadbalancer/static_publisher.go b/loadbalancer/static_publisher.go deleted file mode 100644 index 42ae831..0000000 --- a/loadbalancer/static_publisher.go +++ /dev/null @@ -1,51 +0,0 @@ -package loadbalancer - -import ( - "sync" - - "github.com/go-kit/kit/endpoint" -) - -// NewStaticPublisher returns a publisher that yields a static set of -// endpoints, which can be completely replaced. -func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher { - return &StaticPublisher{ - current: endpoints, - subscribers: map[chan<- []endpoint.Endpoint]struct{}{}, - } -} - -// StaticPublisher holds a static set of endpoints. -type StaticPublisher struct { - mu sync.Mutex - current []endpoint.Endpoint - subscribers map[chan<- []endpoint.Endpoint]struct{} -} - -// Subscribe implements Publisher. -func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) { - p.mu.Lock() - defer p.mu.Unlock() - p.subscribers[c] = struct{}{} - c <- p.current -} - -// Unsubscribe implements Publisher. -func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.subscribers, c) -} - -// Stop implements Publisher, but is a no-op. -func (p *StaticPublisher) Stop() {} - -// Replace replaces the endpoints and notifies all subscribers. -func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) { - p.mu.Lock() - defer p.mu.Unlock() - p.current = endpoints - for c := range p.subscribers { - c <- p.current - } -} diff --git a/loadbalancer/static_publisher_test.go b/loadbalancer/static_publisher_test.go deleted file mode 100644 index 1fc4579..0000000 --- a/loadbalancer/static_publisher_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package loadbalancer_test - -import ( - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" -) - -func TestStaticPublisher(t *testing.T) { - endpoints := []endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - } - p := loadbalancer.NewStaticPublisher(endpoints) - defer p.Stop() - - c := make(chan []endpoint.Endpoint, 1) - p.Subscribe(c) - if want, have := len(endpoints), len(<-c); want != have { - t.Errorf("want %d, have %d", want, have) - } - - endpoints = []endpoint.Endpoint{} - p.Replace(endpoints) - if want, have := len(endpoints), len(<-c); want != have { - t.Errorf("want %d, have %d", want, have) - } -} diff --git a/loadbalancer/strategies.go b/loadbalancer/strategies.go deleted file mode 100644 index 47165b7..0000000 --- a/loadbalancer/strategies.go +++ /dev/null @@ -1,17 +0,0 @@ -package loadbalancer - -import ( - "errors" - - "github.com/go-kit/kit/endpoint" -) - -// Strategy yields endpoints to consumers according to some algorithm. -type Strategy interface { - Next() (endpoint.Endpoint, error) - Stop() -} - -// ErrNoEndpoints is returned by a strategy when there are no endpoints -// available. -var ErrNoEndpoints = errors.New("no endpoints available") diff --git a/loadbalancer/strategy/cache.go b/loadbalancer/strategy/cache.go new file mode 100644 index 0000000..ac942ac --- /dev/null +++ b/loadbalancer/strategy/cache.go @@ -0,0 +1,50 @@ +package strategy + +import ( + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer/publisher" +) + +type cache struct { + req chan []endpoint.Endpoint + cnt chan int + quit chan struct{} +} + +func newCache(p publisher.Publisher) *cache { + c := &cache{ + req: make(chan []endpoint.Endpoint), + cnt: make(chan int), + quit: make(chan struct{}), + } + go c.loop(p) + return c +} + +func (c *cache) loop(p publisher.Publisher) { + e := make(chan []endpoint.Endpoint, 1) + p.Subscribe(e) + defer p.Unsubscribe(e) + endpoints := <-e + for { + select { + case endpoints = <-e: + case c.cnt <- len(endpoints): + case c.req <- endpoints: + case <-c.quit: + return + } + } +} + +func (c *cache) count() int { + return <-c.cnt +} + +func (c *cache) get() []endpoint.Endpoint { + return <-c.req +} + +func (c *cache) stop() { + close(c.quit) +} diff --git a/loadbalancer/strategy/cache_internal_test.go b/loadbalancer/strategy/cache_internal_test.go new file mode 100644 index 0000000..7378b44 --- /dev/null +++ b/loadbalancer/strategy/cache_internal_test.go @@ -0,0 +1,34 @@ +package strategy + +import ( + "runtime" + "testing" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer/publisher/static" +) + +func TestCache(t *testing.T) { + e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + endpoints := []endpoint.Endpoint{e} + + p := static.NewPublisher(endpoints) + defer p.Stop() + + c := newCache(p) + defer c.stop() + + for _, n := range []int{2, 10, 0} { + endpoints = make([]endpoint.Endpoint, n) + for i := 0; i < n; i++ { + endpoints[i] = e + } + p.Replace(endpoints) + runtime.Gosched() + if want, have := len(endpoints), len(c.get()); want != have { + t.Errorf("want %d, have %d", want, have) + } + } +} diff --git a/loadbalancer/strategy/random.go b/loadbalancer/strategy/random.go new file mode 100644 index 0000000..4dcb8cc --- /dev/null +++ b/loadbalancer/strategy/random.go @@ -0,0 +1,26 @@ +package strategy + +import ( + "math/rand" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer" + "github.com/go-kit/kit/loadbalancer/publisher" +) + +// Random returns a load balancer that yields random endpoints. +func Random(p publisher.Publisher) loadbalancer.LoadBalancer { + return random{newCache(p)} +} + +type random struct{ *cache } + +func (r random) Count() int { return r.cache.count() } + +func (r random) Get() (endpoint.Endpoint, error) { + endpoints := r.cache.get() + if len(endpoints) <= 0 { + return nil, loadbalancer.ErrNoEndpointsAvailable + } + return endpoints[rand.Intn(len(endpoints))], nil +} diff --git a/loadbalancer/strategy/random_test.go b/loadbalancer/strategy/random_test.go new file mode 100644 index 0000000..92358e0 --- /dev/null +++ b/loadbalancer/strategy/random_test.go @@ -0,0 +1,43 @@ +package strategy_test + +import ( + "math" + "testing" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer/publisher/static" + "github.com/go-kit/kit/loadbalancer/strategy" + "golang.org/x/net/context" +) + +func TestRandom(t *testing.T) { + p := static.NewPublisher([]endpoint.Endpoint{}) + defer p.Stop() + + lb := strategy.Random(p) + if _, err := lb.Get(); err == nil { + t.Error("want error, got none") + } + + counts := []int{0, 0, 0} + p.Replace([]endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, + }) + assertLoadBalancerNotEmpty(t, lb) + + n := 10000 + for i := 0; i < n; i++ { + e, _ := lb.Get() + e(context.Background(), struct{}{}) + } + + want := float64(n) / float64(len(counts)) + tolerance := (want / 100.0) * 5 // 5% + for _, have := range counts { + if math.Abs(want-float64(have)) > tolerance { + t.Errorf("want %.0f, have %d", want, have) + } + } +} diff --git a/loadbalancer/strategy/round_robin.go b/loadbalancer/strategy/round_robin.go new file mode 100644 index 0000000..4d0cfc6 --- /dev/null +++ b/loadbalancer/strategy/round_robin.go @@ -0,0 +1,36 @@ +package strategy + +import ( + "sync/atomic" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer" + "github.com/go-kit/kit/loadbalancer/publisher" +) + +// RoundRobin returns a load balancer that yields endpoints in sequence. +func RoundRobin(p publisher.Publisher) loadbalancer.LoadBalancer { + return &roundRobin{newCache(p), 0} +} + +type roundRobin struct { + *cache + uint64 +} + +func (r *roundRobin) Count() int { return r.cache.count() } + +func (r *roundRobin) Get() (endpoint.Endpoint, error) { + endpoints := r.cache.get() + if len(endpoints) <= 0 { + return nil, loadbalancer.ErrNoEndpointsAvailable + } + var old uint64 + for { + old = atomic.LoadUint64(&r.uint64) + if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) { + break + } + } + return endpoints[old%uint64(len(endpoints))], nil +} diff --git a/loadbalancer/strategy/round_robin_test.go b/loadbalancer/strategy/round_robin_test.go new file mode 100644 index 0000000..a1d54f4 --- /dev/null +++ b/loadbalancer/strategy/round_robin_test.go @@ -0,0 +1,46 @@ +package strategy_test + +import ( + "reflect" + "testing" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer/publisher/static" + "github.com/go-kit/kit/loadbalancer/strategy" +) + +func TestRoundRobin(t *testing.T) { + p := static.NewPublisher([]endpoint.Endpoint{}) + defer p.Stop() + + lb := strategy.RoundRobin(p) + if _, err := lb.Get(); err == nil { + t.Error("want error, got none") + } + + counts := []int{0, 0, 0} + p.Replace([]endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, + }) + assertLoadBalancerNotEmpty(t, lb) + + for i, want := range [][]int{ + {1, 0, 0}, + {1, 1, 0}, + {1, 1, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 2, 2}, + {3, 2, 2}, + } { + e, _ := lb.Get() + e(context.Background(), struct{}{}) + if have := counts; !reflect.DeepEqual(want, have) { + t.Errorf("%d: want %v, have %v", i+1, want, have) + } + } +} diff --git a/loadbalancer/strategy/strategy.go b/loadbalancer/strategy/strategy.go new file mode 100644 index 0000000..42f18b7 --- /dev/null +++ b/loadbalancer/strategy/strategy.go @@ -0,0 +1,9 @@ +package strategy + +import "github.com/go-kit/kit/endpoint" + +// Strategy yields endpoints to consumers according to some algorithm. +type Strategy interface { + Next() (endpoint.Endpoint, error) + Stop() +} diff --git a/loadbalancer/strategy/util_test.go b/loadbalancer/strategy/util_test.go new file mode 100644 index 0000000..d478447 --- /dev/null +++ b/loadbalancer/strategy/util_test.go @@ -0,0 +1,35 @@ +package strategy_test + +import ( + "fmt" + "testing" + "time" + + "github.com/go-kit/kit/loadbalancer" +) + +func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) { + if err := within(10*time.Millisecond, func() bool { + return lb.Count() > 0 + }); err != nil { + t.Fatal("Publisher never updated endpoints") + } +} + +func within(d time.Duration, f func() bool) error { + var ( + deadline = time.After(d) + ticker = time.NewTicker(d / 10) + ) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if f() { + return nil + } + case <-deadline: + return fmt.Errorf("deadline exceeded") + } + } +} diff --git a/loadbalancer/util_test.go b/loadbalancer/util_test.go deleted file mode 100644 index 4f3bf6b..0000000 --- a/loadbalancer/util_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package loadbalancer_test - -import ( - "fmt" - "testing" - "time" - - "github.com/go-kit/kit/loadbalancer" -) - -func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) { - if err := within(10*time.Millisecond, func() bool { - return lb.Count() > 0 - }); err != nil { - t.Fatal("Publisher never updated endpoints") - } -} - -func within(d time.Duration, f func() bool) error { - var ( - deadline = time.After(d) - ticker = time.NewTicker(d / 10) - ) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if f() { - return nil - } - case <-deadline: - return fmt.Errorf("deadline exceeded") - } - } -}