Merge pull request #67 from go-kit/load-balancer-2
package loadbalancer (take 2)
Peter Bourgon
8 years ago
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | type cache struct { | |
5 | req chan []endpoint.Endpoint | |
6 | quit chan struct{} | |
7 | } | |
8 | ||
9 | func newCache(p Publisher) *cache { | |
10 | c := &cache{ | |
11 | req: make(chan []endpoint.Endpoint), | |
12 | quit: make(chan struct{}), | |
13 | } | |
14 | go c.loop(p) | |
15 | return c | |
16 | } | |
17 | ||
18 | func (c *cache) loop(p Publisher) { | |
19 | e := make(chan []endpoint.Endpoint, 1) | |
20 | p.Subscribe(e) | |
21 | defer p.Unsubscribe(e) | |
22 | endpoints := <-e | |
23 | for { | |
24 | select { | |
25 | case endpoints = <-e: | |
26 | case c.req <- endpoints: | |
27 | case <-c.quit: | |
28 | return | |
29 | } | |
30 | } | |
31 | } | |
32 | ||
33 | func (c *cache) get() []endpoint.Endpoint { | |
34 | return <-c.req | |
35 | } | |
36 | ||
37 | func (c *cache) stop() { | |
38 | close(c.quit) | |
39 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "runtime" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | ) | |
10 | ||
11 | func TestCache(t *testing.T) { | |
12 | e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
13 | endpoints := []endpoint.Endpoint{e} | |
14 | ||
15 | p := NewStaticPublisher(endpoints) | |
16 | defer p.Stop() | |
17 | ||
18 | c := newCache(p) | |
19 | defer c.stop() | |
20 | ||
21 | for _, n := range []int{2, 10, 0} { | |
22 | endpoints = make([]endpoint.Endpoint, n) | |
23 | for i := 0; i < n; i++ { | |
24 | endpoints[i] = e | |
25 | } | |
26 | p.Replace(endpoints) | |
27 | runtime.Gosched() | |
28 | if want, have := len(endpoints), len(c.get()); want != have { | |
29 | t.Errorf("want %d, have %d", want, have) | |
30 | } | |
31 | } | |
32 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "crypto/md5" | |
4 | "fmt" | |
5 | "net" | |
6 | "sort" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | type dnssrvPublisher struct { | |
13 | subscribe chan chan<- []endpoint.Endpoint | |
14 | unsubscribe chan chan<- []endpoint.Endpoint | |
15 | quit chan struct{} | |
16 | } | |
17 | ||
18 | // NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and | |
19 | func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher { | |
20 | p := &dnssrvPublisher{ | |
21 | subscribe: make(chan chan<- []endpoint.Endpoint), | |
22 | unsubscribe: make(chan chan<- []endpoint.Endpoint), | |
23 | quit: make(chan struct{}), | |
24 | } | |
25 | go p.loop(name, ttl, makeEndpoint) | |
26 | return p | |
27 | } | |
28 | ||
29 | func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
30 | p.subscribe <- c | |
31 | } | |
32 | ||
33 | func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
34 | p.unsubscribe <- c | |
35 | } | |
36 | ||
37 | func (p *dnssrvPublisher) Stop() { | |
38 | close(p.quit) | |
39 | } | |
40 | ||
41 | var newTicker = time.NewTicker | |
42 | ||
43 | func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) { | |
44 | var ( | |
45 | subscriptions = map[chan<- []endpoint.Endpoint]struct{}{} | |
46 | addrs, md5, _ = resolve(name) | |
47 | endpoints = convert(addrs, makeEndpoint) | |
48 | ticker = newTicker(ttl) | |
49 | ) | |
50 | defer ticker.Stop() | |
51 | for { | |
52 | select { | |
53 | case <-ticker.C: | |
54 | addrs, newmd5, err := resolve(name) | |
55 | if err == nil && newmd5 != md5 { | |
56 | endpoints = convert(addrs, makeEndpoint) | |
57 | for c := range subscriptions { | |
58 | c <- endpoints | |
59 | } | |
60 | md5 = newmd5 | |
61 | } | |
62 | ||
63 | case c := <-p.subscribe: | |
64 | subscriptions[c] = struct{}{} | |
65 | c <- endpoints | |
66 | ||
67 | case c := <-p.unsubscribe: | |
68 | delete(subscriptions, c) | |
69 | ||
70 | case <-p.quit: | |
71 | return | |
72 | } | |
73 | } | |
74 | } | |
75 | ||
76 | // Allow mocking in tests. | |
77 | var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) { | |
78 | _, addrs, err = net.LookupSRV("", "", name) | |
79 | if err != nil { | |
80 | return addrs, "", err | |
81 | } | |
82 | hostports := make([]string, len(addrs)) | |
83 | for i, addr := range addrs { | |
84 | hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) | |
85 | } | |
86 | sort.Sort(sort.StringSlice(hostports)) | |
87 | h := md5.New() | |
88 | for _, hostport := range hostports { | |
89 | fmt.Fprintf(h, hostport) | |
90 | } | |
91 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
92 | } | |
93 | ||
94 | func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint { | |
95 | endpoints := make([]endpoint.Endpoint, len(addrs)) | |
96 | for i, addr := range addrs { | |
97 | endpoints[i] = makeEndpoint(addr2hostport(addr)) | |
98 | } | |
99 | return endpoints | |
100 | } | |
101 | ||
102 | func addr2hostport(addr *net.SRV) string { | |
103 | return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port)) | |
104 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | "golang.org/x/net/context" | |
9 | ||
10 | "github.com/go-kit/kit/endpoint" | |
11 | ) | |
12 | ||
13 | func TestDNSSRVPublisher(t *testing.T) { | |
14 | // Reset the vars when we're done | |
15 | oldResolve := resolve | |
16 | defer func() { resolve = oldResolve }() | |
17 | oldNewTicker := newTicker | |
18 | defer func() { newTicker = oldNewTicker }() | |
19 | ||
20 | // Set up a fixture and swap the vars | |
21 | a := []*net.SRV{ | |
22 | {Target: "foo", Port: 123}, | |
23 | {Target: "bar", Port: 456}, | |
24 | {Target: "baz", Port: 789}, | |
25 | } | |
26 | ticker := make(chan time.Time) | |
27 | resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil } | |
28 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} } | |
29 | ||
30 | // Construct endpoint | |
31 | m := map[string]int{} | |
32 | e := func(hostport string) endpoint.Endpoint { | |
33 | return func(context.Context, interface{}) (interface{}, error) { | |
34 | m[hostport]++ | |
35 | return struct{}{}, nil | |
36 | } | |
37 | } | |
38 | ||
39 | // Build the publisher | |
40 | var ( | |
41 | name = "irrelevant" | |
42 | ttl = time.Second | |
43 | makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) } | |
44 | ) | |
45 | p := NewDNSSRVPublisher(name, ttl, makeEndpoint) | |
46 | defer p.Stop() | |
47 | ||
48 | // Subscribe | |
49 | c := make(chan []endpoint.Endpoint, 1) | |
50 | p.Subscribe(c) | |
51 | defer p.Unsubscribe(c) | |
52 | ||
53 | // Invoke all of the endpoints | |
54 | for _, e := range <-c { | |
55 | e(context.Background(), struct{}{}) | |
56 | } | |
57 | ||
58 | // Make sure we invoked what we expected to | |
59 | for _, addr := range a { | |
60 | hostport := addr2hostport(addr) | |
61 | if want, have := 1, m[hostport]; want != have { | |
62 | t.Errorf("%q: want %d, have %d", name, want, have) | |
63 | } | |
64 | delete(m, hostport) | |
65 | } | |
66 | if want, have := 0, len(m); want != have { | |
67 | t.Errorf("want %d, have %d", want, have) | |
68 | } | |
69 | ||
70 | // Reset the fixture, trigger the timer, count the endpoints | |
71 | a = []*net.SRV{} | |
72 | ticker <- time.Now() | |
73 | if want, have := len(a), len(<-c); want != have { | |
74 | t.Errorf("want %d, have %d", want, have) | |
75 | } | |
76 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // LoadBalancer yields endpoints one-by-one. | |
9 | type LoadBalancer interface { | |
10 | Get() (endpoint.Endpoint, error) | |
11 | } | |
12 | ||
13 | // ErrNoEndpointsAvailable is given by a load balancer when no endpoints are | |
14 | // available to be returned. | |
15 | var ErrNoEndpointsAvailable = errors.New("no endpoints available") |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Publisher produces endpoints. | |
5 | type Publisher interface { | |
6 | Subscribe(chan<- []endpoint.Endpoint) | |
7 | Unsubscribe(chan<- []endpoint.Endpoint) | |
8 | Stop() | |
9 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Random returns a load balancer that yields random endpoints. | |
9 | func Random(p Publisher) LoadBalancer { | |
10 | return random{newCache(p)} | |
11 | } | |
12 | ||
13 | type random struct{ *cache } | |
14 | ||
15 | func (r random) Get() (endpoint.Endpoint, error) { | |
16 | endpoints := r.cache.get() | |
17 | if len(endpoints) <= 0 { | |
18 | return nil, ErrNoEndpointsAvailable | |
19 | } | |
20 | return endpoints[rand.Intn(len(endpoints))], nil | |
21 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "math" | |
4 | "runtime" | |
5 | "testing" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | "golang.org/x/net/context" | |
10 | ) | |
11 | ||
12 | func TestRandom(t *testing.T) { | |
13 | p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) | |
14 | defer p.Stop() | |
15 | ||
16 | lb := loadbalancer.Random(p) | |
17 | if _, err := lb.Get(); err == nil { | |
18 | t.Error("want error, got none") | |
19 | } | |
20 | ||
21 | counts := []int{0, 0, 0} | |
22 | p.Replace([]endpoint.Endpoint{ | |
23 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
25 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
26 | }) | |
27 | runtime.Gosched() | |
28 | ||
29 | n := 10000 | |
30 | for i := 0; i < n; i++ { | |
31 | e, _ := lb.Get() | |
32 | e(context.Background(), struct{}{}) | |
33 | } | |
34 | ||
35 | want := float64(n) / float64(len(counts)) | |
36 | tolerance := want / 100.0 // 1% | |
37 | for _, have := range counts { | |
38 | if math.Abs(want-float64(have)) > tolerance { | |
39 | t.Errorf("want %.0f, have %d", want, have) | |
40 | } | |
41 | } | |
42 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | // Retry yields an endpoint that takes endpoints from the load balancer. | |
13 | // Invocations that return errors will be retried until they succeed, up to | |
14 | // max times, or until the timeout is elapsed, whichever comes first. | |
15 | func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { | |
16 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
17 | var ( | |
18 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
19 | responses = make(chan interface{}, 1) | |
20 | errs = make(chan error, 1) | |
21 | a = []string{} | |
22 | ) | |
23 | defer cancel() | |
24 | for i := 1; i <= max; i++ { | |
25 | go func() { | |
26 | e, err := lb.Get() | |
27 | if err != nil { | |
28 | errs <- err | |
29 | return | |
30 | } | |
31 | response, err := e(newctx, request) | |
32 | if err != nil { | |
33 | errs <- err | |
34 | return | |
35 | } | |
36 | responses <- response | |
37 | }() | |
38 | ||
39 | select { | |
40 | case <-newctx.Done(): | |
41 | return nil, newctx.Err() | |
42 | case response := <-responses: | |
43 | return response, nil | |
44 | case err := <-errs: | |
45 | a = append(a, err.Error()) | |
46 | continue | |
47 | } | |
48 | } | |
49 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
50 | } | |
51 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "runtime" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | "golang.org/x/net/context" | |
10 | ||
11 | "testing" | |
12 | ) | |
13 | ||
14 | func TestRetryMax(t *testing.T) { | |
15 | var ( | |
16 | endpoints = []endpoint.Endpoint{} | |
17 | p = loadbalancer.NewStaticPublisher(endpoints) | |
18 | lb = loadbalancer.RoundRobin(p) | |
19 | ) | |
20 | ||
21 | if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
22 | t.Errorf("expected error, got none") | |
23 | } | |
24 | ||
25 | endpoints = []endpoint.Endpoint{ | |
26 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
27 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
28 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
29 | } | |
30 | p.Replace(endpoints) | |
31 | runtime.Gosched() | |
32 | ||
33 | if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
34 | t.Errorf("expected error, got none") | |
35 | } | |
36 | ||
37 | if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil { | |
38 | t.Error(err) | |
39 | } | |
40 | } | |
41 | ||
42 | func TestRetryTimeout(t *testing.T) { | |
43 | var ( | |
44 | step = make(chan struct{}) | |
45 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } | |
46 | timeout = time.Millisecond | |
47 | retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e}))) | |
48 | errs = make(chan error) | |
49 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } | |
50 | ) | |
51 | ||
52 | go invoke() // invoke the endpoint | |
53 | step <- struct{}{} // tell the endpoint to return | |
54 | if err := <-errs; err != nil { // that should succeed | |
55 | t.Error(err) | |
56 | } | |
57 | ||
58 | go invoke() // invoke the endpoint | |
59 | time.Sleep(2 * timeout) // wait | |
60 | step <- struct{}{} // tell the endpoint to return | |
61 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
62 | t.Errorf("wanted error, got none") | |
63 | } | |
64 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "sync/atomic" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // RoundRobin returns a load balancer that yields endpoints in sequence. | |
9 | func RoundRobin(p Publisher) LoadBalancer { | |
10 | return &roundRobin{newCache(p), 0} | |
11 | } | |
12 | ||
13 | type roundRobin struct { | |
14 | *cache | |
15 | uint64 | |
16 | } | |
17 | ||
18 | func (r *roundRobin) Get() (endpoint.Endpoint, error) { | |
19 | endpoints := r.cache.get() | |
20 | if len(endpoints) <= 0 { | |
21 | return nil, ErrNoEndpointsAvailable | |
22 | } | |
23 | var old uint64 | |
24 | for { | |
25 | old = atomic.LoadUint64(&r.uint64) | |
26 | if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) { | |
27 | break | |
28 | } | |
29 | } | |
30 | return endpoints[old%uint64(len(endpoints))], nil | |
31 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "runtime" | |
5 | "testing" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | "golang.org/x/net/context" | |
10 | ) | |
11 | ||
12 | func TestRoundRobin(t *testing.T) { | |
13 | p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) | |
14 | defer p.Stop() | |
15 | ||
16 | lb := loadbalancer.RoundRobin(p) | |
17 | if _, err := lb.Get(); err == nil { | |
18 | t.Error("want error, got none") | |
19 | } | |
20 | ||
21 | counts := []int{0, 0, 0} | |
22 | p.Replace([]endpoint.Endpoint{ | |
23 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
25 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
26 | }) | |
27 | runtime.Gosched() | |
28 | ||
29 | for i, want := range [][]int{ | |
30 | {1, 0, 0}, | |
31 | {1, 1, 0}, | |
32 | {1, 1, 1}, | |
33 | {2, 1, 1}, | |
34 | {2, 2, 1}, | |
35 | {2, 2, 2}, | |
36 | {3, 2, 2}, | |
37 | } { | |
38 | e, _ := lb.Get() | |
39 | e(context.Background(), struct{}{}) | |
40 | if have := counts; !reflect.DeepEqual(want, have) { | |
41 | t.Errorf("%d: want %v, have %v", i+1, want, have) | |
42 | } | |
43 | } | |
44 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // NewStaticPublisher returns a publisher that yields a static set of | |
9 | // endpoints, which can be completely replaced. | |
10 | func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher { | |
11 | return &StaticPublisher{ | |
12 | current: endpoints, | |
13 | subscribers: map[chan<- []endpoint.Endpoint]struct{}{}, | |
14 | } | |
15 | } | |
16 | ||
17 | // StaticPublisher holds a static set of endpoints. | |
18 | type StaticPublisher struct { | |
19 | sync.Mutex | |
20 | current []endpoint.Endpoint | |
21 | subscribers map[chan<- []endpoint.Endpoint]struct{} | |
22 | } | |
23 | ||
24 | // Subscribe implements Publisher. | |
25 | func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
26 | p.Lock() | |
27 | defer p.Unlock() | |
28 | p.subscribers[c] = struct{}{} | |
29 | c <- p.current | |
30 | } | |
31 | ||
32 | // Unsubscribe implements Publisher. | |
33 | func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
34 | p.Lock() | |
35 | defer p.Unlock() | |
36 | delete(p.subscribers, c) | |
37 | } | |
38 | ||
39 | // Stop implements Publisher, but is a no-op. | |
40 | func (p *StaticPublisher) Stop() {} | |
41 | ||
42 | // Replace replaces the endpoints and notifies all subscribers. | |
43 | func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) { | |
44 | p.Lock() | |
45 | defer p.Unlock() | |
46 | p.current = endpoints | |
47 | for c := range p.subscribers { | |
48 | c <- p.current | |
49 | } | |
50 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "golang.org/x/net/context" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | ) | |
10 | ||
11 | func TestStaticPublisher(t *testing.T) { | |
12 | endpoints := []endpoint.Endpoint{ | |
13 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
14 | } | |
15 | p := loadbalancer.NewStaticPublisher(endpoints) | |
16 | defer p.Stop() | |
17 | ||
18 | c := make(chan []endpoint.Endpoint, 1) | |
19 | p.Subscribe(c) | |
20 | if want, have := len(endpoints), len(<-c); want != have { | |
21 | t.Errorf("want %d, have %d", want, have) | |
22 | } | |
23 | ||
24 | endpoints = []endpoint.Endpoint{} | |
25 | p.Replace(endpoints) | |
26 | if want, have := len(endpoints), len(<-c); want != have { | |
27 | t.Errorf("want %d, have %d", want, have) | |
28 | } | |
29 | } |