Merge pull request #90 from go-kit/load-balancer-packages
Enhancements to loadbalancer
Peter Bourgon
8 years ago
6 | 6 | discovered via periodic DNS SRV lookups on a single logical name. Consul and |
7 | 7 | etcd publishers are planned. |
8 | 8 | |
9 | Different load balancing strategies are implemented on top of publishers. Go | |
10 | kit currently provides random and round-robin semantics. Smarter behaviors, | |
9 | Different load balancers are implemented on top of publishers. Go kit | |
10 | currently provides random and round-robin load balancers. Smarter behaviors, | |
11 | 11 | e.g. load balancing based on underlying endpoint priority/weight, is planned. |
12 | 12 | |
13 | 13 | ## Rationale |
16 | 16 | |
17 | 17 | ## Usage |
18 | 18 | |
19 | In your client, define a publisher, wrap it with a balancing strategy, and pass | |
20 | it to a retry strategy, which returns an endpoint. Use that endpoint to make | |
21 | requests, or wrap it with other value-add middleware. | |
19 | In your client, construct a publisher for a specific remote service, and pass | |
20 | it to a load balancer. Then, request an endpoint from the load balancer | |
21 | whenever you need to make a request to that remote service. | |
22 | ||
23 | ```go | |
24 | import ( | |
25 | "github.com/go-kit/kit/loadbalancer" | |
26 | "github.com/go-kit/kit/loadbalancer/dnssrv" | |
27 | ) | |
28 | ||
29 | func main() { | |
30 | // Construct a load balancer for foosvc, which gets foosvc instances by | |
31 | // polling a specific DNS SRV name. | |
32 | p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) | |
33 | lb := loadbalancer.NewRoundRobin(p) | |
34 | ||
35 | // Get a new endpoint from the load balancer. | |
36 | endpoint, err := lb.Endpoint() | |
37 | if err != nil { | |
38 | panic(err) | |
39 | } | |
40 | ||
41 | // Use the endpoint to make a request. | |
42 | response, err := endpoint(ctx, request) | |
43 | } | |
44 | ||
45 | func fooFactory(instance string) (endpoint.Endpoint, error) { | |
46 | // Convert an instance (host:port) to an endpoint, via a defined transport binding. | |
47 | } | |
48 | ``` | |
49 | ||
50 | It's also possible to wrap a load balancer with a retry strategy, so that it | |
51 | can be used as an endpoint directly. This may make load balancers more | |
52 | convenient to use, at the cost of fine-grained control of failures. | |
22 | 53 | |
23 | 54 | ```go |
24 | 55 | func main() { |
25 | var ( | |
26 | fooPublisher = loadbalancer.NewDNSSRVPublisher("foo.mynet.local", 5*time.Second, makeEndpoint) | |
27 | fooBalancer = loadbalancer.RoundRobin(fooPublisher) | |
28 | fooEndpoint = loadbalancer.Retry(3, time.Second, fooBalancer) | |
29 | ) | |
30 | http.HandleFunc("/", handle(fooEndpoint)) | |
31 | log.Fatal(http.ListenAndServe(":8080", nil)) | |
56 | p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) | |
57 | lb := loadbalancer.NewRoundRobin(p) | |
58 | endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb) | |
59 | ||
60 | response, err := endpoint(ctx, request) // requests will be automatically load balanced | |
32 | 61 | } |
33 | ||
34 | func makeEndpoint(hostport string) endpoint.Endpoint { | |
35 | // Convert a host:port to a endpoint via your defined transport. | |
36 | } | |
37 | ||
38 | func handle(foo endpoint.Endpoint) http.HandlerFunc { | |
39 | return func(w http.ResponseWriter, r *http.Request) { | |
40 | // foo is usable as a load-balanced remote endpoint. | |
41 | } | |
42 | } | |
43 | ``` | |
62 | ```⏎ |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | type cache struct { | |
5 | req chan []endpoint.Endpoint | |
6 | cnt chan int | |
7 | quit chan struct{} | |
8 | } | |
9 | ||
10 | func newCache(p Publisher) *cache { | |
11 | c := &cache{ | |
12 | req: make(chan []endpoint.Endpoint), | |
13 | cnt: make(chan int), | |
14 | quit: make(chan struct{}), | |
15 | } | |
16 | go c.loop(p) | |
17 | return c | |
18 | } | |
19 | ||
20 | func (c *cache) loop(p Publisher) { | |
21 | e := make(chan []endpoint.Endpoint, 1) | |
22 | p.Subscribe(e) | |
23 | defer p.Unsubscribe(e) | |
24 | endpoints := <-e | |
25 | for { | |
26 | select { | |
27 | case endpoints = <-e: | |
28 | case c.cnt <- len(endpoints): | |
29 | case c.req <- endpoints: | |
30 | case <-c.quit: | |
31 | return | |
32 | } | |
33 | } | |
34 | } | |
35 | ||
36 | func (c *cache) count() int { | |
37 | return <-c.cnt | |
38 | } | |
39 | ||
40 | func (c *cache) get() []endpoint.Endpoint { | |
41 | return <-c.req | |
42 | } | |
43 | ||
44 | func (c *cache) stop() { | |
45 | close(c.quit) | |
46 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "runtime" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | ) | |
10 | ||
11 | func TestCache(t *testing.T) { | |
12 | e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
13 | endpoints := []endpoint.Endpoint{e} | |
14 | ||
15 | p := NewStaticPublisher(endpoints) | |
16 | defer p.Stop() | |
17 | ||
18 | c := newCache(p) | |
19 | defer c.stop() | |
20 | ||
21 | for _, n := range []int{2, 10, 0} { | |
22 | endpoints = make([]endpoint.Endpoint, n) | |
23 | for i := 0; i < n; i++ { | |
24 | endpoints[i] = e | |
25 | } | |
26 | p.Replace(endpoints) | |
27 | runtime.Gosched() | |
28 | if want, have := len(endpoints), len(c.get()); want != have { | |
29 | t.Errorf("want %d, have %d", want, have) | |
30 | } | |
31 | } | |
32 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "crypto/md5" | |
4 | "fmt" | |
5 | "net" | |
6 | "sort" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | type dnssrvPublisher struct { | |
13 | subscribe chan chan<- []endpoint.Endpoint | |
14 | unsubscribe chan chan<- []endpoint.Endpoint | |
15 | quit chan struct{} | |
16 | } | |
17 | ||
18 | // NewDNSSRVPublisher returns a publisher that resolves the SRV name every ttl, and | |
19 | func NewDNSSRVPublisher(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) Publisher { | |
20 | p := &dnssrvPublisher{ | |
21 | subscribe: make(chan chan<- []endpoint.Endpoint), | |
22 | unsubscribe: make(chan chan<- []endpoint.Endpoint), | |
23 | quit: make(chan struct{}), | |
24 | } | |
25 | go p.loop(name, ttl, makeEndpoint) | |
26 | return p | |
27 | } | |
28 | ||
29 | func (p *dnssrvPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
30 | p.subscribe <- c | |
31 | } | |
32 | ||
33 | func (p *dnssrvPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
34 | p.unsubscribe <- c | |
35 | } | |
36 | ||
37 | func (p *dnssrvPublisher) Stop() { | |
38 | close(p.quit) | |
39 | } | |
40 | ||
41 | var newTicker = time.NewTicker | |
42 | ||
43 | func (p *dnssrvPublisher) loop(name string, ttl time.Duration, makeEndpoint func(hostport string) endpoint.Endpoint) { | |
44 | var ( | |
45 | subscriptions = map[chan<- []endpoint.Endpoint]struct{}{} | |
46 | addrs, md5, _ = resolve(name) | |
47 | endpoints = convert(addrs, makeEndpoint) | |
48 | ticker = newTicker(ttl) | |
49 | ) | |
50 | defer ticker.Stop() | |
51 | for { | |
52 | select { | |
53 | case <-ticker.C: | |
54 | addrs, newmd5, err := resolve(name) | |
55 | if err == nil && newmd5 != md5 { | |
56 | endpoints = convert(addrs, makeEndpoint) | |
57 | for c := range subscriptions { | |
58 | c <- endpoints | |
59 | } | |
60 | md5 = newmd5 | |
61 | } | |
62 | ||
63 | case c := <-p.subscribe: | |
64 | subscriptions[c] = struct{}{} | |
65 | c <- endpoints | |
66 | ||
67 | case c := <-p.unsubscribe: | |
68 | delete(subscriptions, c) | |
69 | ||
70 | case <-p.quit: | |
71 | return | |
72 | } | |
73 | } | |
74 | } | |
75 | ||
76 | // Allow mocking in tests. | |
77 | var resolve = func(name string) (addrs []*net.SRV, md5sum string, err error) { | |
78 | _, addrs, err = net.LookupSRV("", "", name) | |
79 | if err != nil { | |
80 | return addrs, "", err | |
81 | } | |
82 | hostports := make([]string, len(addrs)) | |
83 | for i, addr := range addrs { | |
84 | hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) | |
85 | } | |
86 | sort.Sort(sort.StringSlice(hostports)) | |
87 | h := md5.New() | |
88 | for _, hostport := range hostports { | |
89 | fmt.Fprintf(h, hostport) | |
90 | } | |
91 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
92 | } | |
93 | ||
94 | func convert(addrs []*net.SRV, makeEndpoint func(hostport string) endpoint.Endpoint) []endpoint.Endpoint { | |
95 | endpoints := make([]endpoint.Endpoint, len(addrs)) | |
96 | for i, addr := range addrs { | |
97 | endpoints[i] = makeEndpoint(addr2hostport(addr)) | |
98 | } | |
99 | return endpoints | |
100 | } | |
101 | ||
102 | func addr2hostport(addr *net.SRV) string { | |
103 | return net.JoinHostPort(addr.Target, fmt.Sprintf("%d", addr.Port)) | |
104 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | "golang.org/x/net/context" | |
9 | ||
10 | "github.com/go-kit/kit/endpoint" | |
11 | ) | |
12 | ||
13 | func TestDNSSRVPublisher(t *testing.T) { | |
14 | // Reset the vars when we're done | |
15 | oldResolve := resolve | |
16 | defer func() { resolve = oldResolve }() | |
17 | oldNewTicker := newTicker | |
18 | defer func() { newTicker = oldNewTicker }() | |
19 | ||
20 | // Set up a fixture and swap the vars | |
21 | a := []*net.SRV{ | |
22 | {Target: "foo", Port: 123}, | |
23 | {Target: "bar", Port: 456}, | |
24 | {Target: "baz", Port: 789}, | |
25 | } | |
26 | ticker := make(chan time.Time) | |
27 | resolve = func(string) ([]*net.SRV, string, error) { return a, fmt.Sprint(len(a)), nil } | |
28 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: ticker} } | |
29 | ||
30 | // Construct endpoint | |
31 | m := map[string]int{} | |
32 | e := func(hostport string) endpoint.Endpoint { | |
33 | return func(context.Context, interface{}) (interface{}, error) { | |
34 | m[hostport]++ | |
35 | return struct{}{}, nil | |
36 | } | |
37 | } | |
38 | ||
39 | // Build the publisher | |
40 | var ( | |
41 | name = "irrelevant" | |
42 | ttl = time.Second | |
43 | makeEndpoint = func(hostport string) endpoint.Endpoint { return e(hostport) } | |
44 | ) | |
45 | p := NewDNSSRVPublisher(name, ttl, makeEndpoint) | |
46 | defer p.Stop() | |
47 | ||
48 | // Subscribe | |
49 | c := make(chan []endpoint.Endpoint, 1) | |
50 | p.Subscribe(c) | |
51 | defer p.Unsubscribe(c) | |
52 | ||
53 | // Invoke all of the endpoints | |
54 | for _, e := range <-c { | |
55 | e(context.Background(), struct{}{}) | |
56 | } | |
57 | ||
58 | // Make sure we invoked what we expected to | |
59 | for _, addr := range a { | |
60 | hostport := addr2hostport(addr) | |
61 | if want, have := 1, m[hostport]; want != have { | |
62 | t.Errorf("%q: want %d, have %d", name, want, have) | |
63 | } | |
64 | delete(m, hostport) | |
65 | } | |
66 | if want, have := 0, len(m); want != have { | |
67 | t.Errorf("want %d, have %d", want, have) | |
68 | } | |
69 | ||
70 | // Reset the fixture, trigger the timer, count the endpoints | |
71 | a = []*net.SRV{} | |
72 | ticker <- time.Now() | |
73 | if want, have := len(a), len(<-c); want != have { | |
74 | t.Errorf("want %d, have %d", want, have) | |
75 | } | |
76 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "crypto/md5" | |
4 | "fmt" | |
5 | "net" | |
6 | "sort" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer" | |
11 | "github.com/go-kit/kit/log" | |
12 | ) | |
13 | ||
14 | // Publisher yields endpoints taken from the named DNS SRV record. The name is | |
15 | // resolved on a fixed schedule. Priorities and weights are ignored. | |
16 | type Publisher struct { | |
17 | name string | |
18 | ttl time.Duration | |
19 | factory loadbalancer.Factory | |
20 | logger log.Logger | |
21 | endpoints chan []endpoint.Endpoint | |
22 | quit chan struct{} | |
23 | } | |
24 | ||
25 | // NewPublisher returns a DNS SRV publisher. The name is resolved | |
26 | // synchronously as part of construction; if that resolution fails, the | |
27 | // constructor will return an error. The factory is used to convert a | |
28 | // host:port to a usable endpoint. The logger is used to report DNS and | |
29 | // factory errors. | |
30 | func NewPublisher(name string, ttl time.Duration, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { | |
31 | logger = log.NewContext(logger).With("component", "DNS SRV Publisher") | |
32 | addrs, md5, err := resolve(name) | |
33 | if err != nil { | |
34 | return nil, err | |
35 | } | |
36 | p := &Publisher{ | |
37 | name: name, | |
38 | ttl: ttl, | |
39 | factory: f, | |
40 | logger: logger, | |
41 | endpoints: make(chan []endpoint.Endpoint), | |
42 | quit: make(chan struct{}), | |
43 | } | |
44 | go p.loop(makeEndpoints(addrs, f, logger), md5) | |
45 | return p, nil | |
46 | } | |
47 | ||
48 | // Stop terminates the publisher. | |
49 | func (p *Publisher) Stop() { | |
50 | close(p.quit) | |
51 | } | |
52 | ||
53 | func (p *Publisher) loop(endpoints []endpoint.Endpoint, md5 string) { | |
54 | t := newTicker(p.ttl) | |
55 | defer t.Stop() | |
56 | for { | |
57 | select { | |
58 | case p.endpoints <- endpoints: | |
59 | ||
60 | case <-t.C: | |
61 | // TODO should we do this out-of-band? | |
62 | addrs, newmd5, err := resolve(p.name) | |
63 | if err != nil { | |
64 | p.logger.Log("name", p.name, "err", err) | |
65 | continue // don't replace good endpoints with bad ones | |
66 | } | |
67 | if newmd5 == md5 { | |
68 | continue // no change | |
69 | } | |
70 | endpoints = makeEndpoints(addrs, p.factory, p.logger) | |
71 | md5 = newmd5 | |
72 | ||
73 | case <-p.quit: | |
74 | return | |
75 | } | |
76 | } | |
77 | } | |
78 | ||
79 | // Endpoints implements the Publisher interface. | |
80 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
81 | select { | |
82 | case endpoints := <-p.endpoints: | |
83 | return endpoints, nil | |
84 | case <-p.quit: | |
85 | return nil, loadbalancer.ErrPublisherStopped | |
86 | } | |
87 | } | |
88 | ||
89 | var ( | |
90 | lookupSRV = net.LookupSRV | |
91 | newTicker = time.NewTicker | |
92 | ) | |
93 | ||
94 | func resolve(name string) (addrs []*net.SRV, md5sum string, err error) { | |
95 | _, addrs, err = lookupSRV("", "", name) | |
96 | if err != nil { | |
97 | return addrs, "", err | |
98 | } | |
99 | hostports := make([]string, len(addrs)) | |
100 | for i, addr := range addrs { | |
101 | hostports[i] = fmt.Sprintf("%s:%d", addr.Target, addr.Port) | |
102 | } | |
103 | sort.Sort(sort.StringSlice(hostports)) | |
104 | h := md5.New() | |
105 | for _, hostport := range hostports { | |
106 | fmt.Fprintf(h, hostport) | |
107 | } | |
108 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
109 | } | |
110 | ||
111 | func makeEndpoints(addrs []*net.SRV, f loadbalancer.Factory, logger log.Logger) []endpoint.Endpoint { | |
112 | endpoints := make([]endpoint.Endpoint, 0, len(addrs)) | |
113 | for _, addr := range addrs { | |
114 | endpoint, err := f(addr2instance(addr)) | |
115 | if err != nil { | |
116 | logger.Log("instance", addr2instance(addr), "err", err) | |
117 | continue | |
118 | } | |
119 | endpoints = append(endpoints, endpoint) | |
120 | } | |
121 | return endpoints | |
122 | } | |
123 | ||
124 | func addr2instance(addr *net.SRV) string { | |
125 | return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
126 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "net" | |
5 | "sync/atomic" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | "golang.org/x/net/context" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | "github.com/go-kit/kit/loadbalancer" | |
13 | "github.com/go-kit/kit/log" | |
14 | ) | |
15 | ||
16 | func TestPublisher(t *testing.T) { | |
17 | var ( | |
18 | target = "my-target" | |
19 | port = uint16(1234) | |
20 | addr = &net.SRV{Target: target, Port: port} | |
21 | addrs = []*net.SRV{addr} | |
22 | name = "my-name" | |
23 | ttl = time.Second | |
24 | logger = log.NewNopLogger() | |
25 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
26 | ) | |
27 | ||
28 | oldLookup := lookupSRV | |
29 | defer func() { lookupSRV = oldLookup }() | |
30 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
31 | ||
32 | factory := func(instance string) (endpoint.Endpoint, error) { | |
33 | if want, have := addr2instance(addr), instance; want != have { | |
34 | t.Errorf("want %q, have %q", want, have) | |
35 | } | |
36 | return e, nil | |
37 | } | |
38 | ||
39 | p, err := NewPublisher(name, ttl, factory, logger) | |
40 | if err != nil { | |
41 | t.Fatal(err) | |
42 | } | |
43 | defer p.Stop() | |
44 | ||
45 | if _, err := p.Endpoints(); err != nil { | |
46 | t.Fatal(err) | |
47 | } | |
48 | } | |
49 | ||
50 | func TestBadLookup(t *testing.T) { | |
51 | oldLookup := lookupSRV | |
52 | defer func() { lookupSRV = oldLookup }() | |
53 | lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil) | |
54 | ||
55 | var ( | |
56 | name = "some-name" | |
57 | ttl = time.Second | |
58 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("unreachable") } | |
59 | logger = log.NewNopLogger() | |
60 | ) | |
61 | ||
62 | if _, err := NewPublisher(name, ttl, factory, logger); err == nil { | |
63 | t.Fatal("wanted error, got none") | |
64 | } | |
65 | } | |
66 | ||
67 | func TestBadFactory(t *testing.T) { | |
68 | var ( | |
69 | addr = &net.SRV{Target: "foo", Port: 1234} | |
70 | addrs = []*net.SRV{addr} | |
71 | name = "some-name" | |
72 | ttl = time.Second | |
73 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } | |
74 | logger = log.NewNopLogger() | |
75 | ) | |
76 | ||
77 | oldLookup := lookupSRV | |
78 | defer func() { lookupSRV = oldLookup }() | |
79 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
80 | ||
81 | p, err := NewPublisher(name, ttl, factory, logger) | |
82 | if err != nil { | |
83 | t.Fatal(err) | |
84 | } | |
85 | defer p.Stop() | |
86 | ||
87 | endpoints, err := p.Endpoints() | |
88 | if err != nil { | |
89 | t.Fatal(err) | |
90 | } | |
91 | if want, have := 0, len(endpoints); want != have { | |
92 | t.Errorf("want %q, have %q", want, have) | |
93 | } | |
94 | } | |
95 | ||
96 | func TestRefreshWithChange(t *testing.T) { | |
97 | t.Skip("TODO") | |
98 | } | |
99 | ||
100 | func TestRefreshNoChange(t *testing.T) { | |
101 | var ( | |
102 | tick = make(chan time.Time) | |
103 | target = "my-target" | |
104 | port = uint16(5678) | |
105 | addr = &net.SRV{Target: target, Port: port} | |
106 | addrs = []*net.SRV{addr} | |
107 | name = "my-name" | |
108 | ttl = time.Second | |
109 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } | |
110 | logger = log.NewNopLogger() | |
111 | ) | |
112 | ||
113 | oldTicker := newTicker | |
114 | defer func() { newTicker = oldTicker }() | |
115 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} } | |
116 | ||
117 | var resolves uint64 | |
118 | oldLookup := lookupSRV | |
119 | defer func() { lookupSRV = oldLookup }() | |
120 | lookupSRV = mockLookupSRV(addrs, nil, &resolves) | |
121 | ||
122 | p, err := NewPublisher(name, ttl, factory, logger) | |
123 | if err != nil { | |
124 | t.Fatal(err) | |
125 | } | |
126 | defer p.Stop() | |
127 | ||
128 | tick <- time.Now() | |
129 | if want, have := uint64(2), resolves; want != have { | |
130 | t.Errorf("want %d, have %d", want, have) | |
131 | } | |
132 | } | |
133 | ||
134 | func TestRefreshResolveError(t *testing.T) { | |
135 | t.Skip("TODO") | |
136 | } | |
137 | ||
138 | func TestErrPublisherStopped(t *testing.T) { | |
139 | var ( | |
140 | name = "my-name" | |
141 | ttl = time.Second | |
142 | factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } | |
143 | logger = log.NewNopLogger() | |
144 | ) | |
145 | ||
146 | oldLookup := lookupSRV | |
147 | defer func() { lookupSRV = oldLookup }() | |
148 | lookupSRV = mockLookupSRV([]*net.SRV{}, nil, nil) | |
149 | ||
150 | p, err := NewPublisher(name, ttl, factory, logger) | |
151 | if err != nil { | |
152 | t.Fatal(err) | |
153 | } | |
154 | ||
155 | p.Stop() | |
156 | _, have := p.Endpoints() | |
157 | if want := loadbalancer.ErrPublisherStopped; want != have { | |
158 | t.Fatalf("want %v, have %v", want, have) | |
159 | } | |
160 | } | |
161 | ||
162 | func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) { | |
163 | return func(service, proto, name string) (string, []*net.SRV, error) { | |
164 | if count != nil { | |
165 | atomic.AddUint64(count, 1) | |
166 | } | |
167 | return "", addrs, err | |
168 | } | |
169 | } |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | type endpointCache struct { | |
5 | requests chan []endpoint.Endpoint | |
6 | quit chan struct{} | |
7 | } | |
8 | ||
9 | func newEndpointCache(p Publisher) *endpointCache { | |
10 | c := &endpointCache{ | |
11 | requests: make(chan []endpoint.Endpoint), | |
12 | quit: make(chan struct{}), | |
13 | } | |
14 | go c.loop(p) | |
15 | return c | |
16 | } | |
17 | ||
18 | func (c *endpointCache) loop(p Publisher) { | |
19 | updates := make(chan []endpoint.Endpoint, 1) | |
20 | p.Subscribe(updates) | |
21 | defer p.Unsubscribe(updates) | |
22 | endpoints := <-updates | |
23 | ||
24 | for { | |
25 | select { | |
26 | case endpoints = <-updates: | |
27 | case c.requests <- endpoints: | |
28 | case <-c.quit: | |
29 | return | |
30 | } | |
31 | } | |
32 | } | |
33 | ||
34 | func (c *endpointCache) get() []endpoint.Endpoint { | |
35 | return <-c.requests | |
36 | } | |
37 | ||
38 | func (c *endpointCache) stop() { | |
39 | close(c.quit) | |
40 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "golang.org/x/net/context" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | ) | |
9 | ||
10 | func TestEndpointCache(t *testing.T) { | |
11 | endpoints := []endpoint.Endpoint{ | |
12 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
13 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
14 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
15 | } | |
16 | ||
17 | p := NewStaticPublisher(endpoints) | |
18 | defer p.Stop() | |
19 | ||
20 | c := newEndpointCache(p) | |
21 | defer c.stop() | |
22 | ||
23 | if want, have := len(endpoints), len(c.get()); want != have { | |
24 | t.Errorf("want %d, have %d", want, have) | |
25 | } | |
26 | } |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Factory is a function that converts an instance string, e.g. a host:port, | |
5 | // to a usable endpoint. Factories are used by load balancers to lift | |
6 | // instances returned by Publishers into endpoints. Users are expected to | |
7 | // provide their own factory functions that assume specific transports, or can | |
8 | // deduce transports by parsing the instance string. | |
9 | type Factory func(instance string) (endpoint.Endpoint, error) |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // LoadBalancer yields endpoints one-by-one. | |
9 | type LoadBalancer interface { | |
10 | Count() int | |
11 | Get() (endpoint.Endpoint, error) | |
12 | } | |
13 | ||
14 | // ErrNoEndpointsAvailable is given by a load balancer when no endpoints are | |
15 | // available to be returned. | |
16 | var ErrNoEndpointsAvailable = errors.New("no endpoints available") |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // LoadBalancer describes something that can yield endpoints for a remote | |
9 | // service method. | |
10 | type LoadBalancer interface { | |
11 | Endpoint() (endpoint.Endpoint, error) | |
12 | } | |
13 | ||
14 | // ErrNoEndpoints is returned when a load balancer (or one of its components) | |
15 | // has no endpoints to return. In a request lifecycle, this is usually a fatal | |
16 | // error. | |
17 | var ErrNoEndpoints = errors.New("no endpoints available") |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "runtime" | |
4 | "sync" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | ) | |
8 | ||
9 | type mockPublisher struct { | |
10 | sync.Mutex | |
11 | e []endpoint.Endpoint | |
12 | s map[chan<- []endpoint.Endpoint]struct{} | |
13 | } | |
14 | ||
15 | func newMockPublisher(endpoints []endpoint.Endpoint) *mockPublisher { | |
16 | return &mockPublisher{ | |
17 | e: endpoints, | |
18 | s: map[chan<- []endpoint.Endpoint]struct{}{}, | |
19 | } | |
20 | } | |
21 | ||
22 | func (p *mockPublisher) replace(endpoints []endpoint.Endpoint) { | |
23 | p.Lock() | |
24 | defer p.Unlock() | |
25 | p.e = endpoints | |
26 | for s := range p.s { | |
27 | s <- p.e | |
28 | } | |
29 | runtime.Gosched() | |
30 | } | |
31 | ||
32 | func (p *mockPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
33 | p.Lock() | |
34 | defer p.Unlock() | |
35 | p.s[c] = struct{}{} | |
36 | c <- p.e | |
37 | } | |
38 | ||
39 | func (p *mockPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
40 | p.Lock() | |
41 | defer p.Unlock() | |
42 | delete(p.s, c) | |
43 | } | |
44 | ||
45 | func (p *mockPublisher) Stop() {} |
0 | 0 | package loadbalancer |
1 | 1 | |
2 | import "github.com/go-kit/kit/endpoint" | |
2 | import ( | |
3 | "errors" | |
3 | 4 | |
4 | // Publisher produces endpoints. | |
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Publisher describes something that provides a set of identical endpoints. | |
9 | // Different publisher implementations exist for different kinds of service | |
10 | // discovery systems. | |
5 | 11 | type Publisher interface { |
6 | Subscribe(chan<- []endpoint.Endpoint) | |
7 | Unsubscribe(chan<- []endpoint.Endpoint) | |
8 | Stop() | |
12 | Endpoints() ([]endpoint.Endpoint, error) | |
9 | 13 | } |
14 | ||
15 | // ErrPublisherStopped is returned by publishers when the underlying | |
16 | // implementation has been terminated and can no longer serve requests. | |
17 | var ErrPublisherStopped = errors.New("publisher stopped") |
5 | 5 | "github.com/go-kit/kit/endpoint" |
6 | 6 | ) |
7 | 7 | |
8 | // Random returns a load balancer that yields random endpoints. | |
9 | func Random(p Publisher) LoadBalancer { | |
10 | return random{newCache(p)} | |
8 | // Random is a completely stateless load balancer that chooses a random | |
9 | // endpoint to return each time. | |
10 | type Random struct { | |
11 | p Publisher | |
12 | r *rand.Rand | |
11 | 13 | } |
12 | 14 | |
13 | type random struct{ *cache } | |
15 | // NewRandom returns a new Random load balancer. | |
16 | func NewRandom(p Publisher, seed int64) *Random { | |
17 | return &Random{ | |
18 | p: p, | |
19 | r: rand.New(rand.NewSource(seed)), | |
20 | } | |
21 | } | |
14 | 22 | |
15 | func (r random) Count() int { return r.cache.count() } | |
16 | ||
17 | func (r random) Get() (endpoint.Endpoint, error) { | |
18 | endpoints := r.cache.get() | |
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (r *Random) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := r.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
19 | 29 | if len(endpoints) <= 0 { |
20 | return nil, ErrNoEndpointsAvailable | |
30 | return nil, ErrNoEndpoints | |
21 | 31 | } |
22 | return endpoints[rand.Intn(len(endpoints))], nil | |
32 | return endpoints[r.r.Intn(len(endpoints))], nil | |
23 | 33 | } |
3 | 3 | "math" |
4 | 4 | "testing" |
5 | 5 | |
6 | "golang.org/x/net/context" | |
7 | ||
6 | 8 | "github.com/go-kit/kit/endpoint" |
7 | 9 | "github.com/go-kit/kit/loadbalancer" |
8 | "golang.org/x/net/context" | |
10 | "github.com/go-kit/kit/loadbalancer/static" | |
9 | 11 | ) |
10 | 12 | |
11 | func TestRandom(t *testing.T) { | |
12 | p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) | |
13 | defer p.Stop() | |
13 | func TestRandomDistribution(t *testing.T) { | |
14 | var ( | |
15 | n = 3 | |
16 | endpoints = make([]endpoint.Endpoint, n) | |
17 | counts = make([]int, n) | |
18 | seed = int64(123) | |
19 | ctx = context.Background() | |
20 | iterations = 100000 | |
21 | want = iterations / n | |
22 | tolerance = want / 100 // 1% | |
23 | ) | |
14 | 24 | |
15 | lb := loadbalancer.Random(p) | |
16 | if _, err := lb.Get(); err == nil { | |
17 | t.Error("want error, got none") | |
25 | for i := 0; i < n; i++ { | |
26 | i0 := i | |
27 | endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } | |
18 | 28 | } |
19 | 29 | |
20 | counts := []int{0, 0, 0} | |
21 | p.Replace([]endpoint.Endpoint{ | |
22 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
23 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
25 | }) | |
26 | assertLoadBalancerNotEmpty(t, lb) | |
30 | lb := loadbalancer.NewRandom(static.NewPublisher(endpoints), seed) | |
27 | 31 | |
28 | n := 10000 | |
29 | for i := 0; i < n; i++ { | |
30 | e, _ := lb.Get() | |
31 | e(context.Background(), struct{}{}) | |
32 | for i := 0; i < iterations; i++ { | |
33 | e, err := lb.Endpoint() | |
34 | if err != nil { | |
35 | t.Fatal(err) | |
36 | } | |
37 | e(ctx, struct{}{}) | |
32 | 38 | } |
33 | 39 | |
34 | want := float64(n) / float64(len(counts)) | |
35 | tolerance := (want / 100.0) * 5 // 5% | |
36 | for _, have := range counts { | |
37 | if math.Abs(want-float64(have)) > tolerance { | |
38 | t.Errorf("want %.0f, have %d", want, have) | |
40 | for i, have := range counts { | |
41 | if math.Abs(float64(want-have)) > float64(tolerance) { | |
42 | t.Errorf("%d: want %d, have %d", i, want, have) | |
39 | 43 | } |
40 | 44 | } |
41 | 45 | } |
46 | ||
47 | func TestRandomBadPublisher(t *testing.T) { | |
48 | t.Skip("TODO") | |
49 | } | |
50 | ||
51 | func TestRandomNoEndpoints(t *testing.T) { | |
52 | lb := loadbalancer.NewRandom(static.NewPublisher([]endpoint.Endpoint{}), 123) | |
53 | _, have := lb.Endpoint() | |
54 | if want := loadbalancer.ErrNoEndpoints; want != have { | |
55 | t.Errorf("want %q, have %q", want, have) | |
56 | } | |
57 | } |
9 | 9 | "github.com/go-kit/kit/endpoint" |
10 | 10 | ) |
11 | 11 | |
12 | // Retry yields an endpoint that takes endpoints from the load balancer. | |
13 | // Invocations that return errors will be retried until they succeed, up to | |
14 | // max times, or until the timeout is elapsed, whichever comes first. | |
12 | // Retry wraps the load balancer to make it behave like a simple endpoint. | |
13 | // Requests to the endpoint will be automatically load balanced via the load | |
14 | // balancer. Requests that return errors will be retried until they succeed, | |
15 | // up to max times, or until the timeout is elapsed, whichever comes first. | |
15 | 16 | func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { |
16 | 17 | return func(ctx context.Context, request interface{}) (interface{}, error) { |
17 | 18 | var ( |
23 | 24 | defer cancel() |
24 | 25 | for i := 1; i <= max; i++ { |
25 | 26 | go func() { |
26 | e, err := lb.Get() | |
27 | e, err := lb.Endpoint() | |
27 | 28 | if err != nil { |
28 | 29 | errs <- err |
29 | 30 | return |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "errors" |
4 | "testing" | |
4 | 5 | "time" |
6 | ||
7 | "golang.org/x/net/context" | |
5 | 8 | |
6 | 9 | "github.com/go-kit/kit/endpoint" |
7 | 10 | "github.com/go-kit/kit/loadbalancer" |
8 | "golang.org/x/net/context" | |
9 | ||
10 | "testing" | |
11 | "github.com/go-kit/kit/loadbalancer/static" | |
11 | 12 | ) |
12 | 13 | |
13 | func TestRetryMax(t *testing.T) { | |
14 | func TestRetryMaxTotalFail(t *testing.T) { | |
14 | 15 | var ( |
15 | endpoints = []endpoint.Endpoint{} | |
16 | p = loadbalancer.NewStaticPublisher(endpoints) | |
17 | lb = loadbalancer.RoundRobin(p) | |
16 | endpoints = []endpoint.Endpoint{} // no endpoints | |
17 | p = static.NewPublisher(endpoints) | |
18 | lb = loadbalancer.NewRoundRobin(p) | |
19 | retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries | |
20 | ctx = context.Background() | |
18 | 21 | ) |
22 | if _, err := retry(ctx, struct{}{}); err == nil { | |
23 | t.Errorf("expected error, got none") // should fail | |
24 | } | |
25 | } | |
19 | 26 | |
20 | if _, err := loadbalancer.Retry(999, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
27 | func TestRetryMaxPartialFail(t *testing.T) { | |
28 | var ( | |
29 | endpoints = []endpoint.Endpoint{ | |
30 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
31 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
32 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
33 | } | |
34 | retries = len(endpoints) - 1 // not quite enough retries | |
35 | p = static.NewPublisher(endpoints) | |
36 | lb = loadbalancer.NewRoundRobin(p) | |
37 | ctx = context.Background() | |
38 | ) | |
39 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { | |
21 | 40 | t.Errorf("expected error, got none") |
22 | 41 | } |
42 | } | |
23 | 43 | |
24 | endpoints = []endpoint.Endpoint{ | |
25 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
26 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
27 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
28 | } | |
29 | p.Replace(endpoints) | |
30 | assertLoadBalancerNotEmpty(t, lb) | |
31 | ||
32 | if _, err := loadbalancer.Retry(len(endpoints)-1, time.Second, lb)(context.Background(), struct{}{}); err == nil { | |
33 | t.Errorf("expected error, got none") | |
34 | } | |
35 | ||
36 | if _, err := loadbalancer.Retry(len(endpoints), time.Second, lb)(context.Background(), struct{}{}); err != nil { | |
44 | func TestRetryMaxSuccess(t *testing.T) { | |
45 | var ( | |
46 | endpoints = []endpoint.Endpoint{ | |
47 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
48 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
49 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
50 | } | |
51 | retries = len(endpoints) // exactly enough retries | |
52 | p = static.NewPublisher(endpoints) | |
53 | lb = loadbalancer.NewRoundRobin(p) | |
54 | ctx = context.Background() | |
55 | ) | |
56 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { | |
37 | 57 | t.Error(err) |
38 | 58 | } |
39 | 59 | } |
43 | 63 | step = make(chan struct{}) |
44 | 64 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } |
45 | 65 | timeout = time.Millisecond |
46 | retry = loadbalancer.Retry(999, timeout, loadbalancer.RoundRobin(loadbalancer.NewStaticPublisher([]endpoint.Endpoint{e}))) | |
47 | errs = make(chan error) | |
66 | retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(static.NewPublisher([]endpoint.Endpoint{e}))) | |
67 | errs = make(chan error, 1) | |
48 | 68 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } |
49 | 69 | ) |
50 | 70 | |
51 | go invoke() // invoke the endpoint | |
52 | step <- struct{}{} // tell the endpoint to return | |
53 | if err := <-errs; err != nil { // that should succeed | |
71 | go func() { step <- struct{}{} }() // queue up a flush of the endpoint | |
72 | invoke() // invoke the endpoint and trigger the flush | |
73 | if err := <-errs; err != nil { // that should succeed | |
54 | 74 | t.Error(err) |
55 | 75 | } |
56 | 76 | |
57 | go invoke() // invoke the endpoint | |
58 | time.Sleep(2 * timeout) // wait | |
59 | time.Sleep(2 * timeout) // wait again (CI servers!!) | |
60 | step <- struct{}{} // tell the endpoint to return | |
61 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
62 | t.Errorf("wanted error, got none") | |
77 | go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush | |
78 | invoke() // invoke the endpoint | |
79 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
80 | t.Errorf("wanted %v, got none", context.DeadlineExceeded) | |
63 | 81 | } |
64 | 82 | } |
5 | 5 | "github.com/go-kit/kit/endpoint" |
6 | 6 | ) |
7 | 7 | |
8 | // RoundRobin returns a load balancer that yields endpoints in sequence. | |
9 | func RoundRobin(p Publisher) LoadBalancer { | |
10 | return &roundRobin{newCache(p), 0} | |
8 | // RoundRobin is a simple load balancer that returns each of the published | |
9 | // endpoints in sequence. | |
10 | type RoundRobin struct { | |
11 | p Publisher | |
12 | counter uint64 | |
11 | 13 | } |
12 | 14 | |
13 | type roundRobin struct { | |
14 | *cache | |
15 | uint64 | |
15 | // NewRoundRobin returns a new RoundRobin load balancer. | |
16 | func NewRoundRobin(p Publisher) *RoundRobin { | |
17 | return &RoundRobin{ | |
18 | p: p, | |
19 | counter: 0, | |
20 | } | |
16 | 21 | } |
17 | 22 | |
18 | func (r *roundRobin) Count() int { return r.cache.count() } | |
19 | ||
20 | func (r *roundRobin) Get() (endpoint.Endpoint, error) { | |
21 | endpoints := r.cache.get() | |
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := rr.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
22 | 29 | if len(endpoints) <= 0 { |
23 | return nil, ErrNoEndpointsAvailable | |
30 | return nil, ErrNoEndpoints | |
24 | 31 | } |
25 | 32 | var old uint64 |
26 | 33 | for { |
27 | old = atomic.LoadUint64(&r.uint64) | |
28 | if atomic.CompareAndSwapUint64(&r.uint64, old, old+1) { | |
34 | old = atomic.LoadUint64(&rr.counter) | |
35 | if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) { | |
29 | 36 | break |
30 | 37 | } |
31 | 38 | } |
5 | 5 | |
6 | 6 | "github.com/go-kit/kit/endpoint" |
7 | 7 | "github.com/go-kit/kit/loadbalancer" |
8 | "github.com/go-kit/kit/loadbalancer/static" | |
8 | 9 | "golang.org/x/net/context" |
9 | 10 | ) |
10 | 11 | |
11 | func TestRoundRobin(t *testing.T) { | |
12 | p := loadbalancer.NewStaticPublisher([]endpoint.Endpoint{}) | |
13 | defer p.Stop() | |
12 | func TestRoundRobinDistribution(t *testing.T) { | |
13 | var ( | |
14 | ctx = context.Background() | |
15 | counts = []int{0, 0, 0} | |
16 | endpoints = []endpoint.Endpoint{ | |
17 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
18 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
19 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
20 | } | |
21 | ) | |
14 | 22 | |
15 | lb := loadbalancer.RoundRobin(p) | |
16 | if _, err := lb.Get(); err == nil { | |
17 | t.Error("want error, got none") | |
18 | } | |
19 | ||
20 | counts := []int{0, 0, 0} | |
21 | p.Replace([]endpoint.Endpoint{ | |
22 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
23 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
24 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
25 | }) | |
26 | assertLoadBalancerNotEmpty(t, lb) | |
23 | lb := loadbalancer.NewRoundRobin(static.NewPublisher(endpoints)) | |
27 | 24 | |
28 | 25 | for i, want := range [][]int{ |
29 | 26 | {1, 0, 0}, |
34 | 31 | {2, 2, 2}, |
35 | 32 | {3, 2, 2}, |
36 | 33 | } { |
37 | e, _ := lb.Get() | |
38 | e(context.Background(), struct{}{}) | |
34 | e, err := lb.Endpoint() | |
35 | if err != nil { | |
36 | t.Fatal(err) | |
37 | } | |
38 | e(ctx, struct{}{}) | |
39 | 39 | if have := counts; !reflect.DeepEqual(want, have) { |
40 | t.Errorf("%d: want %v, have %v", i+1, want, have) | |
40 | t.Fatalf("%d: want %v, have %v", i, want, have) | |
41 | 41 | } |
42 | ||
42 | 43 | } |
43 | 44 | } |
45 | ||
46 | func TestRoundRobinBadPublisher(t *testing.T) { | |
47 | t.Skip("TODO") | |
48 | } |
0 | package static | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Publisher yields the same set of static endpoints. | |
9 | type Publisher struct { | |
10 | mtx sync.RWMutex | |
11 | endpoints []endpoint.Endpoint | |
12 | } | |
13 | ||
14 | // NewPublisher returns a static endpoint Publisher. | |
15 | func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { | |
16 | return &Publisher{ | |
17 | endpoints: endpoints, | |
18 | } | |
19 | } | |
20 | ||
21 | // Endpoints implements the Publisher interface. | |
22 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
23 | p.mtx.RLock() | |
24 | defer p.mtx.RUnlock() | |
25 | return p.endpoints, nil | |
26 | } | |
27 | ||
28 | // Replace is a utility method to swap out the underlying endpoints of an | |
29 | // existing static publisher. It's useful mostly for testing. | |
30 | func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { | |
31 | p.mtx.Lock() | |
32 | defer p.mtx.Unlock() | |
33 | p.endpoints = endpoints | |
34 | } |
0 | package static_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer/static" | |
10 | ) | |
11 | ||
12 | func TestStatic(t *testing.T) { | |
13 | var ( | |
14 | e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
15 | e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
16 | endpoints = []endpoint.Endpoint{e1, e2} | |
17 | ) | |
18 | p := static.NewPublisher(endpoints) | |
19 | have, err := p.Endpoints() | |
20 | if err != nil { | |
21 | t.Fatal(err) | |
22 | } | |
23 | if want := endpoints; !reflect.DeepEqual(want, have) { | |
24 | t.Fatalf("want %#+v, have %#+v", want, have) | |
25 | } | |
26 | } | |
27 | ||
28 | func TestStaticReplace(t *testing.T) { | |
29 | p := static.NewPublisher([]endpoint.Endpoint{ | |
30 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
31 | }) | |
32 | have, err := p.Endpoints() | |
33 | if err != nil { | |
34 | t.Fatal(err) | |
35 | } | |
36 | if want, have := 1, len(have); want != have { | |
37 | t.Fatalf("want %d, have %d", want, have) | |
38 | } | |
39 | p.Replace([]endpoint.Endpoint{}) | |
40 | have, err = p.Endpoints() | |
41 | if err != nil { | |
42 | t.Fatal(err) | |
43 | } | |
44 | if want, have := 0, len(have); want != have { | |
45 | t.Fatalf("want %d, have %d", want, have) | |
46 | } | |
47 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // NewStaticPublisher returns a publisher that yields a static set of | |
9 | // endpoints, which can be completely replaced. | |
10 | func NewStaticPublisher(endpoints []endpoint.Endpoint) *StaticPublisher { | |
11 | return &StaticPublisher{ | |
12 | current: endpoints, | |
13 | subscribers: map[chan<- []endpoint.Endpoint]struct{}{}, | |
14 | } | |
15 | } | |
16 | ||
17 | // StaticPublisher holds a static set of endpoints. | |
18 | type StaticPublisher struct { | |
19 | mu sync.Mutex | |
20 | current []endpoint.Endpoint | |
21 | subscribers map[chan<- []endpoint.Endpoint]struct{} | |
22 | } | |
23 | ||
24 | // Subscribe implements Publisher. | |
25 | func (p *StaticPublisher) Subscribe(c chan<- []endpoint.Endpoint) { | |
26 | p.mu.Lock() | |
27 | defer p.mu.Unlock() | |
28 | p.subscribers[c] = struct{}{} | |
29 | c <- p.current | |
30 | } | |
31 | ||
32 | // Unsubscribe implements Publisher. | |
33 | func (p *StaticPublisher) Unsubscribe(c chan<- []endpoint.Endpoint) { | |
34 | p.mu.Lock() | |
35 | defer p.mu.Unlock() | |
36 | delete(p.subscribers, c) | |
37 | } | |
38 | ||
39 | // Stop implements Publisher, but is a no-op. | |
40 | func (p *StaticPublisher) Stop() {} | |
41 | ||
42 | // Replace replaces the endpoints and notifies all subscribers. | |
43 | func (p *StaticPublisher) Replace(endpoints []endpoint.Endpoint) { | |
44 | p.mu.Lock() | |
45 | defer p.mu.Unlock() | |
46 | p.current = endpoints | |
47 | for c := range p.subscribers { | |
48 | c <- p.current | |
49 | } | |
50 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "golang.org/x/net/context" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | ) | |
10 | ||
11 | func TestStaticPublisher(t *testing.T) { | |
12 | endpoints := []endpoint.Endpoint{ | |
13 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
14 | } | |
15 | p := loadbalancer.NewStaticPublisher(endpoints) | |
16 | defer p.Stop() | |
17 | ||
18 | c := make(chan []endpoint.Endpoint, 1) | |
19 | p.Subscribe(c) | |
20 | if want, have := len(endpoints), len(<-c); want != have { | |
21 | t.Errorf("want %d, have %d", want, have) | |
22 | } | |
23 | ||
24 | endpoints = []endpoint.Endpoint{} | |
25 | p.Replace(endpoints) | |
26 | if want, have := len(endpoints), len(<-c); want != have { | |
27 | t.Errorf("want %d, have %d", want, have) | |
28 | } | |
29 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Strategy yields endpoints to consumers according to some algorithm. | |
9 | type Strategy interface { | |
10 | Next() (endpoint.Endpoint, error) | |
11 | Stop() | |
12 | } | |
13 | ||
14 | // ErrNoEndpoints is returned by a strategy when there are no endpoints | |
15 | // available. | |
16 | var ErrNoEndpoints = errors.New("no endpoints available") |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/loadbalancer" | |
8 | ) | |
9 | ||
10 | func assertLoadBalancerNotEmpty(t *testing.T, lb loadbalancer.LoadBalancer) { | |
11 | if err := within(10*time.Millisecond, func() bool { | |
12 | return lb.Count() > 0 | |
13 | }); err != nil { | |
14 | t.Fatal("Publisher never updated endpoints") | |
15 | } | |
16 | } | |
17 | ||
18 | func within(d time.Duration, f func() bool) error { | |
19 | var ( | |
20 | deadline = time.After(d) | |
21 | ticker = time.NewTicker(d / 10) | |
22 | ) | |
23 | defer ticker.Stop() | |
24 | for { | |
25 | select { | |
26 | case <-ticker.C: | |
27 | if f() { | |
28 | return nil | |
29 | } | |
30 | case <-deadline: | |
31 | return fmt.Errorf("deadline exceeded") | |
32 | } | |
33 | } | |
34 | } |