Codebase list golang-github-go-kit-kit / b315cae
Merge pull request #67 from go-kit/load-balancer-2 package loadbalancer (take 2) Peter Bourgon 8 years ago
14 changed file(s) with 621 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 type cache struct {
5 req chan []endpoint.Endpoint
6 quit chan struct{}
7 }
8
9 func newCache(p Publisher) *cache {
10 c := &cache{
11 req: make(chan []endpoint.Endpoint),
12 quit: make(chan struct{}),
13 }
14 go c.loop(p)
15 return c
16 }
17
18 func (c *cache) loop(p Publisher) {
19 e := make(chan []endpoint.Endpoint, 1)
20 p.Subscribe(e)
21 defer p.Unsubscribe(e)
22 endpoints := <-e
23 for {
24 select {
25 case endpoints = <-e:
26 case c.req <- endpoints:
27 case <-c.quit:
28 return
29 }
30 }
31 }
32
33 func (c *cache) get() []endpoint.Endpoint {
34 return <-c.req
35 }
36
37 func (c *cache) stop() {
38 close(c.quit)
39 }
0 package loadbalancer
1
2 import (
3 "runtime"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 )
10
11 func TestCache(t *testing.T) {
12 e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
13 endpoints := []endpoint.Endpoint{e}
14
15 p := NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := newCache(p)
19 defer c.stop()
20
21 for _, n := range []int{2, 10, 0} {
22 endpoints = make([]endpoint.Endpoint, n)
23 for i := 0; i < n; i++ {
24 endpoints[i] = e
25 }
26 p.Replace(endpoints)
27 runtime.Gosched()
28 if want, have := len(endpoints), len(c.get()); want != have {
29 t.Errorf("want %d, have %d", want, have)
30 }
31 }
32 }
0 package loadbalancer
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 type dnssrvPublisher struct {
13 subscribe chan chan<- []endpoint.Endpoint
14 unsubscribe chan chan<- []endpoint.Endpoint
15 quit chan struct{}
16 }
17
18 // NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and
19 func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher {
20 p := &dnssrvPublisher{
21 subscribe: make(chan chan<- []endpoint.Endpoint),
22 unsubscribe: make(chan chan<- []endpoint.Endpoint),
23 quit: make(chan struct{}),
24 }
25 go p.loop(name, ttl, makeEndpoint)
26 return p
27 }
28
29 func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
30 p.subscribe <- c
31 }
32
33 func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.unsubscribe <- c
35 }
36
37 func (p *dnssrvPublisher) Stop() {
38 close(p.quit)
39 }
40
41 var newTicker = time.NewTicker
42
43 func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) {
44 var (
45 subscriptions = map[chan<- []endpoint.Endpoint]struct{}{}
46 addrs, md5, _ = resolve(name)
47 endpoints = convert(addrs, makeEndpoint)
48 ticker = newTicker(ttl)
49 )
50 defer ticker.Stop()
51 for {
52 select {
53 case <-ticker.C:
54 addrs, newmd5, err := resolve(name)
55 if err == nil && newmd5 != md5 {
56 endpoints = convert(addrs, makeEndpoint)
57 for c := range subscriptions {
58 c <- endpoints
59 }
60 md5 = newmd5
61 }
62
63 case c := <-p.subscribe:
64 subscriptions[c] = struct{}{}
65 c <- endpoints
66
67 case c := <-p.unsubscribe:
68 delete(subscriptions, c)
69
70 case <-p.quit:
71 return
72 }
73 }
74 }
75
76 // Allow mocking in tests.
77 var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) {
78 _, addrs, err = net.LookupSRV("", "", name)
79 if err != nil {
80 return addrs, "", err
81 }
82 hostports := make([]string, len(addrs))
83 for i, addr := range addrs {
84 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
85 }
86 sort.Sort(sort.StringSlice(hostports))
87 h := md5.New()
88 for _, hostport := range hostports {
89 fmt.Fprintf(h, hostport)
90 }
91 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
92 }
93
94 func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint {
95 endpoints := make([]endpoint.Endpoint, len(addrs))
96 for i, addr := range addrs {
97 endpoints[i] = makeEndpoint(addr2hostport(addr))
98 }
99 return endpoints
100 }
101
102 func addr2hostport(addr *net.SRV) string {
103 return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port))
104 }
0 package loadbalancer
1
2 import (
3 "fmt"
4 "net"
5 "testing"
6 "time"
7
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 )
12
13 func TestDNSSRVPublisher(t *testing.T) {
14 // Reset the vars when we're done
15 oldResolve := resolve
16 defer func() { resolve = oldResolve }()
17 oldNewTicker := newTicker
18 defer func() { newTicker = oldNewTicker }()
19
20 // Set up a fixture and swap the vars
21 a := []*net.SRV{
22 {Target: "foo", Port: 123},
23 {Target: "bar", Port: 456},
24 {Target: "baz", Port: 789},
25 }
26 ticker := make(chan time.Time)
27 resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil }
28 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} }
29
30 // Construct endpoint
31 m := map[string]int{}
32 e := func(hostport string) endpoint.Endpoint {
33 return func(context.Context, interface{}) (interface{}, error) {
34 m[hostport]++
35 return struct{}{}, nil
36 }
37 }
38
39 // Build the publisher
40 var (
41 name = "irrelevant"
42 ttl = time.Second
43 makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) }
44 )
45 p := NewDNSSRVPublisher(name, ttl, makeEndpoint)
46 defer p.Stop()
47
48 // Subscribe
49 c := make(chan []endpoint.Endpoint, 1)
50 p.Subscribe(c)
51 defer p.Unsubscribe(c)
52
53 // Invoke all of the endpoints
54 for _, e := range <-c {
55 e(context.Background(), struct{}{})
56 }
57
58 // Make sure we invoked what we expected to
59 for _, addr := range a {
60 hostport := addr2hostport(addr)
61 if want, have := 1, m[hostport]; want != have {
62 t.Errorf("%q: want %d, have %d", name, want, have)
63 }
64 delete(m, hostport)
65 }
66 if want, have := 0, len(m); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Reset the fixture, trigger the timer, count the endpoints
71 a = []*net.SRV{}
72 ticker <- time.Now()
73 if want, have := len(a), len(<-c); want != have {
74 t.Errorf("want %d, have %d", want, have)
75 }
76 }
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // LoadBalancer yields endpoints one-by-one.
9 type LoadBalancer interface {
10 Get() (endpoint.Endpoint, error)
11 }
12
13 // ErrNoEndpointsAvailable is given by a load balancer when no endpoints are
14 // available to be returned.
15 var ErrNoEndpointsAvailable = errors.New("no endpoints available")
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher produces endpoints.
5 type Publisher interface {
6 Subscribe(chan<- []endpoint.Endpoint)
7 Unsubscribe(chan<- []endpoint.Endpoint)
8 Stop()
9 }
0 package loadbalancer
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Random returns a load balancer that yields random endpoints.
9 func Random(p Publisher) LoadBalancer {
10 return random{newCache(p)}
11 }
12
13 type random struct{ *cache }
14
15 func (r random) Get() (endpoint.Endpoint, error) {
16 endpoints := r.cache.get()
17 if len(endpoints) <= 0 {
18 return nil, ErrNoEndpointsAvailable
19 }
20 return endpoints[rand.Intn(len(endpoints))], nil
21 }
0 package loadbalancer_test
1
2 import (
3 "math"
4 "runtime"
5 "testing"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 "golang.org/x/net/context"
10 )
11
12 func TestRandom(t *testing.T) {
13 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
14 defer p.Stop()
15
16 lb := loadbalancer.Random(p)
17 if _, err := lb.Get(); err == nil {
18 t.Error("want error, got none")
19 }
20
21 counts := []int{0, 0, 0}
22 p.Replace([]endpoint.Endpoint{
23 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
26 })
27 runtime.Gosched()
28
29 n := 10000
30 for i := 0; i < n; i++ {
31 e, _ := lb.Get()
32 e(context.Background(), struct{}{})
33 }
34
35 want := float64(n) / float64(len(counts))
36 tolerance := want / 100.0 // 1%
37 for _, have := range counts {
38 if math.Abs(want-float64(have)) > tolerance {
39 t.Errorf("want %.0f, have %d", want, have)
40 }
41 }
42 }
0 package loadbalancer
1
2 import (
3 "fmt"
4 "strings"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // Retry yields an endpoint that takes endpoints from the load balancer.
13 // Invocations that return errors will be retried until they succeed, up to
14 // max times, or until the timeout is elapsed, whichever comes first.
15 func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint {
16 return func(ctx context.Context, request interface{}) (interface{}, error) {
17 var (
18 newctx, cancel = context.WithTimeout(ctx, timeout)
19 responses = make(chan interface{}, 1)
20 errs = make(chan error, 1)
21 a = []string{}
22 )
23 defer cancel()
24 for i := 1; i <= max; i++ {
25 go func() {
26 e, err := lb.Get()
27 if err != nil {
28 errs <- err
29 return
30 }
31 response, err := e(newctx, request)
32 if err != nil {
33 errs <- err
34 return
35 }
36 responses <- response
37 }()
38
39 select {
40 case <-newctx.Done():
41 return nil, newctx.Err()
42 case response := <-responses:
43 return response, nil
44 case err := <-errs:
45 a = append(a, err.Error())
46 continue
47 }
48 }
49 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
50 }
51 }
0 package loadbalancer_test
1
2 import (
3 "errors"
4 "runtime"
5 "time"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 "golang.org/x/net/context"
10
11 "testing"
12 )
13
14 func TestRetryMax(t *testing.T) {
15 var (
16 endpoints = []endpoint.Endpoint{}
17 p = loadbalancer.NewStaticPublisher(endpoints)
18 lb = loadbalancer.RoundRobin(p)
19 )
20
21 if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil {
22 t.Errorf("expected error, got none")
23 }
24
25 endpoints = []endpoint.Endpoint{
26 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
27 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
28 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
29 }
30 p.Replace(endpoints)
31 runtime.Gosched()
32
33 if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil {
34 t.Errorf("expected error, got none")
35 }
36
37 if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil {
38 t.Error(err)
39 }
40 }
41
42 func TestRetryTimeout(t *testing.T) {
43 var (
44 step = make(chan struct{})
45 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
46 timeout = time.Millisecond
47 retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e})))
48 errs = make(chan error)
49 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
50 )
51
52 go invoke() // invoke the endpoint
53 step <- struct{}{} // tell the endpoint to return
54 if err := <-errs; err != nil { // that should succeed
55 t.Error(err)
56 }
57
58 go invoke() // invoke the endpoint
59 time.Sleep(2 * timeout) // wait
60 step <- struct{}{} // tell the endpoint to return
61 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
62 t.Errorf("wanted error, got none")
63 }
64 }
0 package loadbalancer
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // RoundRobin returns a load balancer that yields endpoints in sequence.
9 func RoundRobin(p Publisher) LoadBalancer {
10 return &roundRobin{newCache(p), 0}
11 }
12
13 type roundRobin struct {
14 *cache
15 uint64
16 }
17
18 func (r *roundRobin) Get() (endpoint.Endpoint, error) {
19 endpoints := r.cache.get()
20 if len(endpoints) <= 0 {
21 return nil, ErrNoEndpointsAvailable
22 }
23 var old uint64
24 for {
25 old = atomic.LoadUint64(&r.uint64)
26 if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) {
27 break
28 }
29 }
30 return endpoints[old%uint64(len(endpoints))], nil
31 }
0 package loadbalancer_test
1
2 import (
3 "reflect"
4 "runtime"
5 "testing"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 "golang.org/x/net/context"
10 )
11
12 func TestRoundRobin(t *testing.T) {
13 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
14 defer p.Stop()
15
16 lb := loadbalancer.RoundRobin(p)
17 if _, err := lb.Get(); err == nil {
18 t.Error("want error, got none")
19 }
20
21 counts := []int{0, 0, 0}
22 p.Replace([]endpoint.Endpoint{
23 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
26 })
27 runtime.Gosched()
28
29 for i, want := range [][]int{
30 {1, 0, 0},
31 {1, 1, 0},
32 {1, 1, 1},
33 {2, 1, 1},
34 {2, 2, 1},
35 {2, 2, 2},
36 {3, 2, 2},
37 } {
38 e, _ := lb.Get()
39 e(context.Background(), struct{}{})
40 if have := counts; !reflect.DeepEqual(want, have) {
41 t.Errorf("%d: want %v, have %v", i+1, want, have)
42 }
43 }
44 }
0 package loadbalancer
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // NewStaticPublisher returns a publisher that yields a static set of
9 // endpoints, which can be completely replaced.
10 func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher {
11 return &StaticPublisher{
12 current: endpoints,
13 subscribers: map[chan<- []endpoint.Endpoint]struct{}{},
14 }
15 }
16
17 // StaticPublisher holds a static set of endpoints.
18 type StaticPublisher struct {
19 sync.Mutex
20 current []endpoint.Endpoint
21 subscribers map[chan<- []endpoint.Endpoint]struct{}
22 }
23
24 // Subscribe implements Publisher.
25 func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
26 p.Lock()
27 defer p.Unlock()
28 p.subscribers[c] = struct{}{}
29 c <- p.current
30 }
31
32 // Unsubscribe implements Publisher.
33 func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.Lock()
35 defer p.Unlock()
36 delete(p.subscribers, c)
37 }
38
39 // Stop implements Publisher, but is a no-op.
40 func (p *StaticPublisher) Stop() {}
41
42 // Replace replaces the endpoints and notifies all subscribers.
43 func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) {
44 p.Lock()
45 defer p.Unlock()
46 p.current = endpoints
47 for c := range p.subscribers {
48 c <- p.current
49 }
50 }
0 package loadbalancer_test
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 )
10
11 func TestStaticPublisher(t *testing.T) {
12 endpoints := []endpoint.Endpoint{
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 }
15 p := loadbalancer.NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := make(chan []endpoint.Endpoint, 1)
19 p.Subscribe(c)
20 if want, have := len(endpoints), len(<-c); want != have {
21 t.Errorf("want %d, have %d", want, have)
22 }
23
24 endpoints = []endpoint.Endpoint{}
25 p.Replace(endpoints)
26 if want, have := len(endpoints), len(<-c); want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }