Big re-org
Having Publishers return a set of endpoints directly (synchronously)
allows us to eliminate a lot of boilerplate related to pub/sub
semantics, as well as the cache type. Thanks, @rogpeppe!
Peter Bourgon
8 years ago
0 | # package loadbalancer | |
1 | ||
2 | `package loadbalancer` provides a client-side load balancer abstraction. | |
3 | ||
4 | A publisher is responsible for emitting the most recent set of endpoints for a | |
5 | single logical service. Publishers exist for static endpoints, and endpoints | |
6 | discovered via periodic DNS SRV lookups on a single logical name. Consul and | |
7 | etcd publishers are planned. | |
8 | ||
9 | Different load balancing strategies are implemented on top of publishers. Go | |
10 | kit currently provides random and round-robin semantics. Smarter behaviors, | |
11 | e.g. load balancing based on underlying endpoint priority/weight, is planned. | |
12 | ||
13 | ## Rationale | |
14 | ||
15 | TODO | |
16 | ||
17 | ## Usage | |
18 | ||
19 | In your client, define a publisher, wrap it with a balancing strategy, and pass | |
20 | it to a retry strategy, which returns an endpoint. Use that endpoint to make | |
21 | requests, or wrap it with other value-add middleware. | |
22 | ||
23 | ```go | |
24 | func main() { | |
25 | var ( | |
26 | fooPublisher = loadbalancer.NewDNSSRVPublisher("foo.mynet.local", 5*time.Second, makeEndpoint) | |
27 | fooBalancer = loadbalancer.RoundRobin(fooPublisher) | |
28 | fooEndpoint = loadbalancer.Retry(3, time.Second, fooBalancer) | |
29 | ) | |
30 | http.HandleFunc("/", handle(fooEndpoint)) | |
31 | log.Fatal(http.ListenAndServe(":8080", nil)) | |
32 | } | |
33 | ||
34 | func makeEndpoint(hostport string) endpoint.Endpoint { | |
35 | // Convert a host:port to a endpoint via your defined transport. | |
36 | } | |
37 | ||
38 | func handle(foo endpoint.Endpoint) http.HandlerFunc { | |
39 | return func(w http.ResponseWriter, r *http.Request) { | |
40 | // foo is usable as a load-balanced remote endpoint. | |
41 | } | |
42 | } | |
43 | ``` |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "crypto/md5" | |
4 | "fmt" | |
5 | "net" | |
6 | "sort" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer" | |
11 | "github.com/go-kit/kit/log" | |
12 | ) | |
13 | ||
14 | // Publisher yields endpoints taken from the named DNS SRV record. The name is | |
15 | // resolved on a fixed schedule. Priorities and weights are ignored. | |
16 | type Publisher struct { | |
17 | name string | |
18 | ttl time.Duration | |
19 | factory loadbalancer.Factory | |
20 | logger log.Logger | |
21 | endpoints chan []endpoint.Endpoint | |
22 | quit chan struct{} | |
23 | } | |
24 | ||
25 | // NewPublisher returns a DNS SRV publisher. The name is resolved | |
26 | // synchronously as part of construction; if that resolution fails, the | |
27 | // constructor will return an error. The factory is used to convert a | |
28 | // host:port to a usable endpoint. The logger is used to report DNS and | |
29 | // factory errors. | |
30 | func NewPublisher(name string, ttl time.Duration, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { | |
31 | logger = log.NewContext(logger).With("component", "DNS SRV Publisher") | |
32 | addrs, md5, err := resolve(name) | |
33 | if err != nil { | |
34 | return nil, err | |
35 | } | |
36 | p := &Publisher{ | |
37 | name: name, | |
38 | ttl: ttl, | |
39 | factory: f, | |
40 | logger: logger, | |
41 | endpoints: make(chan []endpoint.Endpoint), | |
42 | quit: make(chan struct{}), | |
43 | } | |
44 | go p.loop(lift(addrs, f, logger), md5) | |
45 | return p, nil | |
46 | } | |
47 | ||
48 | // Stop terminates the publisher. | |
49 | func (p *Publisher) Stop() { | |
50 | close(p.quit) | |
51 | } | |
52 | ||
53 | func (p *Publisher) loop(endpoints []endpoint.Endpoint, md5 string) { | |
54 | t := newTicker(p.ttl) | |
55 | defer t.Stop() | |
56 | for { | |
57 | select { | |
58 | case p.endpoints <- endpoints: | |
59 | ||
60 | case <-t.C: | |
61 | // TODO should we do this out-of-band? | |
62 | addrs, newmd5, err := resolve(p.name) | |
63 | if err != nil { | |
64 | p.logger.Log("name", p.name, "err", err) | |
65 | continue // don't replace good endpoints with bad ones | |
66 | } | |
67 | if newmd5 == md5 { | |
68 | continue // no change | |
69 | } | |
70 | endpoints = lift(addrs, p.factory, p.logger) | |
71 | md5 = newmd5 | |
72 | ||
73 | case <-p.quit: | |
74 | return | |
75 | } | |
76 | } | |
77 | } | |
78 | ||
79 | // Endpoints implements the Publisher interface. | |
80 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
81 | return <-p.endpoints, nil | |
82 | } | |
83 | ||
84 | var ( | |
85 | lookupSRV = net.LookupSRV | |
86 | newTicker = time.NewTicker | |
87 | ) | |
88 | ||
89 | func resolve(name string) (addrs []*net.SRV, md5sum string, err error) { | |
90 | _, addrs, err = lookupSRV("", "", name) | |
91 | if err != nil { | |
92 | return addrs, "", err | |
93 | } | |
94 | hostports := make([]string, len(addrs)) | |
95 | for i, addr := range addrs { | |
96 | hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) | |
97 | } | |
98 | sort.Sort(sort.StringSlice(hostports)) | |
99 | h := md5.New() | |
100 | for _, hostport := range hostports { | |
101 | fmt.Fprintf(h, hostport) | |
102 | } | |
103 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
104 | } | |
105 | ||
106 | func lift(addrs []*net.SRV, f loadbalancer.Factory, logger log.Logger) []endpoint.Endpoint { | |
107 | endpoints := make([]endpoint.Endpoint, 0, len(addrs)) | |
108 | for _, addr := range addrs { | |
109 | endpoint, err := f(addr2instance(addr)) | |
110 | if err != nil { | |
111 | logger.Log("instance", addr2instance(addr), "err", err) | |
112 | continue | |
113 | } | |
114 | endpoints = append(endpoints, endpoint) | |
115 | } | |
116 | return endpoints | |
117 | } | |
118 | ||
119 | func addr2instance(addr *net.SRV) string { | |
120 | return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
121 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "net" | |
5 | "sync/atomic" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | "golang.org/x/net/context" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | "github.com/go-kit/kit/log" | |
13 | ) | |
14 | ||
15 | func TestPublisher(t *testing.T) { | |
16 | var ( | |
17 | target = "my-target" | |
18 | port = uint16(1234) | |
19 | addr = &net.SRV{Target: target, Port: port} | |
20 | addrs = []*net.SRV{addr} | |
21 | name = "my-name" | |
22 | ttl = time.Second | |
23 | logger = log.NewNopLogger() | |
24 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
25 | ) | |
26 | ||
27 | oldLookup := lookupSRV | |
28 | defer func() { lookupSRV = oldLookup }() | |
29 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
30 | ||
31 | factory := func(instance string) (endpoint.Endpoint, error) { | |
32 | if want, have := addr2instance(addr), instance; want != have { | |
33 | t.Errorf("want %q, have %q", want, have) | |
34 | } | |
35 | return e, nil | |
36 | } | |
37 | ||
38 | p, err := NewPublisher(name, ttl, factory, logger) | |
39 | if err != nil { | |
40 | t.Fatal(err) | |
41 | } | |
42 | defer p.Stop() | |
43 | ||
44 | if _, err := p.Endpoints(); err != nil { | |
45 | t.Fatal(err) | |
46 | } | |
47 | } | |
48 | ||
49 | func TestBadLookup(t *testing.T) { | |
50 | oldLookup := lookupSRV | |
51 | defer func() { lookupSRV = oldLookup }() | |
52 | lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil) | |
53 | ||
54 | var ( | |
55 | name = "some-name" | |
56 | ttl = time.Second | |
57 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("unreachable") } | |
58 | logger = log.NewNopLogger() | |
59 | ) | |
60 | ||
61 | if _, err := NewPublisher(name, ttl, factory, logger); err == nil { | |
62 | t.Fatal("wanted error, got none") | |
63 | } | |
64 | } | |
65 | ||
66 | func TestBadFactory(t *testing.T) { | |
67 | var ( | |
68 | addr = &net.SRV{Target: "foo", Port: 1234} | |
69 | addrs = []*net.SRV{addr} | |
70 | name = "some-name" | |
71 | ttl = time.Second | |
72 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } | |
73 | logger = log.NewNopLogger() | |
74 | ) | |
75 | ||
76 | oldLookup := lookupSRV | |
77 | defer func() { lookupSRV = oldLookup }() | |
78 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
79 | ||
80 | p, err := NewPublisher(name, ttl, factory, logger) | |
81 | if err != nil { | |
82 | t.Fatal(err) | |
83 | } | |
84 | defer p.Stop() | |
85 | ||
86 | endpoints, err := p.Endpoints() | |
87 | if err != nil { | |
88 | t.Fatal(err) | |
89 | } | |
90 | if want, have := 0, len(endpoints); want != have { | |
91 | t.Errorf("want %q, have %q", want, have) | |
92 | } | |
93 | } | |
94 | ||
95 | func TestRefreshWithChange(t *testing.T) { | |
96 | t.Skip("TODO") | |
97 | } | |
98 | ||
99 | func TestRefreshNoChange(t *testing.T) { | |
100 | var ( | |
101 | tick = make(chan time.Time) | |
102 | target = "my-target" | |
103 | port = uint16(5678) | |
104 | addr = &net.SRV{Target: target, Port: port} | |
105 | addrs = []*net.SRV{addr} | |
106 | name = "my-name" | |
107 | ttl = time.Second | |
108 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } | |
109 | logger = log.NewNopLogger() | |
110 | ) | |
111 | ||
112 | oldTicker := newTicker | |
113 | defer func() { newTicker = oldTicker }() | |
114 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} } | |
115 | ||
116 | var resolves uint64 | |
117 | oldLookup := lookupSRV | |
118 | defer func() { lookupSRV = oldLookup }() | |
119 | lookupSRV = mockLookupSRV(addrs, nil, &resolves) | |
120 | ||
121 | p, err := NewPublisher(name, ttl, factory, logger) | |
122 | if err != nil { | |
123 | t.Fatal(err) | |
124 | } | |
125 | defer p.Stop() | |
126 | ||
127 | tick <- time.Now() | |
128 | if want, have := uint64(2), resolves; want != have { | |
129 | t.Errorf("want %d, have %d", want, have) | |
130 | } | |
131 | } | |
132 | ||
133 | func TestRefreshResolveError(t *testing.T) { | |
134 | t.Skip("TODO") | |
135 | } | |
136 | ||
137 | func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) { | |
138 | return func(service, proto, name string) (string, []*net.SRV, error) { | |
139 | if count != nil { | |
140 | atomic.AddUint64(count, 1) | |
141 | } | |
142 | return "", addrs, err | |
143 | } | |
144 | } |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Factory is a function that converts an instance string, e.g. a host:port, | |
5 | // to a usable endpoint. Factories are used by load balancers to lift | |
6 | // instances returned by Publishers into endpoints. Users are expected to | |
7 | // provide their own factory functions that assume specific transports, or can | |
8 | // deduce transports by parsing the instance string. | |
9 | type Factory func(instance string) (endpoint.Endpoint, error) |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // LoadBalancer yields endpoints one-by-one. | |
9 | type LoadBalancer interface { | |
10 | Count() int | |
11 | Get() (endpoint.Endpoint, error) | |
12 | } | |
13 | ||
14 | // ErrNoEndpointsAvailable is given by a load balancer or strategy when no | |
15 | // endpoints are available to be returned. | |
16 | var ErrNoEndpointsAvailable = errors.New("no endpoints available") |
0 | package loadbalancer | |
1 | ||
2 | import "errors" | |
3 | ||
4 | // ErrNoEndpoints is returned when a load balancer (or one of its components) | |
5 | // has no endpoints to return. In a request lifecycle, this is usually a fatal | |
6 | // error. | |
7 | var ErrNoEndpoints = errors.New("no endpoints available") |
0 | package dns | |
1 | ||
2 | import ( | |
3 | "crypto/md5" | |
4 | "fmt" | |
5 | "net" | |
6 | "sort" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | // SRVPublisher implements Publisher. | |
13 | type SRVPublisher struct { | |
14 | subscribe chan chan<- []endpoint.Endpoint | |
15 | unsubscribe chan chan<- []endpoint.Endpoint | |
16 | quit chan struct{} | |
17 | } | |
18 | ||
19 | // NewSRVPublisher returns a publisher that resolves the SRV name every ttl, | |
20 | // and yields endpoints constructed via the makeEndpoint factory. | |
21 | func NewSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) *SRVPublisher { | |
22 | p := &SRVPublisher{ | |
23 | subscribe: make(chan chan<- []endpoint.Endpoint), | |
24 | unsubscribe: make(chan chan<- []endpoint.Endpoint), | |
25 | quit: make(chan struct{}), | |
26 | } | |
27 | go p.loop(name, ttl, makeEndpoint) | |
28 | return p | |
29 | } | |
30 | ||
31 | // Subscribe implements Publisher. | |
32 | func (p *SRVPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
33 | p.subscribe <- c | |
34 | } | |
35 | ||
36 | // Unsubscribe implements Publisher. | |
37 | func (p *SRVPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
38 | p.unsubscribe <- c | |
39 | } | |
40 | ||
41 | // Stop implements Publisher. | |
42 | func (p *SRVPublisher) Stop() { | |
43 | close(p.quit) | |
44 | } | |
45 | ||
46 | var newTicker = time.NewTicker | |
47 | ||
48 | func (p *SRVPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) { | |
49 | var ( | |
50 | subscriptions = map[chan<- []endpoint.Endpoint]struct{}{} | |
51 | addrs, md5, _ = resolve(name) | |
52 | endpoints = convert(addrs, makeEndpoint) | |
53 | ticker = newTicker(ttl) | |
54 | ) | |
55 | defer ticker.Stop() | |
56 | for { | |
57 | select { | |
58 | case <-ticker.C: | |
59 | addrs, newmd5, err := resolve(name) | |
60 | if err == nil && newmd5 != md5 { | |
61 | endpoints = convert(addrs, makeEndpoint) | |
62 | for c := range subscriptions { | |
63 | c <- endpoints | |
64 | } | |
65 | md5 = newmd5 | |
66 | } | |
67 | ||
68 | case c := <-p.subscribe: | |
69 | subscriptions[c] = struct{}{} | |
70 | c <- endpoints | |
71 | ||
72 | case c := <-p.unsubscribe: | |
73 | delete(subscriptions, c) | |
74 | ||
75 | case <-p.quit: | |
76 | return | |
77 | } | |
78 | } | |
79 | } | |
80 | ||
81 | // Allow mocking in tests. | |
82 | var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) { | |
83 | _, addrs, err = net.LookupSRV("", "", name) | |
84 | if err != nil { | |
85 | return addrs, "", err | |
86 | } | |
87 | hostports := make([]string, len(addrs)) | |
88 | for i, addr := range addrs { | |
89 | hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) | |
90 | } | |
91 | sort.Sort(sort.StringSlice(hostports)) | |
92 | h := md5.New() | |
93 | for _, hostport := range hostports { | |
94 | fmt.Fprintf(h, hostport) | |
95 | } | |
96 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
97 | } | |
98 | ||
99 | func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint { | |
100 | endpoints := make([]endpoint.Endpoint, len(addrs)) | |
101 | for i, addr := range addrs { | |
102 | endpoints[i] = makeEndpoint(addr2hostport(addr)) | |
103 | } | |
104 | return endpoints | |
105 | } | |
106 | ||
107 | func addr2hostport(addr *net.SRV) string { | |
108 | return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
109 | } |
0 | package dns | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | "golang.org/x/net/context" | |
9 | ||
10 | "github.com/go-kit/kit/endpoint" | |
11 | ) | |
12 | ||
13 | func TestDNSSRVPublisher(t *testing.T) { | |
14 | // Reset the vars when we're done | |
15 | oldResolve := resolve | |
16 | defer func() { resolve = oldResolve }() | |
17 | oldNewTicker := newTicker | |
18 | defer func() { newTicker = oldNewTicker }() | |
19 | ||
20 | // Set up a fixture and swap the vars | |
21 | a := []*net.SRV{ | |
22 | {Target: "foo", Port: 123}, | |
23 | {Target: "bar", Port: 456}, | |
24 | {Target: "baz", Port: 789}, | |
25 | } | |
26 | ticker := make(chan time.Time) | |
27 | resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil } | |
28 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} } | |
29 | ||
30 | // Construct endpoint | |
31 | m := map[string]int{} | |
32 | e := func(hostport string) endpoint.Endpoint { | |
33 | return func(context.Context, interface{}) (interface{}, error) { | |
34 | m[hostport]++ | |
35 | return struct{}{}, nil | |
36 | } | |
37 | } | |
38 | ||
39 | // Build the publisher | |
40 | var ( | |
41 | name = "irrelevant" | |
42 | ttl = time.Second | |
43 | makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) } | |
44 | ) | |
45 | p := NewSRVPublisher(name, ttl, makeEndpoint) | |
46 | defer p.Stop() | |
47 | ||
48 | // Subscribe | |
49 | c := make(chan []endpoint.Endpoint, 1) | |
50 | p.Subscribe(c) | |
51 | defer p.Unsubscribe(c) | |
52 | ||
53 | // Invoke all of the endpoints | |
54 | for _, e := range <-c { | |
55 | e(context.Background(), struct{}{}) | |
56 | } | |
57 | ||
58 | // Make sure we invoked what we expected to | |
59 | for _, addr := range a { | |
60 | hostport := addr2hostport(addr) | |
61 | if want, have := 1, m[hostport]; want != have { | |
62 | t.Errorf("%q: want %d, have %d", name, want, have) | |
63 | } | |
64 | delete(m, hostport) | |
65 | } | |
66 | if want, have := 0, len(m); want != have { | |
67 | t.Errorf("want %d, have %d", want, have) | |
68 | } | |
69 | ||
70 | // Reset the fixture, trigger the timer, count the endpoints | |
71 | a = []*net.SRV{} | |
72 | ticker <- time.Now() | |
73 | if want, have := len(a), len(<-c); want != have { | |
74 | t.Errorf("want %d, have %d", want, have) | |
75 | } | |
76 | } |
0 | package publisher | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Publisher produces endpoints. | |
5 | type Publisher interface { | |
6 | Subscribe(chan<- []endpoint.Endpoint) | |
7 | Unsubscribe(chan<- []endpoint.Endpoint) | |
8 | Stop() | |
9 | } |
0 | package static | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Publisher holds a static set of endpoints. | |
9 | type Publisher struct { | |
10 | mu sync.Mutex | |
11 | current []endpoint.Endpoint | |
12 | subscribers map[chan<- []endpoint.Endpoint]struct{} | |
13 | } | |
14 | ||
15 | // NewPublisher returns a publisher that yields a static set of endpoints, | |
16 | // which can be completely replaced. | |
17 | func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { | |
18 | return &Publisher{ | |
19 | current: endpoints, | |
20 | subscribers: map[chan<- []endpoint.Endpoint]struct{}{}, | |
21 | } | |
22 | } | |
23 | ||
24 | // Subscribe implements Publisher. | |
25 | func (p *Publisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
26 | p.mu.Lock() | |
27 | defer p.mu.Unlock() | |
28 | p.subscribers[c] = struct{}{} | |
29 | c <- p.current | |
30 | } | |
31 | ||
32 | // Unsubscribe implements Publisher. | |
33 | func (p *Publisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
34 | p.mu.Lock() | |
35 | defer p.mu.Unlock() | |
36 | delete(p.subscribers, c) | |
37 | } | |
38 | ||
39 | // Stop implements Publisher, but is a no-op. | |
40 | func (p *Publisher) Stop() {} | |
41 | ||
42 | // Replace replaces the endpoints and notifies all subscribers. | |
43 | func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { | |
44 | p.mu.Lock() | |
45 | defer p.mu.Unlock() | |
46 | p.current = endpoints | |
47 | for c := range p.subscribers { | |
48 | c <- p.current | |
49 | } | |
50 | } |
0 | package static_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "golang.org/x/net/context" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer/publisher/static" | |
9 | ) | |
10 | ||
11 | func TestStaticPublisher(t *testing.T) { | |
12 | endpoints := []endpoint.Endpoint{ | |
13 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
14 | } | |
15 | p := static.NewPublisher(endpoints) | |
16 | defer p.Stop() | |
17 | ||
18 | c := make(chan []endpoint.Endpoint, 1) | |
19 | p.Subscribe(c) | |
20 | if want, have := len(endpoints), len(<-c); want != have { | |
21 | t.Errorf("want %d, have %d", want, have) | |
22 | } | |
23 | ||
24 | endpoints = []endpoint.Endpoint{} | |
25 | p.Replace(endpoints) | |
26 | if want, have := len(endpoints), len(<-c); want != have { | |
27 | t.Errorf("want %d, have %d", want, have) | |
28 | } | |
29 | } |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Publisher describes something that provides a set of identical endpoints. | |
5 | // Different publisher implementations exist for different kinds of service | |
6 | // discovery systems. | |
7 | type Publisher interface { | |
8 | Endpoints() ([]endpoint.Endpoint, error) | |
9 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Random is a completely stateless load balancer that chooses a random | |
9 | // endpoint to return each time. | |
10 | type Random struct { | |
11 | p Publisher | |
12 | r *rand.Rand | |
13 | } | |
14 | ||
15 | // NewRandom returns a new Random load balancer. | |
16 | func NewRandom(p Publisher, seed int64) *Random { | |
17 | return &Random{ | |
18 | p: p, | |
19 | r: rand.New(rand.NewSource(seed)), | |
20 | } | |
21 | } | |
22 | ||
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (r *Random) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := r.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
29 | if len(endpoints) <= 0 { | |
30 | return nil, ErrNoEndpoints | |
31 | } | |
32 | return endpoints[r.r.Intn(len(endpoints))], nil | |
33 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "math" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer" | |
10 | "github.com/go-kit/kit/loadbalancer/static" | |
11 | ) | |
12 | ||
13 | func TestRandomDistribution(t *testing.T) { | |
14 | var ( | |
15 | n = 3 | |
16 | endpoints = make([]endpoint.Endpoint, n) | |
17 | counts = make([]int, n) | |
18 | seed = int64(123) | |
19 | ctx = context.Background() | |
20 | iterations = 100000 | |
21 | want = iterations / n | |
22 | tolerance = want / 100 // 1% | |
23 | ) | |
24 | ||
25 | for i := 0; i < n; i++ { | |
26 | i0 := i | |
27 | endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } | |
28 | } | |
29 | ||
30 | lb := loadbalancer.NewRandom(static.Publisher(endpoints), seed) | |
31 | ||
32 | for i := 0; i < iterations; i++ { | |
33 | e, err := lb.Endpoint() | |
34 | if err != nil { | |
35 | t.Fatal(err) | |
36 | } | |
37 | e(ctx, struct{}{}) | |
38 | } | |
39 | ||
40 | for i, have := range counts { | |
41 | if math.Abs(float64(want-have)) > float64(tolerance) { | |
42 | t.Errorf("%d: want %d, have %d", i, want, have) | |
43 | } | |
44 | } | |
45 | } | |
46 | ||
47 | func TestRandomBadPublisher(t *testing.T) { | |
48 | t.Skip("TODO") | |
49 | } | |
50 | ||
51 | func TestRandomNoEndpoints(t *testing.T) { | |
52 | lb := loadbalancer.NewRandom(static.Publisher([]endpoint.Endpoint{}), 123) | |
53 | _, have := lb.Endpoint() | |
54 | if want := loadbalancer.ErrNoEndpoints; want != have { | |
55 | t.Errorf("want %q, have %q", want, have) | |
56 | } | |
57 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | // Retry yields an endpoint that takes endpoints from the load balancer. | |
13 | // Invocations that return errors will be retried until they succeed, up to | |
14 | // max times, or until the timeout is elapsed, whichever comes first. | |
15 | func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { | |
16 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
17 | var ( | |
18 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
19 | responses = make(chan interface{}, 1) | |
20 | errs = make(chan error, 1) | |
21 | a = []string{} | |
22 | ) | |
23 | defer cancel() | |
24 | for i := 1; i <= max; i++ { | |
25 | go func() { | |
26 | e, err := lb.Get() | |
27 | if err != nil { | |
28 | errs <- err | |
29 | return | |
30 | } | |
31 | response, err := e(newctx, request) | |
32 | if err != nil { | |
33 | errs <- err | |
34 | return | |
35 | } | |
36 | responses <- response | |
37 | }() | |
38 | ||
39 | select { | |
40 | case <-newctx.Done(): | |
41 | return nil, newctx.Err() | |
42 | case response := <-responses: | |
43 | return response, nil | |
44 | case err := <-errs: | |
45 | a = append(a, err.Error()) | |
46 | continue | |
47 | } | |
48 | } | |
49 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
50 | } | |
51 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer" | |
11 | "github.com/go-kit/kit/loadbalancer/publisher/static" | |
12 | "github.com/go-kit/kit/loadbalancer/strategy" | |
13 | ) | |
14 | ||
15 | func TestRetryMax(t *testing.T) { | |
16 | var ( | |
17 | endpoints = []endpoint.Endpoint{} | |
18 | p = static.NewPublisher(endpoints) | |
19 | lb = strategy.RoundRobin(p) | |
20 | ) | |
21 | ||
22 | if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
23 | t.Errorf("expected error, got none") | |
24 | } | |
25 | ||
26 | endpoints = []endpoint.Endpoint{ | |
27 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
28 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
29 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
30 | } | |
31 | p.Replace(endpoints) | |
32 | time.Sleep(10 * time.Millisecond) //assertLoadBalancerNotEmpty(t, lb) // TODO | |
33 | ||
34 | if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
35 | t.Errorf("expected error, got none") | |
36 | } | |
37 | ||
38 | if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil { | |
39 | t.Error(err) | |
40 | } | |
41 | } | |
42 | ||
43 | func TestRetryTimeout(t *testing.T) { | |
44 | var ( | |
45 | step = make(chan struct{}) | |
46 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } | |
47 | timeout = time.Millisecond | |
48 | retry = loadbalancer.Retry(999, timeout, strategy.RoundRobin(static.NewPublisher([]endpoint.Endpoint{e}))) | |
49 | errs = make(chan error) | |
50 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } | |
51 | ) | |
52 | ||
53 | go invoke() // invoke the endpoint | |
54 | step <- struct{}{} // tell the endpoint to return | |
55 | if err := <-errs; err != nil { // that should succeed | |
56 | t.Error(err) | |
57 | } | |
58 | ||
59 | go invoke() // invoke the endpoint | |
60 | time.Sleep(2 * timeout) // wait | |
61 | time.Sleep(2 * timeout) // wait again (CI servers!!) | |
62 | step <- struct{}{} // tell the endpoint to return | |
63 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
64 | t.Errorf("wanted error, got none") | |
65 | } | |
66 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "sync/atomic" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // RoundRobin is a simple load balancer that returns each of the published | |
9 | // endpoints in sequence. | |
10 | type RoundRobin struct { | |
11 | p Publisher | |
12 | counter uint64 | |
13 | } | |
14 | ||
15 | // NewRoundRobin returns a new RoundRobin load balancer. | |
16 | func NewRoundRobin(p Publisher) *RoundRobin { | |
17 | return &RoundRobin{ | |
18 | p: p, | |
19 | counter: 0, | |
20 | } | |
21 | } | |
22 | ||
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := rr.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
29 | if len(endpoints) <= 0 { | |
30 | return nil, ErrNoEndpoints | |
31 | } | |
32 | var old uint64 | |
33 | for { | |
34 | old = atomic.LoadUint64(&rr.counter) | |
35 | if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) { | |
36 | break | |
37 | } | |
38 | } | |
39 | return endpoints[old%uint64(len(endpoints))], nil | |
40 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/loadbalancer" | |
8 | "github.com/go-kit/kit/loadbalancer/static" | |
9 | "golang.org/x/net/context" | |
10 | ) | |
11 | ||
12 | func TestRoundRobinDistribution(t *testing.T) { | |
13 | var ( | |
14 | ctx = context.Background() | |
15 | counts = []int{0, 0, 0} | |
16 | endpoints = []endpoint.Endpoint{ | |
17 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
18 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
19 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
20 | } | |
21 | ) | |
22 | ||
23 | lb := loadbalancer.NewRoundRobin(static.Publisher(endpoints)) | |
24 | ||
25 | for i, want := range [][]int{ | |
26 | {1, 0, 0}, | |
27 | {1, 1, 0}, | |
28 | {1, 1, 1}, | |
29 | {2, 1, 1}, | |
30 | {2, 2, 1}, | |
31 | {2, 2, 2}, | |
32 | {3, 2, 2}, | |
33 | } { | |
34 | e, err := lb.Endpoint() | |
35 | if err != nil { | |
36 | t.Fatal(err) | |
37 | } | |
38 | e(ctx, struct{}{}) | |
39 | if have := counts; !reflect.DeepEqual(want, have) { | |
40 | t.Fatalf("%d: want %v, have %v", i, want, have) | |
41 | } | |
42 | ||
43 | } | |
44 | } | |
45 | ||
46 | func TestRoundRobinBadPublisher(t *testing.T) { | |
47 | t.Skip("TODO") | |
48 | } |
0 | package static | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Publisher yields the same set of static endpoints. | |
5 | type Publisher []endpoint.Endpoint | |
6 | ||
7 | // Endpoints implements the Publisher interface. | |
8 | func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) { return p, nil } |
0 | package static_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer/static" | |
10 | ) | |
11 | ||
12 | func TestStatic(t *testing.T) { | |
13 | var ( | |
14 | e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
15 | e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
16 | endpoints = []endpoint.Endpoint{e1, e2} | |
17 | ) | |
18 | p := static.Publisher(endpoints) | |
19 | have, err := p.Endpoints() | |
20 | if err != nil { | |
21 | t.Fatal(err) | |
22 | } | |
23 | if want := endpoints; !reflect.DeepEqual(want, have) { | |
24 | t.Fatalf("want %#+v, have %#+v", want, have) | |
25 | } | |
26 | } |
0 | package strategy | |
1 | ||
2 | import ( | |
3 | "github.com/go-kit/kit/endpoint" | |
4 | "github.com/go-kit/kit/loadbalancer/publisher" | |
5 | ) | |
6 | ||
7 | type cache struct { | |
8 | req chan []endpoint.Endpoint | |
9 | cnt chan int | |
10 | quit chan struct{} | |
11 | } | |
12 | ||
13 | func newCache(p publisher.Publisher) *cache { | |
14 | c := &cache{ | |
15 | req: make(chan []endpoint.Endpoint), | |
16 | cnt: make(chan int), | |
17 | quit: make(chan struct{}), | |
18 | } | |
19 | go c.loop(p) | |
20 | return c | |
21 | } | |
22 | ||
23 | func (c *cache) loop(p publisher.Publisher) { | |
24 | e := make(chan []endpoint.Endpoint, 1) | |
25 | p.Subscribe(e) | |
26 | defer p.Unsubscribe(e) | |
27 | endpoints := <-e | |
28 | for { | |
29 | select { | |
30 | case endpoints = <-e: | |
31 | case c.cnt <- len(endpoints): | |
32 | case c.req <- endpoints: | |
33 | case <-c.quit: | |
34 | return | |
35 | } | |
36 | } | |
37 | } | |
38 | ||
39 | func (c *cache) count() int { | |
40 | return <-c.cnt | |
41 | } | |
42 | ||
43 | func (c *cache) get() []endpoint.Endpoint { | |
44 | return <-c.req | |
45 | } | |
46 | ||
47 | func (c *cache) stop() { | |
48 | close(c.quit) | |
49 | } |
0 | package strategy | |
1 | ||
2 | import ( | |
3 | "runtime" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer/publisher/static" | |
10 | ) | |
11 | ||
12 | func TestCache(t *testing.T) { | |
13 | e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
14 | endpoints := []endpoint.Endpoint{e} | |
15 | ||
16 | p := static.NewPublisher(endpoints) | |
17 | defer p.Stop() | |
18 | ||
19 | c := newCache(p) | |
20 | defer c.stop() | |
21 | ||
22 | for _, n := range []int{2, 10, 0} { | |
23 | endpoints = make([]endpoint.Endpoint, n) | |
24 | for i := 0; i < n; i++ { | |
25 | endpoints[i] = e | |
26 | } | |
27 | p.Replace(endpoints) | |
28 | runtime.Gosched() | |
29 | if want, have := len(endpoints), len(c.get()); want != have { | |
30 | t.Errorf("want %d, have %d", want, have) | |
31 | } | |
32 | } | |
33 | } |
0 | package strategy | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/loadbalancer" | |
7 | "github.com/go-kit/kit/loadbalancer/publisher" | |
8 | ) | |
9 | ||
10 | // Random returns a load balancer that yields random endpoints. | |
11 | func Random(p publisher.Publisher) loadbalancer.LoadBalancer { | |
12 | return random{newCache(p)} | |
13 | } | |
14 | ||
15 | type random struct{ *cache } | |
16 | ||
17 | func (r random) Count() int { return r.cache.count() } | |
18 | ||
19 | func (r random) Get() (endpoint.Endpoint, error) { | |
20 | endpoints := r.cache.get() | |
21 | if len(endpoints) <= 0 { | |
22 | return nil, loadbalancer.ErrNoEndpointsAvailable | |
23 | } | |
24 | return endpoints[rand.Intn(len(endpoints))], nil | |
25 | } |
0 | package strategy_test | |
1 | ||
2 | import ( | |
3 | "math" | |
4 | "testing" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/loadbalancer/publisher/static" | |
8 | "github.com/go-kit/kit/loadbalancer/strategy" | |
9 | "golang.org/x/net/context" | |
10 | ) | |
11 | ||
12 | func TestRandom(t *testing.T) { | |
13 | p := static.NewPublisher([]endpoint.Endpoint{}) | |
14 | defer p.Stop() | |
15 | ||
16 | lb := strategy.Random(p) | |
17 | if _, err := lb.Get(); err == nil { | |
18 | t.Error("want error, got none") | |
19 | } | |
20 | ||
21 | counts := []int{0, 0, 0} | |
22 | p.Replace([]endpoint.Endpoint{ | |
23 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
25 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
26 | }) | |
27 | assertLoadBalancerNotEmpty(t, lb) | |
28 | ||
29 | n := 10000 | |
30 | for i := 0; i < n; i++ { | |
31 | e, _ := lb.Get() | |
32 | e(context.Background(), struct{}{}) | |
33 | } | |
34 | ||
35 | want := float64(n) / float64(len(counts)) | |
36 | tolerance := (want / 100.0) * 5 // 5% | |
37 | for _, have := range counts { | |
38 | if math.Abs(want-float64(have)) > tolerance { | |
39 | t.Errorf("want %.0f, have %d", want, have) | |
40 | } | |
41 | } | |
42 | } |
0 | package strategy | |
1 | ||
2 | import ( | |
3 | "sync/atomic" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/loadbalancer" | |
7 | "github.com/go-kit/kit/loadbalancer/publisher" | |
8 | ) | |
9 | ||
10 | // RoundRobin returns a load balancer that yields endpoints in sequence. | |
11 | func RoundRobin(p publisher.Publisher) loadbalancer.LoadBalancer { | |
12 | return &roundRobin{newCache(p), 0} | |
13 | } | |
14 | ||
15 | type roundRobin struct { | |
16 | *cache | |
17 | uint64 | |
18 | } | |
19 | ||
20 | func (r *roundRobin) Count() int { return r.cache.count() } | |
21 | ||
22 | func (r *roundRobin) Get() (endpoint.Endpoint, error) { | |
23 | endpoints := r.cache.get() | |
24 | if len(endpoints) <= 0 { | |
25 | return nil, loadbalancer.ErrNoEndpointsAvailable | |
26 | } | |
27 | var old uint64 | |
28 | for { | |
29 | old = atomic.LoadUint64(&r.uint64) | |
30 | if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) { | |
31 | break | |
32 | } | |
33 | } | |
34 | return endpoints[old%uint64(len(endpoints))], nil | |
35 | } |
0 | package strategy_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer/publisher/static" | |
10 | "github.com/go-kit/kit/loadbalancer/strategy" | |
11 | ) | |
12 | ||
13 | func TestRoundRobin(t *testing.T) { | |
14 | p := static.NewPublisher([]endpoint.Endpoint{}) | |
15 | defer p.Stop() | |
16 | ||
17 | lb := strategy.RoundRobin(p) | |
18 | if _, err := lb.Get(); err == nil { | |
19 | t.Error("want error, got none") | |
20 | } | |
21 | ||
22 | counts := []int{0, 0, 0} | |
23 | p.Replace([]endpoint.Endpoint{ | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
25 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
26 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
27 | }) | |
28 | assertLoadBalancerNotEmpty(t, lb) | |
29 | ||
30 | for i, want := range [][]int{ | |
31 | {1, 0, 0}, | |
32 | {1, 1, 0}, | |
33 | {1, 1, 1}, | |
34 | {2, 1, 1}, | |
35 | {2, 2, 1}, | |
36 | {2, 2, 2}, | |
37 | {3, 2, 2}, | |
38 | } { | |
39 | e, _ := lb.Get() | |
40 | e(context.Background(), struct{}{}) | |
41 | if have := counts; !reflect.DeepEqual(want, have) { | |
42 | t.Errorf("%d: want %v, have %v", i+1, want, have) | |
43 | } | |
44 | } | |
45 | } |
0 | package strategy | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Strategy yields endpoints to consumers according to some algorithm. | |
5 | type Strategy interface { | |
6 | Next() (endpoint.Endpoint, error) | |
7 | Stop() | |
8 | } |
0 | package strategy_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/loadbalancer" | |
8 | ) | |
9 | ||
10 | func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) { | |
11 | if err := within(10*time.Millisecond, func() bool { | |
12 | return lb.Count() > 0 | |
13 | }); err != nil { | |
14 | t.Fatal("Publisher never updated endpoints") | |
15 | } | |
16 | } | |
17 | ||
18 | func within(d time.Duration, f func() bool) error { | |
19 | var ( | |
20 | deadline = time.After(d) | |
21 | ticker = time.NewTicker(d / 10) | |
22 | ) | |
23 | defer ticker.Stop() | |
24 | for { | |
25 | select { | |
26 | case <-ticker.C: | |
27 | if f() { | |
28 | return nil | |
29 | } | |
30 | case <-deadline: | |
31 | return fmt.Errorf("deadline exceeded") | |
32 | } | |
33 | } | |
34 | } |