Codebase list golang-github-go-kit-kit / 6044042
Merge pull request #90 from go-kit/load-balancer-packages Enhancements to loadbalancer Peter Bourgon 8 years ago
26 changed file(s) with 620 addition(s) and 654 deletion(s). Raw diff Collapse all Expand all
66 discovered via periodic DNS SRV lookups on a single logical name. Consul and
77 etcd publishers are planned.
88
9 Different load balancing strategies are implemented on top of publishers. Go
10 kit currently provides random and round-robin semantics. Smarter behaviors,
9 Different load balancers are implemented on top of publishers. Go kit
10 currently provides random and round-robin load balancers. Smarter behaviors,
1111 e.g. load balancing based on underlying endpoint priority/weight, is planned.
1212
1313 ## Rationale
1616
1717 ## Usage
1818
19 In your client, define a publisher, wrap it with a balancing strategy, and pass
20 it to a retry strategy, which returns an endpoint. Use that endpoint to make
21 requests, or wrap it with other value-add middleware.
19 In your client, construct a publisher for a specific remote service, and pass
20 it to a load balancer. Then, request an endpoint from the load balancer
21 whenever you need to make a request to that remote service.
22
23 ```go
24 import (
25 "github.com/go-kit/kit/loadbalancer"
26 "github.com/go-kit/kit/loadbalancer/dnssrv"
27 )
28
29 func main() {
30 // Construct a load balancer for foosvc, which gets foosvc instances by
31 // polling a specific DNS SRV name.
32 p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger)
33 lb := loadbalancer.NewRoundRobin(p)
34
35 // Get a new endpoint from the load balancer.
36 endpoint, err := lb.Endpoint()
37 if err != nil {
38 panic(err)
39 }
40
41 // Use the endpoint to make a request.
42 response, err := endpoint(ctx, request)
43 }
44
45 func fooFactory(instance string) (endpoint.Endpoint, error) {
46 // Convert an instance (host:port) to an endpoint, via a defined transport binding.
47 }
48 ```
49
50 It's also possible to wrap a load balancer with a retry strategy, so that it
51 can be used as an endpoint directly. This may make load balancers more
52 convenient to use, at the cost of fine-grained control of failures.
2253
2354 ```go
2455 func main() {
25 var (
26 fooPublisher = loadbalancer.NewDNSSRVPublisher("foo.mynet.local", 5*time.Second, makeEndpoint)
27 fooBalancer = loadbalancer.RoundRobin(fooPublisher)
28 fooEndpoint = loadbalancer.Retry(3, time.Second, fooBalancer)
29 )
30 http.HandleFunc("/", handle(fooEndpoint))
31 log.Fatal(http.ListenAndServe(":8080", nil))
56 p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger)
57 lb := loadbalancer.NewRoundRobin(p)
58 endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb)
59
60 response, err := endpoint(ctx, request) // requests will be automatically load balanced
3261 }
33
34 func makeEndpoint(hostport string) endpoint.Endpoint {
35 // Convert a host:port to a endpoint via your defined transport.
36 }
37
38 func handle(foo endpoint.Endpoint) http.HandlerFunc {
39 return func(w http.ResponseWriter, r *http.Request) {
40 // foo is usable as a load-balanced remote endpoint.
41 }
42 }
43 ```
62 ```
+0
-47
loadbalancer/cache.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 type cache struct {
5 req chan []endpoint.Endpoint
6 cnt chan int
7 quit chan struct{}
8 }
9
10 func newCache(p Publisher) *cache {
11 c := &cache{
12 req: make(chan []endpoint.Endpoint),
13 cnt: make(chan int),
14 quit: make(chan struct{}),
15 }
16 go c.loop(p)
17 return c
18 }
19
20 func (c *cache) loop(p Publisher) {
21 e := make(chan []endpoint.Endpoint, 1)
22 p.Subscribe(e)
23 defer p.Unsubscribe(e)
24 endpoints := <-e
25 for {
26 select {
27 case endpoints = <-e:
28 case c.cnt <- len(endpoints):
29 case c.req <- endpoints:
30 case <-c.quit:
31 return
32 }
33 }
34 }
35
36 func (c *cache) count() int {
37 return <-c.cnt
38 }
39
40 func (c *cache) get() []endpoint.Endpoint {
41 return <-c.req
42 }
43
44 func (c *cache) stop() {
45 close(c.quit)
46 }
+0
-33
loadbalancer/cache_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "runtime"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 )
10
11 func TestCache(t *testing.T) {
12 e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
13 endpoints := []endpoint.Endpoint{e}
14
15 p := NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := newCache(p)
19 defer c.stop()
20
21 for _, n := range []int{2, 10, 0} {
22 endpoints = make([]endpoint.Endpoint, n)
23 for i := 0; i < n; i++ {
24 endpoints[i] = e
25 }
26 p.Replace(endpoints)
27 runtime.Gosched()
28 if want, have := len(endpoints), len(c.get()); want != have {
29 t.Errorf("want %d, have %d", want, have)
30 }
31 }
32 }
+0
-105
loadbalancer/dns_srv_publisher.go less more
0 package loadbalancer
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 type dnssrvPublisher struct {
13 subscribe chan chan<- []endpoint.Endpoint
14 unsubscribe chan chan<- []endpoint.Endpoint
15 quit chan struct{}
16 }
17
18 // NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and
19 func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher {
20 p := &dnssrvPublisher{
21 subscribe: make(chan chan<- []endpoint.Endpoint),
22 unsubscribe: make(chan chan<- []endpoint.Endpoint),
23 quit: make(chan struct{}),
24 }
25 go p.loop(name, ttl, makeEndpoint)
26 return p
27 }
28
29 func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
30 p.subscribe <- c
31 }
32
33 func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.unsubscribe <- c
35 }
36
37 func (p *dnssrvPublisher) Stop() {
38 close(p.quit)
39 }
40
41 var newTicker = time.NewTicker
42
43 func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) {
44 var (
45 subscriptions = map[chan<- []endpoint.Endpoint]struct{}{}
46 addrs, md5, _ = resolve(name)
47 endpoints = convert(addrs, makeEndpoint)
48 ticker = newTicker(ttl)
49 )
50 defer ticker.Stop()
51 for {
52 select {
53 case <-ticker.C:
54 addrs, newmd5, err := resolve(name)
55 if err == nil && newmd5 != md5 {
56 endpoints = convert(addrs, makeEndpoint)
57 for c := range subscriptions {
58 c <- endpoints
59 }
60 md5 = newmd5
61 }
62
63 case c := <-p.subscribe:
64 subscriptions[c] = struct{}{}
65 c <- endpoints
66
67 case c := <-p.unsubscribe:
68 delete(subscriptions, c)
69
70 case <-p.quit:
71 return
72 }
73 }
74 }
75
76 // Allow mocking in tests.
77 var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) {
78 _, addrs, err = net.LookupSRV("", "", name)
79 if err != nil {
80 return addrs, "", err
81 }
82 hostports := make([]string, len(addrs))
83 for i, addr := range addrs {
84 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
85 }
86 sort.Sort(sort.StringSlice(hostports))
87 h := md5.New()
88 for _, hostport := range hostports {
89 fmt.Fprintf(h, hostport)
90 }
91 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
92 }
93
94 func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint {
95 endpoints := make([]endpoint.Endpoint, len(addrs))
96 for i, addr := range addrs {
97 endpoints[i] = makeEndpoint(addr2hostport(addr))
98 }
99 return endpoints
100 }
101
102 func addr2hostport(addr *net.SRV) string {
103 return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port))
104 }
+0
-77
loadbalancer/dns_srv_publisher_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "fmt"
4 "net"
5 "testing"
6 "time"
7
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 )
12
13 func TestDNSSRVPublisher(t *testing.T) {
14 // Reset the vars when we're done
15 oldResolve := resolve
16 defer func() { resolve = oldResolve }()
17 oldNewTicker := newTicker
18 defer func() { newTicker = oldNewTicker }()
19
20 // Set up a fixture and swap the vars
21 a := []*net.SRV{
22 {Target: "foo", Port: 123},
23 {Target: "bar", Port: 456},
24 {Target: "baz", Port: 789},
25 }
26 ticker := make(chan time.Time)
27 resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil }
28 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} }
29
30 // Construct endpoint
31 m := map[string]int{}
32 e := func(hostport string) endpoint.Endpoint {
33 return func(context.Context, interface{}) (interface{}, error) {
34 m[hostport]++
35 return struct{}{}, nil
36 }
37 }
38
39 // Build the publisher
40 var (
41 name = "irrelevant"
42 ttl = time.Second
43 makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) }
44 )
45 p := NewDNSSRVPublisher(name, ttl, makeEndpoint)
46 defer p.Stop()
47
48 // Subscribe
49 c := make(chan []endpoint.Endpoint, 1)
50 p.Subscribe(c)
51 defer p.Unsubscribe(c)
52
53 // Invoke all of the endpoints
54 for _, e := range <-c {
55 e(context.Background(), struct{}{})
56 }
57
58 // Make sure we invoked what we expected to
59 for _, addr := range a {
60 hostport := addr2hostport(addr)
61 if want, have := 1, m[hostport]; want != have {
62 t.Errorf("%q: want %d, have %d", name, want, have)
63 }
64 delete(m, hostport)
65 }
66 if want, have := 0, len(m); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Reset the fixture, trigger the timer, count the endpoints
71 a = []*net.SRV{}
72 ticker <- time.Now()
73 if want, have := len(a), len(<-c); want != have {
74 t.Errorf("want %d, have %d", want, have)
75 }
76 }
0 package dnssrv
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer"
11 "github.com/go-kit/kit/log"
12 )
13
14 // Publisher yields endpoints taken from the named DNS SRV record. The name is
15 // resolved on a fixed schedule. Priorities and weights are ignored.
16 type Publisher struct {
17 name string
18 ttl time.Duration
19 factory loadbalancer.Factory
20 logger log.Logger
21 endpoints chan []endpoint.Endpoint
22 quit chan struct{}
23 }
24
25 // NewPublisher returns a DNS SRV publisher. The name is resolved
26 // synchronously as part of construction; if that resolution fails, the
27 // constructor will return an error. The factory is used to convert a
28 // host:port to a usable endpoint. The logger is used to report DNS and
29 // factory errors.
30 func NewPublisher(name string, ttl time.Duration, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) {
31 logger = log.NewContext(logger).With("component", "DNS SRV Publisher")
32 addrs, md5, err := resolve(name)
33 if err != nil {
34 return nil, err
35 }
36 p := &Publisher{
37 name: name,
38 ttl: ttl,
39 factory: f,
40 logger: logger,
41 endpoints: make(chan []endpoint.Endpoint),
42 quit: make(chan struct{}),
43 }
44 go p.loop(makeEndpoints(addrs, f, logger), md5)
45 return p, nil
46 }
47
48 // Stop terminates the publisher.
49 func (p *Publisher) Stop() {
50 close(p.quit)
51 }
52
53 func (p *Publisher) loop(endpoints []endpoint.Endpoint, md5 string) {
54 t := newTicker(p.ttl)
55 defer t.Stop()
56 for {
57 select {
58 case p.endpoints <- endpoints:
59
60 case <-t.C:
61 // TODO should we do this out-of-band?
62 addrs, newmd5, err := resolve(p.name)
63 if err != nil {
64 p.logger.Log("name", p.name, "err", err)
65 continue // don't replace good endpoints with bad ones
66 }
67 if newmd5 == md5 {
68 continue // no change
69 }
70 endpoints = makeEndpoints(addrs, p.factory, p.logger)
71 md5 = newmd5
72
73 case <-p.quit:
74 return
75 }
76 }
77 }
78
79 // Endpoints implements the Publisher interface.
80 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
81 select {
82 case endpoints := <-p.endpoints:
83 return endpoints, nil
84 case <-p.quit:
85 return nil, loadbalancer.ErrPublisherStopped
86 }
87 }
88
89 var (
90 lookupSRV = net.LookupSRV
91 newTicker = time.NewTicker
92 )
93
94 func resolve(name string) (addrs []*net.SRV, md5sum string, err error) {
95 _, addrs, err = lookupSRV("", "", name)
96 if err != nil {
97 return addrs, "", err
98 }
99 hostports := make([]string, len(addrs))
100 for i, addr := range addrs {
101 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
102 }
103 sort.Sort(sort.StringSlice(hostports))
104 h := md5.New()
105 for _, hostport := range hostports {
106 fmt.Fprintf(h, hostport)
107 }
108 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
109 }
110
111 func makeEndpoints(addrs []*net.SRV, f loadbalancer.Factory, logger log.Logger) []endpoint.Endpoint {
112 endpoints := make([]endpoint.Endpoint, 0, len(addrs))
113 for _, addr := range addrs {
114 endpoint, err := f(addr2instance(addr))
115 if err != nil {
116 logger.Log("instance", addr2instance(addr), "err", err)
117 continue
118 }
119 endpoints = append(endpoints, endpoint)
120 }
121 return endpoints
122 }
123
124 func addr2instance(addr *net.SRV) string {
125 return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
126 }
0 package dnssrv
1
2 import (
3 "errors"
4 "net"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "golang.org/x/net/context"
10
11 "github.com/go-kit/kit/endpoint"
12 "github.com/go-kit/kit/loadbalancer"
13 "github.com/go-kit/kit/log"
14 )
15
16 func TestPublisher(t *testing.T) {
17 var (
18 target = "my-target"
19 port = uint16(1234)
20 addr = &net.SRV{Target: target, Port: port}
21 addrs = []*net.SRV{addr}
22 name = "my-name"
23 ttl = time.Second
24 logger = log.NewNopLogger()
25 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
26 )
27
28 oldLookup := lookupSRV
29 defer func() { lookupSRV = oldLookup }()
30 lookupSRV = mockLookupSRV(addrs, nil, nil)
31
32 factory := func(instance string) (endpoint.Endpoint, error) {
33 if want, have := addr2instance(addr), instance; want != have {
34 t.Errorf("want %q, have %q", want, have)
35 }
36 return e, nil
37 }
38
39 p, err := NewPublisher(name, ttl, factory, logger)
40 if err != nil {
41 t.Fatal(err)
42 }
43 defer p.Stop()
44
45 if _, err := p.Endpoints(); err != nil {
46 t.Fatal(err)
47 }
48 }
49
50 func TestBadLookup(t *testing.T) {
51 oldLookup := lookupSRV
52 defer func() { lookupSRV = oldLookup }()
53 lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil)
54
55 var (
56 name = "some-name"
57 ttl = time.Second
58 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("unreachable") }
59 logger = log.NewNopLogger()
60 )
61
62 if _, err := NewPublisher(name, ttl, factory, logger); err == nil {
63 t.Fatal("wanted error, got none")
64 }
65 }
66
67 func TestBadFactory(t *testing.T) {
68 var (
69 addr = &net.SRV{Target: "foo", Port: 1234}
70 addrs = []*net.SRV{addr}
71 name = "some-name"
72 ttl = time.Second
73 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
74 logger = log.NewNopLogger()
75 )
76
77 oldLookup := lookupSRV
78 defer func() { lookupSRV = oldLookup }()
79 lookupSRV = mockLookupSRV(addrs, nil, nil)
80
81 p, err := NewPublisher(name, ttl, factory, logger)
82 if err != nil {
83 t.Fatal(err)
84 }
85 defer p.Stop()
86
87 endpoints, err := p.Endpoints()
88 if err != nil {
89 t.Fatal(err)
90 }
91 if want, have := 0, len(endpoints); want != have {
92 t.Errorf("want %q, have %q", want, have)
93 }
94 }
95
96 func TestRefreshWithChange(t *testing.T) {
97 t.Skip("TODO")
98 }
99
100 func TestRefreshNoChange(t *testing.T) {
101 var (
102 tick = make(chan time.Time)
103 target = "my-target"
104 port = uint16(5678)
105 addr = &net.SRV{Target: target, Port: port}
106 addrs = []*net.SRV{addr}
107 name = "my-name"
108 ttl = time.Second
109 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
110 logger = log.NewNopLogger()
111 )
112
113 oldTicker := newTicker
114 defer func() { newTicker = oldTicker }()
115 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} }
116
117 var resolves uint64
118 oldLookup := lookupSRV
119 defer func() { lookupSRV = oldLookup }()
120 lookupSRV = mockLookupSRV(addrs, nil, &resolves)
121
122 p, err := NewPublisher(name, ttl, factory, logger)
123 if err != nil {
124 t.Fatal(err)
125 }
126 defer p.Stop()
127
128 tick <- time.Now()
129 if want, have := uint64(2), resolves; want != have {
130 t.Errorf("want %d, have %d", want, have)
131 }
132 }
133
134 func TestRefreshResolveError(t *testing.T) {
135 t.Skip("TODO")
136 }
137
138 func TestErrPublisherStopped(t *testing.T) {
139 var (
140 name = "my-name"
141 ttl = time.Second
142 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
143 logger = log.NewNopLogger()
144 )
145
146 oldLookup := lookupSRV
147 defer func() { lookupSRV = oldLookup }()
148 lookupSRV = mockLookupSRV([]*net.SRV{}, nil, nil)
149
150 p, err := NewPublisher(name, ttl, factory, logger)
151 if err != nil {
152 t.Fatal(err)
153 }
154
155 p.Stop()
156 _, have := p.Endpoints()
157 if want := loadbalancer.ErrPublisherStopped; want != have {
158 t.Fatalf("want %v, have %v", want, have)
159 }
160 }
161
162 func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) {
163 return func(service, proto, name string) (string, []*net.SRV, error) {
164 if count != nil {
165 atomic.AddUint64(count, 1)
166 }
167 return "", addrs, err
168 }
169 }
+0
-41
loadbalancer/endpoint_cache.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 type endpointCache struct {
5 requests chan []endpoint.Endpoint
6 quit chan struct{}
7 }
8
9 func newEndpointCache(p Publisher) *endpointCache {
10 c := &endpointCache{
11 requests: make(chan []endpoint.Endpoint),
12 quit: make(chan struct{}),
13 }
14 go c.loop(p)
15 return c
16 }
17
18 func (c *endpointCache) loop(p Publisher) {
19 updates := make(chan []endpoint.Endpoint, 1)
20 p.Subscribe(updates)
21 defer p.Unsubscribe(updates)
22 endpoints := <-updates
23
24 for {
25 select {
26 case endpoints = <-updates:
27 case c.requests <- endpoints:
28 case <-c.quit:
29 return
30 }
31 }
32 }
33
34 func (c *endpointCache) get() []endpoint.Endpoint {
35 return <-c.requests
36 }
37
38 func (c *endpointCache) stop() {
39 close(c.quit)
40 }
+0
-27
loadbalancer/endpoint_cache_internal_test.go less more
0 package loadbalancer
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 )
9
10 func TestEndpointCache(t *testing.T) {
11 endpoints := []endpoint.Endpoint{
12 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
15 }
16
17 p := NewStaticPublisher(endpoints)
18 defer p.Stop()
19
20 c := newEndpointCache(p)
21 defer c.stop()
22
23 if want, have := len(endpoints), len(c.get()); want != have {
24 t.Errorf("want %d, have %d", want, have)
25 }
26 }
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Factory is a function that converts an instance string, e.g. a host:port,
5 // to a usable endpoint. Factories are used by load balancers to lift
6 // instances returned by Publishers into endpoints. Users are expected to
7 // provide their own factory functions that assume specific transports, or can
8 // deduce transports by parsing the instance string.
9 type Factory func(instance string) (endpoint.Endpoint, error)
+0
-17
loadbalancer/load_balancer.go less more
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // LoadBalancer yields endpoints one-by-one.
9 type LoadBalancer interface {
10 Count() int
11 Get() (endpoint.Endpoint, error)
12 }
13
14 // ErrNoEndpointsAvailable is given by a load balancer when no endpoints are
15 // available to be returned.
16 var ErrNoEndpointsAvailable = errors.New("no endpoints available")
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // LoadBalancer describes something that can yield endpoints for a remote
9 // service method.
10 type LoadBalancer interface {
11 Endpoint() (endpoint.Endpoint, error)
12 }
13
14 // ErrNoEndpoints is returned when a load balancer (or one of its components)
15 // has no endpoints to return. In a request lifecycle, this is usually a fatal
16 // error.
17 var ErrNoEndpoints = errors.New("no endpoints available")
+0
-46
loadbalancer/mock_publisher_test.go less more
0 package loadbalancer_test
1
2 import (
3 "runtime"
4 "sync"
5
6 "github.com/go-kit/kit/endpoint"
7 )
8
9 type mockPublisher struct {
10 sync.Mutex
11 e []endpoint.Endpoint
12 s map[chan<- []endpoint.Endpoint]struct{}
13 }
14
15 func newMockPublisher(endpoints []endpoint.Endpoint) *mockPublisher {
16 return &mockPublisher{
17 e: endpoints,
18 s: map[chan<- []endpoint.Endpoint]struct{}{},
19 }
20 }
21
22 func (p *mockPublisher) replace(endpoints []endpoint.Endpoint) {
23 p.Lock()
24 defer p.Unlock()
25 p.e = endpoints
26 for s := range p.s {
27 s <- p.e
28 }
29 runtime.Gosched()
30 }
31
32 func (p *mockPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
33 p.Lock()
34 defer p.Unlock()
35 p.s[c] = struct{}{}
36 c <- p.e
37 }
38
39 func (p *mockPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
40 p.Lock()
41 defer p.Unlock()
42 delete(p.s, c)
43 }
44
45 func (p *mockPublisher) Stop() {}
00 package loadbalancer
11
2 import "github.com/go-kit/kit/endpoint"
2 import (
3 "errors"
34
4 // Publisher produces endpoints.
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Publisher describes something that provides a set of identical endpoints.
9 // Different publisher implementations exist for different kinds of service
10 // discovery systems.
511 type Publisher interface {
6 Subscribe(chan<- []endpoint.Endpoint)
7 Unsubscribe(chan<- []endpoint.Endpoint)
8 Stop()
12 Endpoints() ([]endpoint.Endpoint, error)
913 }
14
15 // ErrPublisherStopped is returned by publishers when the underlying
16 // implementation has been terminated and can no longer serve requests.
17 var ErrPublisherStopped = errors.New("publisher stopped")
55 "github.com/go-kit/kit/endpoint"
66 )
77
8 // Random returns a load balancer that yields random endpoints.
9 func Random(p Publisher) LoadBalancer {
10 return random{newCache(p)}
8 // Random is a completely stateless load balancer that chooses a random
9 // endpoint to return each time.
10 type Random struct {
11 p Publisher
12 r *rand.Rand
1113 }
1214
13 type random struct{ *cache }
15 // NewRandom returns a new Random load balancer.
16 func NewRandom(p Publisher, seed int64) *Random {
17 return &Random{
18 p: p,
19 r: rand.New(rand.NewSource(seed)),
20 }
21 }
1422
15 func (r random) Count() int { return r.cache.count() }
16
17 func (r random) Get() (endpoint.Endpoint, error) {
18 endpoints := r.cache.get()
23 // Endpoint implements the LoadBalancer interface.
24 func (r *Random) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := r.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
1929 if len(endpoints) <= 0 {
20 return nil, ErrNoEndpointsAvailable
30 return nil, ErrNoEndpoints
2131 }
22 return endpoints[rand.Intn(len(endpoints))], nil
32 return endpoints[r.r.Intn(len(endpoints))], nil
2333 }
33 "math"
44 "testing"
55
6 "golang.org/x/net/context"
7
68 "github.com/go-kit/kit/endpoint"
79 "github.com/go-kit/kit/loadbalancer"
8 "golang.org/x/net/context"
10 "github.com/go-kit/kit/loadbalancer/static"
911 )
1012
11 func TestRandom(t *testing.T) {
12 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
13 defer p.Stop()
13 func TestRandomDistribution(t *testing.T) {
14 var (
15 n = 3
16 endpoints = make([]endpoint.Endpoint, n)
17 counts = make([]int, n)
18 seed = int64(123)
19 ctx = context.Background()
20 iterations = 100000
21 want = iterations / n
22 tolerance = want / 100 // 1%
23 )
1424
15 lb := loadbalancer.Random(p)
16 if _, err := lb.Get(); err == nil {
17 t.Error("want error, got none")
25 for i := 0; i < n; i++ {
26 i0 := i
27 endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil }
1828 }
1929
20 counts := []int{0, 0, 0}
21 p.Replace([]endpoint.Endpoint{
22 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
23 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
25 })
26 assertLoadBalancerNotEmpty(t, lb)
30 lb := loadbalancer.NewRandom(static.NewPublisher(endpoints), seed)
2731
28 n := 10000
29 for i := 0; i < n; i++ {
30 e, _ := lb.Get()
31 e(context.Background(), struct{}{})
32 for i := 0; i < iterations; i++ {
33 e, err := lb.Endpoint()
34 if err != nil {
35 t.Fatal(err)
36 }
37 e(ctx, struct{}{})
3238 }
3339
34 want := float64(n) / float64(len(counts))
35 tolerance := (want / 100.0) * 5 // 5%
36 for _, have := range counts {
37 if math.Abs(want-float64(have)) > tolerance {
38 t.Errorf("want %.0f, have %d", want, have)
40 for i, have := range counts {
41 if math.Abs(float64(want-have)) > float64(tolerance) {
42 t.Errorf("%d: want %d, have %d", i, want, have)
3943 }
4044 }
4145 }
46
47 func TestRandomBadPublisher(t *testing.T) {
48 t.Skip("TODO")
49 }
50
51 func TestRandomNoEndpoints(t *testing.T) {
52 lb := loadbalancer.NewRandom(static.NewPublisher([]endpoint.Endpoint{}), 123)
53 _, have := lb.Endpoint()
54 if want := loadbalancer.ErrNoEndpoints; want != have {
55 t.Errorf("want %q, have %q", want, have)
56 }
57 }
99 "github.com/go-kit/kit/endpoint"
1010 )
1111
12 // Retry yields an endpoint that takes endpoints from the load balancer.
13 // Invocations that return errors will be retried until they succeed, up to
14 // max times, or until the timeout is elapsed, whichever comes first.
12 // Retry wraps the load balancer to make it behave like a simple endpoint.
13 // Requests to the endpoint will be automatically load balanced via the load
14 // balancer. Requests that return errors will be retried until they succeed,
15 // up to max times, or until the timeout is elapsed, whichever comes first.
1516 func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint {
1617 return func(ctx context.Context, request interface{}) (interface{}, error) {
1718 var (
2324 defer cancel()
2425 for i := 1; i <= max; i++ {
2526 go func() {
26 e, err := lb.Get()
27 e, err := lb.Endpoint()
2728 if err != nil {
2829 errs <- err
2930 return
11
22 import (
33 "errors"
4 "testing"
45 "time"
6
7 "golang.org/x/net/context"
58
69 "github.com/go-kit/kit/endpoint"
710 "github.com/go-kit/kit/loadbalancer"
8 "golang.org/x/net/context"
9
10 "testing"
11 "github.com/go-kit/kit/loadbalancer/static"
1112 )
1213
13 func TestRetryMax(t *testing.T) {
14 func TestRetryMaxTotalFail(t *testing.T) {
1415 var (
15 endpoints = []endpoint.Endpoint{}
16 p = loadbalancer.NewStaticPublisher(endpoints)
17 lb = loadbalancer.RoundRobin(p)
16 endpoints = []endpoint.Endpoint{} // no endpoints
17 p = static.NewPublisher(endpoints)
18 lb = loadbalancer.NewRoundRobin(p)
19 retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries
20 ctx = context.Background()
1821 )
22 if _, err := retry(ctx, struct{}{}); err == nil {
23 t.Errorf("expected error, got none") // should fail
24 }
25 }
1926
20 if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil {
27 func TestRetryMaxPartialFail(t *testing.T) {
28 var (
29 endpoints = []endpoint.Endpoint{
30 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
31 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
32 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
33 }
34 retries = len(endpoints) - 1 // not quite enough retries
35 p = static.NewPublisher(endpoints)
36 lb = loadbalancer.NewRoundRobin(p)
37 ctx = context.Background()
38 )
39 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil {
2140 t.Errorf("expected error, got none")
2241 }
42 }
2343
24 endpoints = []endpoint.Endpoint{
25 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
26 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
27 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
28 }
29 p.Replace(endpoints)
30 assertLoadBalancerNotEmpty(t, lb)
31
32 if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil {
33 t.Errorf("expected error, got none")
34 }
35
36 if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil {
44 func TestRetryMaxSuccess(t *testing.T) {
45 var (
46 endpoints = []endpoint.Endpoint{
47 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
48 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
49 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
50 }
51 retries = len(endpoints) // exactly enough retries
52 p = static.NewPublisher(endpoints)
53 lb = loadbalancer.NewRoundRobin(p)
54 ctx = context.Background()
55 )
56 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil {
3757 t.Error(err)
3858 }
3959 }
4363 step = make(chan struct{})
4464 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
4565 timeout = time.Millisecond
46 retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e})))
47 errs = make(chan error)
66 retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(static.NewPublisher([]endpoint.Endpoint{e})))
67 errs = make(chan error, 1)
4868 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
4969 )
5070
51 go invoke() // invoke the endpoint
52 step <- struct{}{} // tell the endpoint to return
53 if err := <-errs; err != nil { // that should succeed
71 go func() { step <- struct{}{} }() // queue up a flush of the endpoint
72 invoke() // invoke the endpoint and trigger the flush
73 if err := <-errs; err != nil { // that should succeed
5474 t.Error(err)
5575 }
5676
57 go invoke() // invoke the endpoint
58 time.Sleep(2 * timeout) // wait
59 time.Sleep(2 * timeout) // wait again (CI servers!!)
60 step <- struct{}{} // tell the endpoint to return
61 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
62 t.Errorf("wanted error, got none")
77 go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
78 invoke() // invoke the endpoint
79 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
80 t.Errorf("wanted %v, got none", context.DeadlineExceeded)
6381 }
6482 }
55 "github.com/go-kit/kit/endpoint"
66 )
77
8 // RoundRobin returns a load balancer that yields endpoints in sequence.
9 func RoundRobin(p Publisher) LoadBalancer {
10 return &roundRobin{newCache(p), 0}
8 // RoundRobin is a simple load balancer that returns each of the published
9 // endpoints in sequence.
10 type RoundRobin struct {
11 p Publisher
12 counter uint64
1113 }
1214
13 type roundRobin struct {
14 *cache
15 uint64
15 // NewRoundRobin returns a new RoundRobin load balancer.
16 func NewRoundRobin(p Publisher) *RoundRobin {
17 return &RoundRobin{
18 p: p,
19 counter: 0,
20 }
1621 }
1722
18 func (r *roundRobin) Count() int { return r.cache.count() }
19
20 func (r *roundRobin) Get() (endpoint.Endpoint, error) {
21 endpoints := r.cache.get()
23 // Endpoint implements the LoadBalancer interface.
24 func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := rr.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
2229 if len(endpoints) <= 0 {
23 return nil, ErrNoEndpointsAvailable
30 return nil, ErrNoEndpoints
2431 }
2532 var old uint64
2633 for {
27 old = atomic.LoadUint64(&r.uint64)
28 if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) {
34 old = atomic.LoadUint64(&rr.counter)
35 if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) {
2936 break
3037 }
3138 }
55
66 "github.com/go-kit/kit/endpoint"
77 "github.com/go-kit/kit/loadbalancer"
8 "github.com/go-kit/kit/loadbalancer/static"
89 "golang.org/x/net/context"
910 )
1011
11 func TestRoundRobin(t *testing.T) {
12 p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{})
13 defer p.Stop()
12 func TestRoundRobinDistribution(t *testing.T) {
13 var (
14 ctx = context.Background()
15 counts = []int{0, 0, 0}
16 endpoints = []endpoint.Endpoint{
17 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
18 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
19 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
20 }
21 )
1422
15 lb := loadbalancer.RoundRobin(p)
16 if _, err := lb.Get(); err == nil {
17 t.Error("want error, got none")
18 }
19
20 counts := []int{0, 0, 0}
21 p.Replace([]endpoint.Endpoint{
22 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
23 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
25 })
26 assertLoadBalancerNotEmpty(t, lb)
23 lb := loadbalancer.NewRoundRobin(static.NewPublisher(endpoints))
2724
2825 for i, want := range [][]int{
2926 {1, 0, 0},
3431 {2, 2, 2},
3532 {3, 2, 2},
3633 } {
37 e, _ := lb.Get()
38 e(context.Background(), struct{}{})
34 e, err := lb.Endpoint()
35 if err != nil {
36 t.Fatal(err)
37 }
38 e(ctx, struct{}{})
3939 if have := counts; !reflect.DeepEqual(want, have) {
40 t.Errorf("%d: want %v, have %v", i+1, want, have)
40 t.Fatalf("%d: want %v, have %v", i, want, have)
4141 }
42
4243 }
4344 }
45
46 func TestRoundRobinBadPublisher(t *testing.T) {
47 t.Skip("TODO")
48 }
0 package static
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Publisher yields the same set of static endpoints.
9 type Publisher struct {
10 mtx sync.RWMutex
11 endpoints []endpoint.Endpoint
12 }
13
14 // NewPublisher returns a static endpoint Publisher.
15 func NewPublisher(endpoints []endpoint.Endpoint) *Publisher {
16 return &Publisher{
17 endpoints: endpoints,
18 }
19 }
20
21 // Endpoints implements the Publisher interface.
22 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
23 p.mtx.RLock()
24 defer p.mtx.RUnlock()
25 return p.endpoints, nil
26 }
27
28 // Replace is a utility method to swap out the underlying endpoints of an
29 // existing static publisher. It's useful mostly for testing.
30 func (p *Publisher) Replace(endpoints []endpoint.Endpoint) {
31 p.mtx.Lock()
32 defer p.mtx.Unlock()
33 p.endpoints = endpoints
34 }
0 package static_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/static"
10 )
11
12 func TestStatic(t *testing.T) {
13 var (
14 e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
15 e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
16 endpoints = []endpoint.Endpoint{e1, e2}
17 )
18 p := static.NewPublisher(endpoints)
19 have, err := p.Endpoints()
20 if err != nil {
21 t.Fatal(err)
22 }
23 if want := endpoints; !reflect.DeepEqual(want, have) {
24 t.Fatalf("want %#+v, have %#+v", want, have)
25 }
26 }
27
28 func TestStaticReplace(t *testing.T) {
29 p := static.NewPublisher([]endpoint.Endpoint{
30 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
31 })
32 have, err := p.Endpoints()
33 if err != nil {
34 t.Fatal(err)
35 }
36 if want, have := 1, len(have); want != have {
37 t.Fatalf("want %d, have %d", want, have)
38 }
39 p.Replace([]endpoint.Endpoint{})
40 have, err = p.Endpoints()
41 if err != nil {
42 t.Fatal(err)
43 }
44 if want, have := 0, len(have); want != have {
45 t.Fatalf("want %d, have %d", want, have)
46 }
47 }
+0
-51
loadbalancer/static_publisher.go less more
0 package loadbalancer
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // NewStaticPublisher returns a publisher that yields a static set of
9 // endpoints, which can be completely replaced.
10 func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher {
11 return &StaticPublisher{
12 current: endpoints,
13 subscribers: map[chan<- []endpoint.Endpoint]struct{}{},
14 }
15 }
16
17 // StaticPublisher holds a static set of endpoints.
18 type StaticPublisher struct {
19 mu sync.Mutex
20 current []endpoint.Endpoint
21 subscribers map[chan<- []endpoint.Endpoint]struct{}
22 }
23
24 // Subscribe implements Publisher.
25 func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
26 p.mu.Lock()
27 defer p.mu.Unlock()
28 p.subscribers[c] = struct{}{}
29 c <- p.current
30 }
31
32 // Unsubscribe implements Publisher.
33 func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.mu.Lock()
35 defer p.mu.Unlock()
36 delete(p.subscribers, c)
37 }
38
39 // Stop implements Publisher, but is a no-op.
40 func (p *StaticPublisher) Stop() {}
41
42 // Replace replaces the endpoints and notifies all subscribers.
43 func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) {
44 p.mu.Lock()
45 defer p.mu.Unlock()
46 p.current = endpoints
47 for c := range p.subscribers {
48 c <- p.current
49 }
50 }
+0
-30
loadbalancer/static_publisher_test.go less more
0 package loadbalancer_test
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 )
10
11 func TestStaticPublisher(t *testing.T) {
12 endpoints := []endpoint.Endpoint{
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 }
15 p := loadbalancer.NewStaticPublisher(endpoints)
16 defer p.Stop()
17
18 c := make(chan []endpoint.Endpoint, 1)
19 p.Subscribe(c)
20 if want, have := len(endpoints), len(<-c); want != have {
21 t.Errorf("want %d, have %d", want, have)
22 }
23
24 endpoints = []endpoint.Endpoint{}
25 p.Replace(endpoints)
26 if want, have := len(endpoints), len(<-c); want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }
+0
-17
loadbalancer/strategies.go less more
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Strategy yields endpoints to consumers according to some algorithm.
9 type Strategy interface {
10 Next() (endpoint.Endpoint, error)
11 Stop()
12 }
13
14 // ErrNoEndpoints is returned by a strategy when there are no endpoints
15 // available.
16 var ErrNoEndpoints = errors.New("no endpoints available")
+0
-35
loadbalancer/util_test.go less more
0 package loadbalancer_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6
7 "github.com/go-kit/kit/loadbalancer"
8 )
9
10 func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) {
11 if err := within(10*time.Millisecond, func() bool {
12 return lb.Count() > 0
13 }); err != nil {
14 t.Fatal("Publisher never updated endpoints")
15 }
16 }
17
18 func within(d time.Duration, f func() bool) error {
19 var (
20 deadline = time.After(d)
21 ticker = time.NewTicker(d / 10)
22 )
23 defer ticker.Stop()
24 for {
25 select {
26 case <-ticker.C:
27 if f() {
28 return nil
29 }
30 case <-deadline:
31 return fmt.Errorf("deadline exceeded")
32 }
33 }
34 }