Codebase list golang-github-go-kit-kit / 8357da3
Big re-org Having Publishers return a set of endpoints directly (synchronously) allows us to eliminate a lot of boilerplate related to pub/sub semantics, as well as the cache type. Thanks, @rogpeppe! Peter Bourgon 8 years ago
28 changed file(s) with 513 addition(s) and 737 deletion(s). Raw diff Collapse all Expand all
+0
-44
loadbalancer/README.md less more
0 # package loadbalancer
1
2 `package loadbalancer` provides a client-side load balancer abstraction.
3
4 A publisher is responsible for emitting the most recent set of endpoints for a
5 single logical service. Publishers exist for static endpoints, and endpoints
6 discovered via periodic DNS SRV lookups on a single logical name. Consul and
7 etcd publishers are planned.
8
9 Different load balancing strategies are implemented on top of publishers. Go
10 kit currently provides random and round-robin semantics. Smarter behaviors,
11 e.g. load balancing based on underlying endpoint priority/weight, is planned.
12
13 ## Rationale
14
15 TODO
16
17 ## Usage
18
19 In your client, define a publisher, wrap it with a balancing strategy, and pass
20 it to a retry strategy, which returns an endpoint. Use that endpoint to make
21 requests, or wrap it with other value-add middleware.
22
23 ```go
24 func main() {
25 var (
26 fooPublisher = loadbalancer.NewDNSSRVPublisher("foo.mynet.local", 5*time.Second, makeEndpoint)
27 fooBalancer = loadbalancer.RoundRobin(fooPublisher)
28 fooEndpoint = loadbalancer.Retry(3, time.Second, fooBalancer)
29 )
30 http.HandleFunc("/", handle(fooEndpoint))
31 log.Fatal(http.ListenAndServe(":8080", nil))
32 }
33
34 func makeEndpoint(hostport string) endpoint.Endpoint {
35 // Convert a host:port to a endpoint via your defined transport.
36 }
37
38 func handle(foo endpoint.Endpoint) http.HandlerFunc {
39 return func(w http.ResponseWriter, r *http.Request) {
40 // foo is usable as a load-balanced remote endpoint.
41 }
42 }
43 ```
0 package dnssrv
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer"
11 "github.com/go-kit/kit/log"
12 )
13
14 // Publisher yields endpoints taken from the named DNS SRV record. The name is
15 // resolved on a fixed schedule. Priorities and weights are ignored.
16 type Publisher struct {
17 name string
18 ttl time.Duration
19 factory loadbalancer.Factory
20 logger log.Logger
21 endpoints chan []endpoint.Endpoint
22 quit chan struct{}
23 }
24
25 // NewPublisher returns a DNS SRV publisher. The name is resolved
26 // synchronously as part of construction; if that resolution fails, the
27 // constructor will return an error. The factory is used to convert a
28 // host:port to a usable endpoint. The logger is used to report DNS and
29 // factory errors.
30 func NewPublisher(name string, ttl time.Duration, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) {
31 logger = log.NewContext(logger).With("component", "DNS SRV Publisher")
32 addrs, md5, err := resolve(name)
33 if err != nil {
34 return nil, err
35 }
36 p := &Publisher{
37 name: name,
38 ttl: ttl,
39 factory: f,
40 logger: logger,
41 endpoints: make(chan []endpoint.Endpoint),
42 quit: make(chan struct{}),
43 }
44 go p.loop(lift(addrs, f, logger), md5)
45 return p, nil
46 }
47
48 // Stop terminates the publisher.
49 func (p *Publisher) Stop() {
50 close(p.quit)
51 }
52
53 func (p *Publisher) loop(endpoints []endpoint.Endpoint, md5 string) {
54 t := newTicker(p.ttl)
55 defer t.Stop()
56 for {
57 select {
58 case p.endpoints <- endpoints:
59
60 case <-t.C:
61 // TODO should we do this out-of-band?
62 addrs, newmd5, err := resolve(p.name)
63 if err != nil {
64 p.logger.Log("name", p.name, "err", err)
65 continue // don't replace good endpoints with bad ones
66 }
67 if newmd5 == md5 {
68 continue // no change
69 }
70 endpoints = lift(addrs, p.factory, p.logger)
71 md5 = newmd5
72
73 case <-p.quit:
74 return
75 }
76 }
77 }
78
79 // Endpoints implements the Publisher interface.
80 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
81 return <-p.endpoints, nil
82 }
83
84 var (
85 lookupSRV = net.LookupSRV
86 newTicker = time.NewTicker
87 )
88
89 func resolve(name string) (addrs []*net.SRV, md5sum string, err error) {
90 _, addrs, err = lookupSRV("", "", name)
91 if err != nil {
92 return addrs, "", err
93 }
94 hostports := make([]string, len(addrs))
95 for i, addr := range addrs {
96 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
97 }
98 sort.Sort(sort.StringSlice(hostports))
99 h := md5.New()
100 for _, hostport := range hostports {
101 fmt.Fprintf(h, hostport)
102 }
103 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
104 }
105
106 func lift(addrs []*net.SRV, f loadbalancer.Factory, logger log.Logger) []endpoint.Endpoint {
107 endpoints := make([]endpoint.Endpoint, 0, len(addrs))
108 for _, addr := range addrs {
109 endpoint, err := f(addr2instance(addr))
110 if err != nil {
111 logger.Log("instance", addr2instance(addr), "err", err)
112 continue
113 }
114 endpoints = append(endpoints, endpoint)
115 }
116 return endpoints
117 }
118
119 func addr2instance(addr *net.SRV) string {
120 return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
121 }
0 package dnssrv
1
2 import (
3 "errors"
4 "net"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "golang.org/x/net/context"
10
11 "github.com/go-kit/kit/endpoint"
12 "github.com/go-kit/kit/log"
13 )
14
15 func TestPublisher(t *testing.T) {
16 var (
17 target = "my-target"
18 port = uint16(1234)
19 addr = &net.SRV{Target: target, Port: port}
20 addrs = []*net.SRV{addr}
21 name = "my-name"
22 ttl = time.Second
23 logger = log.NewNopLogger()
24 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
25 )
26
27 oldLookup := lookupSRV
28 defer func() { lookupSRV = oldLookup }()
29 lookupSRV = mockLookupSRV(addrs, nil, nil)
30
31 factory := func(instance string) (endpoint.Endpoint, error) {
32 if want, have := addr2instance(addr), instance; want != have {
33 t.Errorf("want %q, have %q", want, have)
34 }
35 return e, nil
36 }
37
38 p, err := NewPublisher(name, ttl, factory, logger)
39 if err != nil {
40 t.Fatal(err)
41 }
42 defer p.Stop()
43
44 if _, err := p.Endpoints(); err != nil {
45 t.Fatal(err)
46 }
47 }
48
49 func TestBadLookup(t *testing.T) {
50 oldLookup := lookupSRV
51 defer func() { lookupSRV = oldLookup }()
52 lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil)
53
54 var (
55 name = "some-name"
56 ttl = time.Second
57 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("unreachable") }
58 logger = log.NewNopLogger()
59 )
60
61 if _, err := NewPublisher(name, ttl, factory, logger); err == nil {
62 t.Fatal("wanted error, got none")
63 }
64 }
65
66 func TestBadFactory(t *testing.T) {
67 var (
68 addr = &net.SRV{Target: "foo", Port: 1234}
69 addrs = []*net.SRV{addr}
70 name = "some-name"
71 ttl = time.Second
72 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
73 logger = log.NewNopLogger()
74 )
75
76 oldLookup := lookupSRV
77 defer func() { lookupSRV = oldLookup }()
78 lookupSRV = mockLookupSRV(addrs, nil, nil)
79
80 p, err := NewPublisher(name, ttl, factory, logger)
81 if err != nil {
82 t.Fatal(err)
83 }
84 defer p.Stop()
85
86 endpoints, err := p.Endpoints()
87 if err != nil {
88 t.Fatal(err)
89 }
90 if want, have := 0, len(endpoints); want != have {
91 t.Errorf("want %q, have %q", want, have)
92 }
93 }
94
95 func TestRefreshWithChange(t *testing.T) {
96 t.Skip("TODO")
97 }
98
99 func TestRefreshNoChange(t *testing.T) {
100 var (
101 tick = make(chan time.Time)
102 target = "my-target"
103 port = uint16(5678)
104 addr = &net.SRV{Target: target, Port: port}
105 addrs = []*net.SRV{addr}
106 name = "my-name"
107 ttl = time.Second
108 factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
109 logger = log.NewNopLogger()
110 )
111
112 oldTicker := newTicker
113 defer func() { newTicker = oldTicker }()
114 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} }
115
116 var resolves uint64
117 oldLookup := lookupSRV
118 defer func() { lookupSRV = oldLookup }()
119 lookupSRV = mockLookupSRV(addrs, nil, &resolves)
120
121 p, err := NewPublisher(name, ttl, factory, logger)
122 if err != nil {
123 t.Fatal(err)
124 }
125 defer p.Stop()
126
127 tick <- time.Now()
128 if want, have := uint64(2), resolves; want != have {
129 t.Errorf("want %d, have %d", want, have)
130 }
131 }
132
133 func TestRefreshResolveError(t *testing.T) {
134 t.Skip("TODO")
135 }
136
137 func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) {
138 return func(service, proto, name string) (string, []*net.SRV, error) {
139 if count != nil {
140 atomic.AddUint64(count, 1)
141 }
142 return "", addrs, err
143 }
144 }
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Factory is a function that converts an instance string, e.g. a host:port,
5 // to a usable endpoint. Factories are used by load balancers to lift
6 // instances returned by Publishers into endpoints. Users are expected to
7 // provide their own factory functions that assume specific transports, or can
8 // deduce transports by parsing the instance string.
9 type Factory func(instance string) (endpoint.Endpoint, error)
+0
-17
loadbalancer/load_balancer.go less more
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // LoadBalancer yields endpoints one-by-one.
9 type LoadBalancer interface {
10 Count() int
11 Get() (endpoint.Endpoint, error)
12 }
13
14 // ErrNoEndpointsAvailable is given by a load balancer or strategy when no
15 // endpoints are available to be returned.
16 var ErrNoEndpointsAvailable = errors.New("no endpoints available")
0 package loadbalancer
1
2 import "errors"
3
4 // ErrNoEndpoints is returned when a load balancer (or one of its components)
5 // has no endpoints to return. In a request lifecycle, this is usually a fatal
6 // error.
7 var ErrNoEndpoints = errors.New("no endpoints available")
+0
-110
loadbalancer/publisher/dns/srv_publisher.go less more
0 package dns
1
2 import (
3 "crypto/md5"
4 "fmt"
5 "net"
6 "sort"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // SRVPublisher implements Publisher.
13 type SRVPublisher struct {
14 subscribe chan chan<- []endpoint.Endpoint
15 unsubscribe chan chan<- []endpoint.Endpoint
16 quit chan struct{}
17 }
18
19 // NewSRVPublisher returns a publisher that resolves the SRV name every ttl,
20 // and yields endpoints constructed via the makeEndpoint factory.
21 func NewSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) *SRVPublisher {
22 p := &SRVPublisher{
23 subscribe: make(chan chan<- []endpoint.Endpoint),
24 unsubscribe: make(chan chan<- []endpoint.Endpoint),
25 quit: make(chan struct{}),
26 }
27 go p.loop(name, ttl, makeEndpoint)
28 return p
29 }
30
31 // Subscribe implements Publisher.
32 func (p *SRVPublisher) Subscribe(c chan<- []endpoint.Endpoint) {
33 p.subscribe <- c
34 }
35
36 // Unsubscribe implements Publisher.
37 func (p *SRVPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
38 p.unsubscribe <- c
39 }
40
41 // Stop implements Publisher.
42 func (p *SRVPublisher) Stop() {
43 close(p.quit)
44 }
45
46 var newTicker = time.NewTicker
47
48 func (p *SRVPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) {
49 var (
50 subscriptions = map[chan<- []endpoint.Endpoint]struct{}{}
51 addrs, md5, _ = resolve(name)
52 endpoints = convert(addrs, makeEndpoint)
53 ticker = newTicker(ttl)
54 )
55 defer ticker.Stop()
56 for {
57 select {
58 case <-ticker.C:
59 addrs, newmd5, err := resolve(name)
60 if err == nil && newmd5 != md5 {
61 endpoints = convert(addrs, makeEndpoint)
62 for c := range subscriptions {
63 c <- endpoints
64 }
65 md5 = newmd5
66 }
67
68 case c := <-p.subscribe:
69 subscriptions[c] = struct{}{}
70 c <- endpoints
71
72 case c := <-p.unsubscribe:
73 delete(subscriptions, c)
74
75 case <-p.quit:
76 return
77 }
78 }
79 }
80
81 // Allow mocking in tests.
82 var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) {
83 _, addrs, err = net.LookupSRV("", "", name)
84 if err != nil {
85 return addrs, "", err
86 }
87 hostports := make([]string, len(addrs))
88 for i, addr := range addrs {
89 hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port)
90 }
91 sort.Sort(sort.StringSlice(hostports))
92 h := md5.New()
93 for _, hostport := range hostports {
94 fmt.Fprintf(h, hostport)
95 }
96 return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil
97 }
98
99 func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint {
100 endpoints := make([]endpoint.Endpoint, len(addrs))
101 for i, addr := range addrs {
102 endpoints[i] = makeEndpoint(addr2hostport(addr))
103 }
104 return endpoints
105 }
106
107 func addr2hostport(addr *net.SRV) string {
108 return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
109 }
+0
-77
loadbalancer/publisher/dns/srv_publisher_internal_test.go less more
0 package dns
1
2 import (
3 "fmt"
4 "net"
5 "testing"
6 "time"
7
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 )
12
13 func TestDNSSRVPublisher(t *testing.T) {
14 // Reset the vars when we're done
15 oldResolve := resolve
16 defer func() { resolve = oldResolve }()
17 oldNewTicker := newTicker
18 defer func() { newTicker = oldNewTicker }()
19
20 // Set up a fixture and swap the vars
21 a := []*net.SRV{
22 {Target: "foo", Port: 123},
23 {Target: "bar", Port: 456},
24 {Target: "baz", Port: 789},
25 }
26 ticker := make(chan time.Time)
27 resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil }
28 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} }
29
30 // Construct endpoint
31 m := map[string]int{}
32 e := func(hostport string) endpoint.Endpoint {
33 return func(context.Context, interface{}) (interface{}, error) {
34 m[hostport]++
35 return struct{}{}, nil
36 }
37 }
38
39 // Build the publisher
40 var (
41 name = "irrelevant"
42 ttl = time.Second
43 makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) }
44 )
45 p := NewSRVPublisher(name, ttl, makeEndpoint)
46 defer p.Stop()
47
48 // Subscribe
49 c := make(chan []endpoint.Endpoint, 1)
50 p.Subscribe(c)
51 defer p.Unsubscribe(c)
52
53 // Invoke all of the endpoints
54 for _, e := range <-c {
55 e(context.Background(), struct{}{})
56 }
57
58 // Make sure we invoked what we expected to
59 for _, addr := range a {
60 hostport := addr2hostport(addr)
61 if want, have := 1, m[hostport]; want != have {
62 t.Errorf("%q: want %d, have %d", name, want, have)
63 }
64 delete(m, hostport)
65 }
66 if want, have := 0, len(m); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Reset the fixture, trigger the timer, count the endpoints
71 a = []*net.SRV{}
72 ticker <- time.Now()
73 if want, have := len(a), len(<-c); want != have {
74 t.Errorf("want %d, have %d", want, have)
75 }
76 }
+0
-10
loadbalancer/publisher/publisher.go less more
0 package publisher
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher produces endpoints.
5 type Publisher interface {
6 Subscribe(chan<- []endpoint.Endpoint)
7 Unsubscribe(chan<- []endpoint.Endpoint)
8 Stop()
9 }
+0
-51
loadbalancer/publisher/static/publisher.go less more
0 package static
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Publisher holds a static set of endpoints.
9 type Publisher struct {
10 mu sync.Mutex
11 current []endpoint.Endpoint
12 subscribers map[chan<- []endpoint.Endpoint]struct{}
13 }
14
15 // NewPublisher returns a publisher that yields a static set of endpoints,
16 // which can be completely replaced.
17 func NewPublisher(endpoints []endpoint.Endpoint) *Publisher {
18 return &Publisher{
19 current: endpoints,
20 subscribers: map[chan<- []endpoint.Endpoint]struct{}{},
21 }
22 }
23
24 // Subscribe implements Publisher.
25 func (p *Publisher) Subscribe(c chan<- []endpoint.Endpoint) {
26 p.mu.Lock()
27 defer p.mu.Unlock()
28 p.subscribers[c] = struct{}{}
29 c <- p.current
30 }
31
32 // Unsubscribe implements Publisher.
33 func (p *Publisher) Unsubscribe(c chan<- []endpoint.Endpoint) {
34 p.mu.Lock()
35 defer p.mu.Unlock()
36 delete(p.subscribers, c)
37 }
38
39 // Stop implements Publisher, but is a no-op.
40 func (p *Publisher) Stop() {}
41
42 // Replace replaces the endpoints and notifies all subscribers.
43 func (p *Publisher) Replace(endpoints []endpoint.Endpoint) {
44 p.mu.Lock()
45 defer p.mu.Unlock()
46 p.current = endpoints
47 for c := range p.subscribers {
48 c <- p.current
49 }
50 }
+0
-30
loadbalancer/publisher/static/publisher_test.go less more
0 package static_test
1
2 import (
3 "testing"
4
5 "golang.org/x/net/context"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer/publisher/static"
9 )
10
11 func TestStaticPublisher(t *testing.T) {
12 endpoints := []endpoint.Endpoint{
13 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
14 }
15 p := static.NewPublisher(endpoints)
16 defer p.Stop()
17
18 c := make(chan []endpoint.Endpoint, 1)
19 p.Subscribe(c)
20 if want, have := len(endpoints), len(<-c); want != have {
21 t.Errorf("want %d, have %d", want, have)
22 }
23
24 endpoints = []endpoint.Endpoint{}
25 p.Replace(endpoints)
26 if want, have := len(endpoints), len(<-c); want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher describes something that provides a set of identical endpoints.
5 // Different publisher implementations exist for different kinds of service
6 // discovery systems.
7 type Publisher interface {
8 Endpoints() ([]endpoint.Endpoint, error)
9 }
0 package loadbalancer
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Random is a completely stateless load balancer that chooses a random
9 // endpoint to return each time.
10 type Random struct {
11 p Publisher
12 r *rand.Rand
13 }
14
15 // NewRandom returns a new Random load balancer.
16 func NewRandom(p Publisher, seed int64) *Random {
17 return &Random{
18 p: p,
19 r: rand.New(rand.NewSource(seed)),
20 }
21 }
22
23 // Endpoint implements the LoadBalancer interface.
24 func (r *Random) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := r.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
29 if len(endpoints) <= 0 {
30 return nil, ErrNoEndpoints
31 }
32 return endpoints[r.r.Intn(len(endpoints))], nil
33 }
0 package loadbalancer_test
1
2 import (
3 "math"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer"
10 "github.com/go-kit/kit/loadbalancer/static"
11 )
12
13 func TestRandomDistribution(t *testing.T) {
14 var (
15 n = 3
16 endpoints = make([]endpoint.Endpoint, n)
17 counts = make([]int, n)
18 seed = int64(123)
19 ctx = context.Background()
20 iterations = 100000
21 want = iterations / n
22 tolerance = want / 100 // 1%
23 )
24
25 for i := 0; i < n; i++ {
26 i0 := i
27 endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil }
28 }
29
30 lb := loadbalancer.NewRandom(static.Publisher(endpoints), seed)
31
32 for i := 0; i < iterations; i++ {
33 e, err := lb.Endpoint()
34 if err != nil {
35 t.Fatal(err)
36 }
37 e(ctx, struct{}{})
38 }
39
40 for i, have := range counts {
41 if math.Abs(float64(want-have)) > float64(tolerance) {
42 t.Errorf("%d: want %d, have %d", i, want, have)
43 }
44 }
45 }
46
47 func TestRandomBadPublisher(t *testing.T) {
48 t.Skip("TODO")
49 }
50
51 func TestRandomNoEndpoints(t *testing.T) {
52 lb := loadbalancer.NewRandom(static.Publisher([]endpoint.Endpoint{}), 123)
53 _, have := lb.Endpoint()
54 if want := loadbalancer.ErrNoEndpoints; want != have {
55 t.Errorf("want %q, have %q", want, have)
56 }
57 }
+0
-52
loadbalancer/retry.go less more
0 package loadbalancer
1
2 import (
3 "fmt"
4 "strings"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // Retry yields an endpoint that takes endpoints from the load balancer.
13 // Invocations that return errors will be retried until they succeed, up to
14 // max times, or until the timeout is elapsed, whichever comes first.
15 func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint {
16 return func(ctx context.Context, request interface{}) (interface{}, error) {
17 var (
18 newctx, cancel = context.WithTimeout(ctx, timeout)
19 responses = make(chan interface{}, 1)
20 errs = make(chan error, 1)
21 a = []string{}
22 )
23 defer cancel()
24 for i := 1; i <= max; i++ {
25 go func() {
26 e, err := lb.Get()
27 if err != nil {
28 errs <- err
29 return
30 }
31 response, err := e(newctx, request)
32 if err != nil {
33 errs <- err
34 return
35 }
36 responses <- response
37 }()
38
39 select {
40 case <-newctx.Done():
41 return nil, newctx.Err()
42 case response := <-responses:
43 return response, nil
44 case err := <-errs:
45 a = append(a, err.Error())
46 continue
47 }
48 }
49 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
50 }
51 }
+0
-67
loadbalancer/retry_test.go less more
0 package loadbalancer_test
1
2 import (
3 "errors"
4 "testing"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer"
11 "github.com/go-kit/kit/loadbalancer/publisher/static"
12 "github.com/go-kit/kit/loadbalancer/strategy"
13 )
14
15 func TestRetryMax(t *testing.T) {
16 var (
17 endpoints = []endpoint.Endpoint{}
18 p = static.NewPublisher(endpoints)
19 lb = strategy.RoundRobin(p)
20 )
21
22 if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil {
23 t.Errorf("expected error, got none")
24 }
25
26 endpoints = []endpoint.Endpoint{
27 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
28 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
29 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
30 }
31 p.Replace(endpoints)
32 time.Sleep(10 * time.Millisecond) //assertLoadBalancerNotEmpty(t, lb) // TODO
33
34 if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil {
35 t.Errorf("expected error, got none")
36 }
37
38 if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil {
39 t.Error(err)
40 }
41 }
42
43 func TestRetryTimeout(t *testing.T) {
44 var (
45 step = make(chan struct{})
46 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
47 timeout = time.Millisecond
48 retry = loadbalancer.Retry(999, timeout, strategy.RoundRobin(static.NewPublisher([]endpoint.Endpoint{e})))
49 errs = make(chan error)
50 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
51 )
52
53 go invoke() // invoke the endpoint
54 step <- struct{}{} // tell the endpoint to return
55 if err := <-errs; err != nil { // that should succeed
56 t.Error(err)
57 }
58
59 go invoke() // invoke the endpoint
60 time.Sleep(2 * timeout) // wait
61 time.Sleep(2 * timeout) // wait again (CI servers!!)
62 step <- struct{}{} // tell the endpoint to return
63 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
64 t.Errorf("wanted error, got none")
65 }
66 }
0 package loadbalancer
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // RoundRobin is a simple load balancer that returns each of the published
9 // endpoints in sequence.
10 type RoundRobin struct {
11 p Publisher
12 counter uint64
13 }
14
15 // NewRoundRobin returns a new RoundRobin load balancer.
16 func NewRoundRobin(p Publisher) *RoundRobin {
17 return &RoundRobin{
18 p: p,
19 counter: 0,
20 }
21 }
22
23 // Endpoint implements the LoadBalancer interface.
24 func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := rr.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
29 if len(endpoints) <= 0 {
30 return nil, ErrNoEndpoints
31 }
32 var old uint64
33 for {
34 old = atomic.LoadUint64(&rr.counter)
35 if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) {
36 break
37 }
38 }
39 return endpoints[old%uint64(len(endpoints))], nil
40 }
0 package loadbalancer_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer"
8 "github.com/go-kit/kit/loadbalancer/static"
9 "golang.org/x/net/context"
10 )
11
12 func TestRoundRobinDistribution(t *testing.T) {
13 var (
14 ctx = context.Background()
15 counts = []int{0, 0, 0}
16 endpoints = []endpoint.Endpoint{
17 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
18 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
19 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
20 }
21 )
22
23 lb := loadbalancer.NewRoundRobin(static.Publisher(endpoints))
24
25 for i, want := range [][]int{
26 {1, 0, 0},
27 {1, 1, 0},
28 {1, 1, 1},
29 {2, 1, 1},
30 {2, 2, 1},
31 {2, 2, 2},
32 {3, 2, 2},
33 } {
34 e, err := lb.Endpoint()
35 if err != nil {
36 t.Fatal(err)
37 }
38 e(ctx, struct{}{})
39 if have := counts; !reflect.DeepEqual(want, have) {
40 t.Fatalf("%d: want %v, have %v", i, want, have)
41 }
42
43 }
44 }
45
46 func TestRoundRobinBadPublisher(t *testing.T) {
47 t.Skip("TODO")
48 }
0 package static
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher yields the same set of static endpoints.
5 type Publisher []endpoint.Endpoint
6
7 // Endpoints implements the Publisher interface.
8 func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) { return p, nil }
0 package static_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/static"
10 )
11
12 func TestStatic(t *testing.T) {
13 var (
14 e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
15 e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
16 endpoints = []endpoint.Endpoint{e1, e2}
17 )
18 p := static.Publisher(endpoints)
19 have, err := p.Endpoints()
20 if err != nil {
21 t.Fatal(err)
22 }
23 if want := endpoints; !reflect.DeepEqual(want, have) {
24 t.Fatalf("want %#+v, have %#+v", want, have)
25 }
26 }
+0
-50
loadbalancer/strategy/cache.go less more
0 package strategy
1
2 import (
3 "github.com/go-kit/kit/endpoint"
4 "github.com/go-kit/kit/loadbalancer/publisher"
5 )
6
7 type cache struct {
8 req chan []endpoint.Endpoint
9 cnt chan int
10 quit chan struct{}
11 }
12
13 func newCache(p publisher.Publisher) *cache {
14 c := &cache{
15 req: make(chan []endpoint.Endpoint),
16 cnt: make(chan int),
17 quit: make(chan struct{}),
18 }
19 go c.loop(p)
20 return c
21 }
22
23 func (c *cache) loop(p publisher.Publisher) {
24 e := make(chan []endpoint.Endpoint, 1)
25 p.Subscribe(e)
26 defer p.Unsubscribe(e)
27 endpoints := <-e
28 for {
29 select {
30 case endpoints = <-e:
31 case c.cnt <- len(endpoints):
32 case c.req <- endpoints:
33 case <-c.quit:
34 return
35 }
36 }
37 }
38
39 func (c *cache) count() int {
40 return <-c.cnt
41 }
42
43 func (c *cache) get() []endpoint.Endpoint {
44 return <-c.req
45 }
46
47 func (c *cache) stop() {
48 close(c.quit)
49 }
+0
-34
loadbalancer/strategy/cache_internal_test.go less more
0 package strategy
1
2 import (
3 "runtime"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/publisher/static"
10 )
11
12 func TestCache(t *testing.T) {
13 e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
14 endpoints := []endpoint.Endpoint{e}
15
16 p := static.NewPublisher(endpoints)
17 defer p.Stop()
18
19 c := newCache(p)
20 defer c.stop()
21
22 for _, n := range []int{2, 10, 0} {
23 endpoints = make([]endpoint.Endpoint, n)
24 for i := 0; i < n; i++ {
25 endpoints[i] = e
26 }
27 p.Replace(endpoints)
28 runtime.Gosched()
29 if want, have := len(endpoints), len(c.get()); want != have {
30 t.Errorf("want %d, have %d", want, have)
31 }
32 }
33 }
+0
-26
loadbalancer/strategy/random.go less more
0 package strategy
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/loadbalancer"
7 "github.com/go-kit/kit/loadbalancer/publisher"
8 )
9
10 // Random returns a load balancer that yields random endpoints.
11 func Random(p publisher.Publisher) loadbalancer.LoadBalancer {
12 return random{newCache(p)}
13 }
14
15 type random struct{ *cache }
16
17 func (r random) Count() int { return r.cache.count() }
18
19 func (r random) Get() (endpoint.Endpoint, error) {
20 endpoints := r.cache.get()
21 if len(endpoints) <= 0 {
22 return nil, loadbalancer.ErrNoEndpointsAvailable
23 }
24 return endpoints[rand.Intn(len(endpoints))], nil
25 }
+0
-43
loadbalancer/strategy/random_test.go less more
0 package strategy_test
1
2 import (
3 "math"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer/publisher/static"
8 "github.com/go-kit/kit/loadbalancer/strategy"
9 "golang.org/x/net/context"
10 )
11
12 func TestRandom(t *testing.T) {
13 p := static.NewPublisher([]endpoint.Endpoint{})
14 defer p.Stop()
15
16 lb := strategy.Random(p)
17 if _, err := lb.Get(); err == nil {
18 t.Error("want error, got none")
19 }
20
21 counts := []int{0, 0, 0}
22 p.Replace([]endpoint.Endpoint{
23 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
24 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
26 })
27 assertLoadBalancerNotEmpty(t, lb)
28
29 n := 10000
30 for i := 0; i < n; i++ {
31 e, _ := lb.Get()
32 e(context.Background(), struct{}{})
33 }
34
35 want := float64(n) / float64(len(counts))
36 tolerance := (want / 100.0) * 5 // 5%
37 for _, have := range counts {
38 if math.Abs(want-float64(have)) > tolerance {
39 t.Errorf("want %.0f, have %d", want, have)
40 }
41 }
42 }
+0
-36
loadbalancer/strategy/round_robin.go less more
0 package strategy
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/loadbalancer"
7 "github.com/go-kit/kit/loadbalancer/publisher"
8 )
9
10 // RoundRobin returns a load balancer that yields endpoints in sequence.
11 func RoundRobin(p publisher.Publisher) loadbalancer.LoadBalancer {
12 return &roundRobin{newCache(p), 0}
13 }
14
15 type roundRobin struct {
16 *cache
17 uint64
18 }
19
20 func (r *roundRobin) Count() int { return r.cache.count() }
21
22 func (r *roundRobin) Get() (endpoint.Endpoint, error) {
23 endpoints := r.cache.get()
24 if len(endpoints) <= 0 {
25 return nil, loadbalancer.ErrNoEndpointsAvailable
26 }
27 var old uint64
28 for {
29 old = atomic.LoadUint64(&r.uint64)
30 if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) {
31 break
32 }
33 }
34 return endpoints[old%uint64(len(endpoints))], nil
35 }
+0
-46
loadbalancer/strategy/round_robin_test.go less more
0 package strategy_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/publisher/static"
10 "github.com/go-kit/kit/loadbalancer/strategy"
11 )
12
13 func TestRoundRobin(t *testing.T) {
14 p := static.NewPublisher([]endpoint.Endpoint{})
15 defer p.Stop()
16
17 lb := strategy.RoundRobin(p)
18 if _, err := lb.Get(); err == nil {
19 t.Error("want error, got none")
20 }
21
22 counts := []int{0, 0, 0}
23 p.Replace([]endpoint.Endpoint{
24 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
25 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
26 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
27 })
28 assertLoadBalancerNotEmpty(t, lb)
29
30 for i, want := range [][]int{
31 {1, 0, 0},
32 {1, 1, 0},
33 {1, 1, 1},
34 {2, 1, 1},
35 {2, 2, 1},
36 {2, 2, 2},
37 {3, 2, 2},
38 } {
39 e, _ := lb.Get()
40 e(context.Background(), struct{}{})
41 if have := counts; !reflect.DeepEqual(want, have) {
42 t.Errorf("%d: want %v, have %v", i+1, want, have)
43 }
44 }
45 }
+0
-9
loadbalancer/strategy/strategy.go less more
0 package strategy
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Strategy yields endpoints to consumers according to some algorithm.
5 type Strategy interface {
6 Next() (endpoint.Endpoint, error)
7 Stop()
8 }
+0
-35
loadbalancer/strategy/util_test.go less more
0 package strategy_test
1
2 import (
3 "fmt"
4 "testing"
5 "time"
6
7 "github.com/go-kit/kit/loadbalancer"
8 )
9
10 func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) {
11 if err := within(10*time.Millisecond, func() bool {
12 return lb.Count() > 0
13 }); err != nil {
14 t.Fatal("Publisher never updated endpoints")
15 }
16 }
17
18 func within(d time.Duration, f func() bool) error {
19 var (
20 deadline = time.After(d)
21 ticker = time.NewTicker(d / 10)
22 )
23 defer ticker.Stop()
24 for {
25 select {
26 case <-ticker.C:
27 if f() {
28 return nil
29 }
30 case <-deadline:
31 return fmt.Errorf("deadline exceeded")
32 }
33 }
34 }