Codebase list golang-github-go-kit-kit / 78fd391
Split out loadbalancer components to packages Peter Bourgon 8 years ago
31 changed file(s) with 568 addition(s) and 672 deletion(s). Raw diff Collapse all Expand all
+0
-47
loadbalancer/cache.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 type cache struct {
5 req chan []endpoint.Endpoint
6 cnt chan int
7 quit chan struct{}
8 }
9
10 func newCache(p Publisher) *cache {
11 c := &cache{
12 req: make(chan []endpoint.Endpoint),
13 cnt: make(chan int),
14 quit: make(chan struct{}),
15 }
16 go c.loop(p)
17 return c
18 }
19
20 func (c *cache) loop(p Publisher) {
21 e := make(chan []endpoint.Endpoint, 1)
22 p.Subscribe(e)
23 defer p.Unsubscribe(e)
24 endpoints := <-e
25 for {
26 select {
27 case endpoints = <-e:
28 case c.cnt <- len(endpoints):
29 case c.req <- endpoints:
30 case <-c.quit:
31 return
32 }
33 }
34 }
35
36 func (c *cache) count() int {
37 return <-c.cnt
38 }
39
40 func (c *cache) get() []endpoint.Endpoint {
41 return <-c.req
42 }
43
44 func (c *cache) stop() {
45 close(c.quit)
46 }
+0
-33
loadbalancer/cache_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "runtime"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 )
10
11 func TestCache(t *testing.T) {
12 e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
13 endpoints := []endpoint.Endpoint{e}
14
15 p := NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := newCache(p)
19 defer c.stop()
20
21 for _, n := range []int{2, 10, 0} {
22 endpoints = make([]endpoint.Endpoint, n)
23 for i := 0; i < n; i++ {
24 endpoints[i] = e
25 }
26 p.Replace(endpoints)
27 runtime.Gosched()
28 if want, have := len(endpoints), len(c.get()); want != have {
29 t.Errorf("want %d, have %d", want, have)
30 }
31 }
32 }
+0
-105
loadbalancer/dns_srv_publisher.go less more
0 package loadbalancer
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 type dnssrvPublisher struct {
13 subscribe chan chan<- []endpoint.Endpoint
14 unsubscribe chan chan<- []endpoint.Endpoint
15 quit chan struct{}
16 }
17
18 // NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and
19 func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher {
20 p := &dnssrvPublisher{
21 subscribe: make(chan chan<- []endpoint.Endpoint),
22 unsubscribe: make(chan chan<- []endpoint.Endpoint),
23 quit: make(chan struct{}),
24 }
25 go p.loop(name, ttl, makeEndpoint)
26 return p
27 }
28
29 func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
30 p.subscribe <- c
31 }
32
33 func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.unsubscribe <- c
35 }
36
37 func (p *dnssrvPublisher) Stop() {
38 close(p.quit)
39 }
40
41 var newTicker = time.NewTicker
42
43 func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) {
44 var (
45 subscriptions = map[chan<- []endpoint.Endpoint]struct{}{}
46 addrs, md5, _ = resolve(name)
47 endpoints = convert(addrs, makeEndpoint)
48 ticker = newTicker(ttl)
49 )
50 defer ticker.Stop()
51 for {
52 select {
53 case <-ticker.C:
54 addrs, newmd5, err := resolve(name)
55 if err == nil && newmd5 != md5 {
56 endpoints = convert(addrs, makeEndpoint)
57 for c := range subscriptions {
58 c <- endpoints
59 }
60 md5 = newmd5
61 }
62
63 case c := <-p.subscribe:
64 subscriptions[c] = struct{}{}
65 c <- endpoints
66
67 case c := <-p.unsubscribe:
68 delete(subscriptions, c)
69
70 case <-p.quit:
71 return
72 }
73 }
74 }
75
76 // Allow mocking in tests.
77 var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) {
78 _, addrs, err = net.LookupSRV("", "", name)
79 if err != nil {
80 return addrs, "", err
81 }
82 hostports := make([]string, len(addrs))
83 for i, addr := range addrs {
84 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
85 }
86 sort.Sort(sort.StringSlice(hostports))
87 h := md5.New()
88 for _, hostport := range hostports {
89 fmt.Fprintf(h, hostport)
90 }
91 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
92 }
93
94 func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint {
95 endpoints := make([]endpoint.Endpoint, len(addrs))
96 for i, addr := range addrs {
97 endpoints[i] = makeEndpoint(addr2hostport(addr))
98 }
99 return endpoints
100 }
101
102 func addr2hostport(addr *net.SRV) string {
103 return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port))
104 }
+0
-77
loadbalancer/dns_srv_publisher_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "fmt"
4 "net"
5 "testing"
6 "time"
7
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 )
12
13 func TestDNSSRVPublisher(t *testing.T) {
14 // Reset the vars when we're done
15 oldResolve := resolve
16 defer func() { resolve = oldResolve }()
17 oldNewTicker := newTicker
18 defer func() { newTicker = oldNewTicker }()
19
20 // Set up a fixture and swap the vars
21 a := []*net.SRV{
22 {Target: "foo", Port: 123},
23 {Target: "bar", Port: 456},
24 {Target: "baz", Port: 789},
25 }
26 ticker := make(chan time.Time)
27 resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil }
28 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} }
29
30 // Construct endpoint
31 m := map[string]int{}
32 e := func(hostport string) endpoint.Endpoint {
33 return func(context.Context, interface{}) (interface{}, error) {
34 m[hostport]++
35 return struct{}{}, nil
36 }
37 }
38
39 // Build the publisher
40 var (
41 name = "irrelevant"
42 ttl = time.Second
43 makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) }
44 )
45 p := NewDNSSRVPublisher(name, ttl, makeEndpoint)
46 defer p.Stop()
47
48 // Subscribe
49 c := make(chan []endpoint.Endpoint, 1)
50 p.Subscribe(c)
51 defer p.Unsubscribe(c)
52
53 // Invoke all of the endpoints
54 for _, e := range <-c {
55 e(context.Background(), struct{}{})
56 }
57
58 // Make sure we invoked what we expected to
59 for _, addr := range a {
60 hostport := addr2hostport(addr)
61 if want, have := 1, m[hostport]; want != have {
62 t.Errorf("%q: want %d, have %d", name, want, have)
63 }
64 delete(m, hostport)
65 }
66 if want, have := 0, len(m); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Reset the fixture, trigger the timer, count the endpoints
71 a = []*net.SRV{}
72 ticker <- time.Now()
73 if want, have := len(a), len(<-c); want != have {
74 t.Errorf("want %d, have %d", want, have)
75 }
76 }
+0
-41
loadbalancer/endpoint_cache.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 type endpointCache struct {
5 requests chan []endpoint.Endpoint
6 quit chan struct{}
7 }
8
9 func newEndpointCache(p Publisher) *endpointCache {
10 c := &endpointCache{
11 requests: make(chan []endpoint.Endpoint),
12 quit: make(chan struct{}),
13 }
14 go c.loop(p)
15 return c
16 }
17
18 func (c *endpointCache) loop(p Publisher) {
19 updates := make(chan []endpoint.Endpoint, 1)
20 p.Subscribe(updates)
21 defer p.Unsubscribe(updates)
22 endpoints := <-updates
23
24 for {
25 select {
26 case endpoints = <-updates:
27 case c.requests <- endpoints:
28 case <-c.quit:
29 return
30 }
31 }
32 }
33
34 func (c *endpointCache) get() []endpoint.Endpoint {
35 return <-c.requests
36 }
37
38 func (c *endpointCache) stop() {
39 close(c.quit)
40 }
+0
-27
loadbalancer/endpoint_cache_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 )
9
10 func TestEndpointCache(t *testing.T) {
11 endpoints := []endpoint.Endpoint{
12 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
15 }
16
17 p := NewStaticPublisher(endpoints)
18 defer p.Stop()
19
20 c := newEndpointCache(p)
21 defer c.stop()
22
23 if want, have := len(endpoints), len(c.get()); want != have {
24 t.Errorf("want %d, have %d", want, have)
25 }
26 }
1111 Get() (endpoint.Endpoint, error)
1212 }
1313
14 // ErrNoEndpointsAvailable is given by a load balancer when no endpoints are
15 // available to be returned.
14 // ErrNoEndpointsAvailable is given by a load balancer or strategy when no
15 // endpoints are available to be returned.
1616 var ErrNoEndpointsAvailable = errors.New("no endpoints available")
+0
-46
loadbalancer/mock_publisher_test.go less more
0 package loadbalancer_test
1
2 import (
3 "runtime"
4 "sync"
5
6 "github.com/go-kit/kit/endpoint"
7 )
8
9 type mockPublisher struct {
10 sync.Mutex
11 e []endpoint.Endpoint
12 s map[chan<- []endpoint.Endpoint]struct{}
13 }
14
15 func newMockPublisher(endpoints []endpoint.Endpoint) *mockPublisher {
16 return &mockPublisher{
17 e: endpoints,
18 s: map[chan<- []endpoint.Endpoint]struct{}{},
19 }
20 }
21
22 func (p *mockPublisher) replace(endpoints []endpoint.Endpoint) {
23 p.Lock()
24 defer p.Unlock()
25 p.e = endpoints
26 for s := range p.s {
27 s <- p.e
28 }
29 runtime.Gosched()
30 }
31
32 func (p *mockPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
33 p.Lock()
34 defer p.Unlock()
35 p.s[c] = struct{}{}
36 c <- p.e
37 }
38
39 func (p *mockPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
40 p.Lock()
41 defer p.Unlock()
42 delete(p.s, c)
43 }
44
45 func (p *mockPublisher) Stop() {}
0 package dns
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // SRVPublisher implements Publisher.
13 type SRVPublisher struct {
14 subscribe chan chan<- []endpoint.Endpoint
15 unsubscribe chan chan<- []endpoint.Endpoint
16 quit chan struct{}
17 }
18
19 // NewSRVPublisher returns a publisher that resolves the SRV name every ttl,
20 // and yields endpoints constructed via the makeEndpoint factory.
21 func NewSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) *SRVPublisher {
22 p := &SRVPublisher{
23 subscribe: make(chan chan<- []endpoint.Endpoint),
24 unsubscribe: make(chan chan<- []endpoint.Endpoint),
25 quit: make(chan struct{}),
26 }
27 go p.loop(name, ttl, makeEndpoint)
28 return p
29 }
30
31 // Subscribe implements Publisher.
32 func (p *SRVPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
33 p.subscribe <- c
34 }
35
36 // Unsubscribe implements Publisher.
37 func (p *SRVPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
38 p.unsubscribe <- c
39 }
40
41 // Stop implements Publisher.
42 func (p *SRVPublisher) Stop() {
43 close(p.quit)
44 }
45
46 var newTicker = time.NewTicker
47
48 func (p *SRVPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) {
49 var (
50 subscriptions = map[chan<- []endpoint.Endpoint]struct{}{}
51 addrs, md5, _ = resolve(name)
52 endpoints = convert(addrs, makeEndpoint)
53 ticker = newTicker(ttl)
54 )
55 defer ticker.Stop()
56 for {
57 select {
58 case <-ticker.C:
59 addrs, newmd5, err := resolve(name)
60 if err == nil && newmd5 != md5 {
61 endpoints = convert(addrs, makeEndpoint)
62 for c := range subscriptions {
63 c <- endpoints
64 }
65 md5 = newmd5
66 }
67
68 case c := <-p.subscribe:
69 subscriptions[c] = struct{}{}
70 c <- endpoints
71
72 case c := <-p.unsubscribe:
73 delete(subscriptions, c)
74
75 case <-p.quit:
76 return
77 }
78 }
79 }
80
81 // Allow mocking in tests.
82 var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) {
83 _, addrs, err = net.LookupSRV("", "", name)
84 if err != nil {
85 return addrs, "", err
86 }
87 hostports := make([]string, len(addrs))
88 for i, addr := range addrs {
89 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
90 }
91 sort.Sort(sort.StringSlice(hostports))
92 h := md5.New()
93 for _, hostport := range hostports {
94 fmt.Fprintf(h, hostport)
95 }
96 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
97 }
98
99 func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint {
100 endpoints := make([]endpoint.Endpoint, len(addrs))
101 for i, addr := range addrs {
102 endpoints[i] = makeEndpoint(addr2hostport(addr))
103 }
104 return endpoints
105 }
106
107 func addr2hostport(addr *net.SRV) string {
108 return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
109 }
0 package dns
1
2 import (
3 "fmt"
4 "net"
5 "testing"
6 "time"
7
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 )
12
13 func TestDNSSRVPublisher(t *testing.T) {
14 // Reset the vars when we're done
15 oldResolve := resolve
16 defer func() { resolve = oldResolve }()
17 oldNewTicker := newTicker
18 defer func() { newTicker = oldNewTicker }()
19
20 // Set up a fixture and swap the vars
21 a := []*net.SRV{
22 {Target: "foo", Port: 123},
23 {Target: "bar", Port: 456},
24 {Target: "baz", Port: 789},
25 }
26 ticker := make(chan time.Time)
27 resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil }
28 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} }
29
30 // Construct endpoint
31 m := map[string]int{}
32 e := func(hostport string) endpoint.Endpoint {
33 return func(context.Context, interface{}) (interface{}, error) {
34 m[hostport]++
35 return struct{}{}, nil
36 }
37 }
38
39 // Build the publisher
40 var (
41 name = "irrelevant"
42 ttl = time.Second
43 makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) }
44 )
45 p := NewSRVPublisher(name, ttl, makeEndpoint)
46 defer p.Stop()
47
48 // Subscribe
49 c := make(chan []endpoint.Endpoint, 1)
50 p.Subscribe(c)
51 defer p.Unsubscribe(c)
52
53 // Invoke all of the endpoints
54 for _, e := range <-c {
55 e(context.Background(), struct{}{})
56 }
57
58 // Make sure we invoked what we expected to
59 for _, addr := range a {
60 hostport := addr2hostport(addr)
61 if want, have := 1, m[hostport]; want != have {
62 t.Errorf("%q: want %d, have %d", name, want, have)
63 }
64 delete(m, hostport)
65 }
66 if want, have := 0, len(m); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Reset the fixture, trigger the timer, count the endpoints
71 a = []*net.SRV{}
72 ticker <- time.Now()
73 if want, have := len(a), len(<-c); want != have {
74 t.Errorf("want %d, have %d", want, have)
75 }
76 }
0 package publisher
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher produces endpoints.
5 type Publisher interface {
6 Subscribe(chan<- []endpoint.Endpoint)
7 Unsubscribe(chan<- []endpoint.Endpoint)
8 Stop()
9 }
0 package static
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Publisher holds a static set of endpoints.
9 type Publisher struct {
10 mu sync.Mutex
11 current []endpoint.Endpoint
12 subscribers map[chan<- []endpoint.Endpoint]struct{}
13 }
14
15 // NewPublisher returns a publisher that yields a static set of endpoints,
16 // which can be completely replaced.
17 func NewPublisher(endpoints []endpoint.Endpoint) *Publisher {
18 return &Publisher{
19 current: endpoints,
20 subscribers: map[chan<- []endpoint.Endpoint]struct{}{},
21 }
22 }
23
24 // Subscribe implements Publisher.
25 func (p *Publisher) Subscribe(c chan<- []endpoint.Endpoint) {
26 p.mu.Lock()
27 defer p.mu.Unlock()
28 p.subscribers[c] = struct{}{}
29 c <- p.current
30 }
31
32 // Unsubscribe implements Publisher.
33 func (p *Publisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.mu.Lock()
35 defer p.mu.Unlock()
36 delete(p.subscribers, c)
37 }
38
39 // Stop implements Publisher, but is a no-op.
40 func (p *Publisher) Stop() {}
41
42 // Replace replaces the endpoints and notifies all subscribers.
43 func (p *Publisher) Replace(endpoints []endpoint.Endpoint) {
44 p.mu.Lock()
45 defer p.mu.Unlock()
46 p.current = endpoints
47 for c := range p.subscribers {
48 c <- p.current
49 }
50 }
0 package static_test
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer/publisher/static"
9 )
10
11 func TestStaticPublisher(t *testing.T) {
12 endpoints := []endpoint.Endpoint{
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 }
15 p := static.NewPublisher(endpoints)
16 defer p.Stop()
17
18 c := make(chan []endpoint.Endpoint, 1)
19 p.Subscribe(c)
20 if want, have := len(endpoints), len(<-c); want != have {
21 t.Errorf("want %d, have %d", want, have)
22 }
23
24 endpoints = []endpoint.Endpoint{}
25 p.Replace(endpoints)
26 if want, have := len(endpoints), len(<-c); want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }
+0
-10
loadbalancer/publisher.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher produces endpoints.
5 type Publisher interface {
6 Subscribe(chan<- []endpoint.Endpoint)
7 Unsubscribe(chan<- []endpoint.Endpoint)
8 Stop()
9 }
+0
-24
loadbalancer/random.go less more
0 package loadbalancer
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Random returns a load balancer that yields random endpoints.
9 func Random(p Publisher) LoadBalancer {
10 return random{newCache(p)}
11 }
12
13 type random struct{ *cache }
14
15 func (r random) Count() int { return r.cache.count() }
16
17 func (r random) Get() (endpoint.Endpoint, error) {
18 endpoints := r.cache.get()
19 if len(endpoints) <= 0 {
20 return nil, ErrNoEndpointsAvailable
21 }
22 return endpoints[rand.Intn(len(endpoints))], nil
23 }
+0
-42
loadbalancer/random_test.go less more
0 package loadbalancer_test
1
2 import (
3 "math"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer"
8 "golang.org/x/net/context"
9 )
10
11 func TestRandom(t *testing.T) {
12 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
13 defer p.Stop()
14
15 lb := loadbalancer.Random(p)
16 if _, err := lb.Get(); err == nil {
17 t.Error("want error, got none")
18 }
19
20 counts := []int{0, 0, 0}
21 p.Replace([]endpoint.Endpoint{
22 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
23 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
25 })
26 assertLoadBalancerNotEmpty(t, lb)
27
28 n := 10000
29 for i := 0; i < n; i++ {
30 e, _ := lb.Get()
31 e(context.Background(), struct{}{})
32 }
33
34 want := float64(n) / float64(len(counts))
35 tolerance := (want / 100.0) * 5 // 5%
36 for _, have := range counts {
37 if math.Abs(want-float64(have)) > tolerance {
38 t.Errorf("want %.0f, have %d", want, have)
39 }
40 }
41 }
11
22 import (
33 "errors"
4 "testing"
45 "time"
6
7 "golang.org/x/net/context"
58
69 "github.com/go-kit/kit/endpoint"
710 "github.com/go-kit/kit/loadbalancer"
8 "golang.org/x/net/context"
9
10 "testing"
11 "github.com/go-kit/kit/loadbalancer/publisher/static"
12 "github.com/go-kit/kit/loadbalancer/strategy"
1113 )
1214
1315 func TestRetryMax(t *testing.T) {
1416 var (
1517 endpoints = []endpoint.Endpoint{}
16 p = loadbalancer.NewStaticPublisher(endpoints)
17 lb = loadbalancer.RoundRobin(p)
18 p = static.NewPublisher(endpoints)
19 lb = strategy.RoundRobin(p)
1820 )
1921
2022 if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil {
2729 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
2830 }
2931 p.Replace(endpoints)
30 assertLoadBalancerNotEmpty(t, lb)
32 time.Sleep(10 * time.Millisecond) //assertLoadBalancerNotEmpty(t, lb) // TODO
3133
3234 if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil {
3335 t.Errorf("expected error, got none")
4345 step = make(chan struct{})
4446 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
4547 timeout = time.Millisecond
46 retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e})))
48 retry = loadbalancer.Retry(999, timeout, strategy.RoundRobin(static.NewPublisher([]endpoint.Endpoint{e})))
4749 errs = make(chan error)
4850 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
4951 )
+0
-34
loadbalancer/round_robin.go less more
0 package loadbalancer
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // RoundRobin returns a load balancer that yields endpoints in sequence.
9 func RoundRobin(p Publisher) LoadBalancer {
10 return &roundRobin{newCache(p), 0}
11 }
12
13 type roundRobin struct {
14 *cache
15 uint64
16 }
17
18 func (r *roundRobin) Count() int { return r.cache.count() }
19
20 func (r *roundRobin) Get() (endpoint.Endpoint, error) {
21 endpoints := r.cache.get()
22 if len(endpoints) <= 0 {
23 return nil, ErrNoEndpointsAvailable
24 }
25 var old uint64
26 for {
27 old = atomic.LoadUint64(&r.uint64)
28 if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) {
29 break
30 }
31 }
32 return endpoints[old%uint64(len(endpoints))], nil
33 }
+0
-44
loadbalancer/round_robin_test.go less more
0 package loadbalancer_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer"
8 "golang.org/x/net/context"
9 )
10
11 func TestRoundRobin(t *testing.T) {
12 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
13 defer p.Stop()
14
15 lb := loadbalancer.RoundRobin(p)
16 if _, err := lb.Get(); err == nil {
17 t.Error("want error, got none")
18 }
19
20 counts := []int{0, 0, 0}
21 p.Replace([]endpoint.Endpoint{
22 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
23 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
25 })
26 assertLoadBalancerNotEmpty(t, lb)
27
28 for i, want := range [][]int{
29 {1, 0, 0},
30 {1, 1, 0},
31 {1, 1, 1},
32 {2, 1, 1},
33 {2, 2, 1},
34 {2, 2, 2},
35 {3, 2, 2},
36 } {
37 e, _ := lb.Get()
38 e(context.Background(), struct{}{})
39 if have := counts; !reflect.DeepEqual(want, have) {
40 t.Errorf("%d: want %v, have %v", i+1, want, have)
41 }
42 }
43 }
+0
-51
loadbalancer/static_publisher.go less more
0 package loadbalancer
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // NewStaticPublisher returns a publisher that yields a static set of
9 // endpoints, which can be completely replaced.
10 func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher {
11 return &StaticPublisher{
12 current: endpoints,
13 subscribers: map[chan<- []endpoint.Endpoint]struct{}{},
14 }
15 }
16
17 // StaticPublisher holds a static set of endpoints.
18 type StaticPublisher struct {
19 mu sync.Mutex
20 current []endpoint.Endpoint
21 subscribers map[chan<- []endpoint.Endpoint]struct{}
22 }
23
24 // Subscribe implements Publisher.
25 func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
26 p.mu.Lock()
27 defer p.mu.Unlock()
28 p.subscribers[c] = struct{}{}
29 c <- p.current
30 }
31
32 // Unsubscribe implements Publisher.
33 func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.mu.Lock()
35 defer p.mu.Unlock()
36 delete(p.subscribers, c)
37 }
38
39 // Stop implements Publisher, but is a no-op.
40 func (p *StaticPublisher) Stop() {}
41
42 // Replace replaces the endpoints and notifies all subscribers.
43 func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) {
44 p.mu.Lock()
45 defer p.mu.Unlock()
46 p.current = endpoints
47 for c := range p.subscribers {
48 c <- p.current
49 }
50 }
+0
-30
loadbalancer/static_publisher_test.go less more
0 package loadbalancer_test
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 )
10
11 func TestStaticPublisher(t *testing.T) {
12 endpoints := []endpoint.Endpoint{
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 }
15 p := loadbalancer.NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := make(chan []endpoint.Endpoint, 1)
19 p.Subscribe(c)
20 if want, have := len(endpoints), len(<-c); want != have {
21 t.Errorf("want %d, have %d", want, have)
22 }
23
24 endpoints = []endpoint.Endpoint{}
25 p.Replace(endpoints)
26 if want, have := len(endpoints), len(<-c); want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }
+0
-17
loadbalancer/strategies.go less more
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Strategy yields endpoints to consumers according to some algorithm.
9 type Strategy interface {
10 Next() (endpoint.Endpoint, error)
11 Stop()
12 }
13
14 // ErrNoEndpoints is returned by a strategy when there are no endpoints
15 // available.
16 var ErrNoEndpoints = errors.New("no endpoints available")
0 package strategy
1
2 import (
3 "github.com/go-kit/kit/endpoint"
4 "github.com/go-kit/kit/loadbalancer/publisher"
5 )
6
7 type cache struct {
8 req chan []endpoint.Endpoint
9 cnt chan int
10 quit chan struct{}
11 }
12
13 func newCache(p publisher.Publisher) *cache {
14 c := &cache{
15 req: make(chan []endpoint.Endpoint),
16 cnt: make(chan int),
17 quit: make(chan struct{}),
18 }
19 go c.loop(p)
20 return c
21 }
22
23 func (c *cache) loop(p publisher.Publisher) {
24 e := make(chan []endpoint.Endpoint, 1)
25 p.Subscribe(e)
26 defer p.Unsubscribe(e)
27 endpoints := <-e
28 for {
29 select {
30 case endpoints = <-e:
31 case c.cnt <- len(endpoints):
32 case c.req <- endpoints:
33 case <-c.quit:
34 return
35 }
36 }
37 }
38
39 func (c *cache) count() int {
40 return <-c.cnt
41 }
42
43 func (c *cache) get() []endpoint.Endpoint {
44 return <-c.req
45 }
46
47 func (c *cache) stop() {
48 close(c.quit)
49 }
0 package strategy
1
2 import (
3 "runtime"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/publisher/static"
10 )
11
12 func TestCache(t *testing.T) {
13 e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
14 endpoints := []endpoint.Endpoint{e}
15
16 p := static.NewPublisher(endpoints)
17 defer p.Stop()
18
19 c := newCache(p)
20 defer c.stop()
21
22 for _, n := range []int{2, 10, 0} {
23 endpoints = make([]endpoint.Endpoint, n)
24 for i := 0; i < n; i++ {
25 endpoints[i] = e
26 }
27 p.Replace(endpoints)
28 runtime.Gosched()
29 if want, have := len(endpoints), len(c.get()); want != have {
30 t.Errorf("want %d, have %d", want, have)
31 }
32 }
33 }
0 package strategy
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/loadbalancer"
7 "github.com/go-kit/kit/loadbalancer/publisher"
8 )
9
10 // Random returns a load balancer that yields random endpoints.
11 func Random(p publisher.Publisher) loadbalancer.LoadBalancer {
12 return random{newCache(p)}
13 }
14
15 type random struct{ *cache }
16
17 func (r random) Count() int { return r.cache.count() }
18
19 func (r random) Get() (endpoint.Endpoint, error) {
20 endpoints := r.cache.get()
21 if len(endpoints) <= 0 {
22 return nil, loadbalancer.ErrNoEndpointsAvailable
23 }
24 return endpoints[rand.Intn(len(endpoints))], nil
25 }
0 package strategy_test
1
2 import (
3 "math"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer/publisher/static"
8 "github.com/go-kit/kit/loadbalancer/strategy"
9 "golang.org/x/net/context"
10 )
11
12 func TestRandom(t *testing.T) {
13 p := static.NewPublisher([]endpoint.Endpoint{})
14 defer p.Stop()
15
16 lb := strategy.Random(p)
17 if _, err := lb.Get(); err == nil {
18 t.Error("want error, got none")
19 }
20
21 counts := []int{0, 0, 0}
22 p.Replace([]endpoint.Endpoint{
23 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
26 })
27 assertLoadBalancerNotEmpty(t, lb)
28
29 n := 10000
30 for i := 0; i < n; i++ {
31 e, _ := lb.Get()
32 e(context.Background(), struct{}{})
33 }
34
35 want := float64(n) / float64(len(counts))
36 tolerance := (want / 100.0) * 5 // 5%
37 for _, have := range counts {
38 if math.Abs(want-float64(have)) > tolerance {
39 t.Errorf("want %.0f, have %d", want, have)
40 }
41 }
42 }
0 package strategy
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/loadbalancer"
7 "github.com/go-kit/kit/loadbalancer/publisher"
8 )
9
10 // RoundRobin returns a load balancer that yields endpoints in sequence.
11 func RoundRobin(p publisher.Publisher) loadbalancer.LoadBalancer {
12 return &roundRobin{newCache(p), 0}
13 }
14
15 type roundRobin struct {
16 *cache
17 uint64
18 }
19
20 func (r *roundRobin) Count() int { return r.cache.count() }
21
22 func (r *roundRobin) Get() (endpoint.Endpoint, error) {
23 endpoints := r.cache.get()
24 if len(endpoints) <= 0 {
25 return nil, loadbalancer.ErrNoEndpointsAvailable
26 }
27 var old uint64
28 for {
29 old = atomic.LoadUint64(&r.uint64)
30 if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) {
31 break
32 }
33 }
34 return endpoints[old%uint64(len(endpoints))], nil
35 }
0 package strategy_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/publisher/static"
10 "github.com/go-kit/kit/loadbalancer/strategy"
11 )
12
13 func TestRoundRobin(t *testing.T) {
14 p := static.NewPublisher([]endpoint.Endpoint{})
15 defer p.Stop()
16
17 lb := strategy.RoundRobin(p)
18 if _, err := lb.Get(); err == nil {
19 t.Error("want error, got none")
20 }
21
22 counts := []int{0, 0, 0}
23 p.Replace([]endpoint.Endpoint{
24 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
26 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
27 })
28 assertLoadBalancerNotEmpty(t, lb)
29
30 for i, want := range [][]int{
31 {1, 0, 0},
32 {1, 1, 0},
33 {1, 1, 1},
34 {2, 1, 1},
35 {2, 2, 1},
36 {2, 2, 2},
37 {3, 2, 2},
38 } {
39 e, _ := lb.Get()
40 e(context.Background(), struct{}{})
41 if have := counts; !reflect.DeepEqual(want, have) {
42 t.Errorf("%d: want %v, have %v", i+1, want, have)
43 }
44 }
45 }
0 package strategy
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Strategy yields endpoints to consumers according to some algorithm.
5 type Strategy interface {
6 Next() (endpoint.Endpoint, error)
7 Stop()
8 }
0 package strategy_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6
7 "github.com/go-kit/kit/loadbalancer"
8 )
9
10 func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) {
11 if err := within(10*time.Millisecond, func() bool {
12 return lb.Count() > 0
13 }); err != nil {
14 t.Fatal("Publisher never updated endpoints")
15 }
16 }
17
18 func within(d time.Duration, f func() bool) error {
19 var (
20 deadline = time.After(d)
21 ticker = time.NewTicker(d / 10)
22 )
23 defer ticker.Stop()
24 for {
25 select {
26 case <-ticker.C:
27 if f() {
28 return nil
29 }
30 case <-deadline:
31 return fmt.Errorf("deadline exceeded")
32 }
33 }
34 }
+0
-35
loadbalancer/util_test.go less more
0 package loadbalancer_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6
7 "github.com/go-kit/kit/loadbalancer"
8 )
9
10 func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) {
11 if err := within(10*time.Millisecond, func() bool {
12 return lb.Count() > 0
13 }); err != nil {
14 t.Fatal("Publisher never updated endpoints")
15 }
16 }
17
18 func within(d time.Duration, f func() bool) error {
19 var (
20 deadline = time.After(d)
21 ticker = time.NewTicker(d / 10)
22 )
23 defer ticker.Stop()
24 for {
25 select {
26 case <-ticker.C:
27 if f() {
28 return nil
29 }
30 case <-deadline:
31 return fmt.Errorf("deadline exceeded")
32 }
33 }
34 }