Codebase list golang-github-go-kit-kit / 6a3f952
loadbalancer: rm Peter Bourgon 7 years ago
31 changed file(s) with 0 addition(s) and 2685 deletion(s). Raw diff Collapse all Expand all
+0
-67
loadbalancer/README.md less more
0 # package loadbalancer
1
2 `package loadbalancer` provides a client-side load balancer abstraction.
3
4 A publisher is responsible for emitting the most recent set of endpoints for a
5 single logical service. Publishers exist for static endpoints, and endpoints
6 discovered via periodic DNS SRV lookups on a single logical name. Consul and
7 etcd publishers are planned.
8
9 Different load balancers are implemented on top of publishers. Go kit
10 currently provides random and round-robin load balancers. Smarter behaviors,
11 e.g. load balancing based on underlying endpoint priority/weight, is planned.
12
13 ## Rationale
14
15 TODO
16
17 ## Usage
18
19 In your client, construct a publisher for a specific remote service, and pass
20 it to a load balancer. Then, request an endpoint from the load balancer
21 whenever you need to make a request to that remote service.
22
23 ```go
24 import (
25 "github.com/go-kit/kit/loadbalancer"
26 "github.com/go-kit/kit/loadbalancer/dnssrv"
27 )
28
29 func main() {
30 // Construct a load balancer for foosvc, which gets foosvc instances by
31 // polling a specific DNS SRV name.
32 p, err := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger)
33 if err != nil {
34 panic(err)
35 }
36
37 lb := loadbalancer.NewRoundRobin(p)
38
39 // Get a new endpoint from the load balancer.
40 endpoint, err := lb.Endpoint()
41 if err != nil {
42 panic(err)
43 }
44
45 // Use the endpoint to make a request.
46 response, err := endpoint(ctx, request)
47 }
48
49 func fooFactory(instance string) (endpoint.Endpoint, error) {
50 // Convert an instance (host:port) to an endpoint, via a defined transport binding.
51 }
52 ```
53
54 It's also possible to wrap a load balancer with a retry strategy, so that it
55 can be used as an endpoint directly. This may make load balancers more
56 convenient to use, at the cost of fine-grained control of failures.
57
58 ```go
59 func main() {
60 p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger)
61 lb := loadbalancer.NewRoundRobin(p)
62 endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb)
63
64 response, err := endpoint(ctx, request) // requests will be automatically load balanced
65 }
66 ```
+0
-30
loadbalancer/consul/client.go less more
0 package consul
1
2 import consul "github.com/hashicorp/consul/api"
3
4 // Client is a wrapper around the Consul API.
5 type Client interface {
6 Service(service string, tag string, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error)
7 }
8
9 type client struct {
10 consul *consul.Client
11 }
12
13 // NewClient returns an implementation of the Client interface expecting a fully
14 // setup Consul Client.
15 func NewClient(c *consul.Client) Client {
16 return &client{
17 consul: c,
18 }
19 }
20
21 // GetInstances returns the list of healthy entries for a given service filtered
22 // by tag.
23 func (c *client) Service(
24 service string,
25 tag string,
26 opts *consul.QueryOptions,
27 ) ([]*consul.ServiceEntry, *consul.QueryMeta, error) {
28 return c.consul.Health().Service(service, tag, true, opts)
29 }
+0
-174
loadbalancer/consul/publisher.go less more
0 package consul
1
2 import (
3 "fmt"
4 "strings"
5
6 consul "github.com/hashicorp/consul/api"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer"
10 "github.com/go-kit/kit/log"
11 )
12
13 const defaultIndex = 0
14
15 // Publisher yields endpoints for a service in Consul. Updates to the service
16 // are watched and will update the Publisher endpoints.
17 type Publisher struct {
18 cache *loadbalancer.EndpointCache
19 client Client
20 logger log.Logger
21 service string
22 tags []string
23 endpointsc chan []endpoint.Endpoint
24 quitc chan struct{}
25 }
26
27 // NewPublisher returns a Consul publisher which returns Endpoints for the
28 // requested service. It only returns instances for which all of the passed
29 // tags are present.
30 func NewPublisher(
31 client Client,
32 factory loadbalancer.Factory,
33 logger log.Logger,
34 service string,
35 tags ...string,
36 ) (*Publisher, error) {
37 p := &Publisher{
38 cache: loadbalancer.NewEndpointCache(factory, logger),
39 client: client,
40 logger: logger,
41 service: service,
42 tags: tags,
43 quitc: make(chan struct{}),
44 }
45
46 instances, index, err := p.getInstances(defaultIndex)
47 if err == nil {
48 logger.Log("service", service, "tags", strings.Join(tags, ", "), "instances", len(instances))
49 } else {
50 logger.Log("service", service, "tags", strings.Join(tags, ", "), "err", err)
51 }
52 p.cache.Replace(instances)
53
54 go p.loop(index)
55
56 return p, nil
57 }
58
59 // Endpoints implements the Publisher interface.
60 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
61 return p.cache.Endpoints()
62 }
63
64 // Stop terminates the publisher.
65 func (p *Publisher) Stop() {
66 close(p.quitc)
67 }
68
69 func (p *Publisher) loop(lastIndex uint64) {
70 var (
71 errc = make(chan error, 1)
72 resc = make(chan response, 1)
73 )
74
75 for {
76 go func() {
77 instances, index, err := p.getInstances(lastIndex)
78 if err != nil {
79 errc <- err
80 return
81 }
82 resc <- response{
83 index: index,
84 instances: instances,
85 }
86 }()
87
88 select {
89 case err := <-errc:
90 p.logger.Log("service", p.service, "err", err)
91 case res := <-resc:
92 p.cache.Replace(res.instances)
93 lastIndex = res.index
94 case <-p.quitc:
95 return
96 }
97 }
98 }
99
100 func (p *Publisher) getInstances(lastIndex uint64) ([]string, uint64, error) {
101 tag := ""
102
103 if len(p.tags) > 0 {
104 tag = p.tags[0]
105 }
106
107 entries, meta, err := p.client.Service(
108 p.service,
109 tag,
110 &consul.QueryOptions{
111 WaitIndex: lastIndex,
112 },
113 )
114 if err != nil {
115 return nil, 0, err
116 }
117
118 // If more than one tag is passed we need to filter it in the publisher until
119 // Consul supports multiple tags[0].
120 //
121 // [0] https://github.com/hashicorp/consul/issues/294
122 if len(p.tags) > 1 {
123 entries = filterEntries(entries, p.tags[1:]...)
124 }
125
126 return makeInstances(entries), meta.LastIndex, nil
127 }
128
129 // response is used as container to transport instances as well as the updated
130 // index.
131 type response struct {
132 index uint64
133 instances []string
134 }
135
136 func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry {
137 var es []*consul.ServiceEntry
138
139 ENTRIES:
140 for _, entry := range entries {
141 ts := make(map[string]struct{}, len(entry.Service.Tags))
142
143 for _, tag := range entry.Service.Tags {
144 ts[tag] = struct{}{}
145 }
146
147 for _, tag := range tags {
148 if _, ok := ts[tag]; !ok {
149 continue ENTRIES
150 }
151 }
152
153 es = append(es, entry)
154 }
155
156 return es
157 }
158
159 func makeInstances(entries []*consul.ServiceEntry) []string {
160 instances := make([]string, len(entries))
161
162 for i, entry := range entries {
163 addr := entry.Node.Address
164
165 if entry.Service.Address != "" {
166 addr = entry.Service.Address
167 }
168
169 instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port)
170 }
171
172 return instances
173 }
+0
-207
loadbalancer/consul/publisher_test.go less more
0 package consul
1
2 import (
3 "io"
4 "testing"
5
6 consul "github.com/hashicorp/consul/api"
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/log"
11 )
12
13 var consulState = []*consul.ServiceEntry{
14 {
15 Node: &consul.Node{
16 Address: "10.0.0.0",
17 Node: "app00.local",
18 },
19 Service: &consul.AgentService{
20 ID: "search-api-0",
21 Port: 8000,
22 Service: "search",
23 Tags: []string{
24 "api",
25 "v1",
26 },
27 },
28 },
29 {
30 Node: &consul.Node{
31 Address: "10.0.0.1",
32 Node: "app01.local",
33 },
34 Service: &consul.AgentService{
35 ID: "search-api-1",
36 Port: 8001,
37 Service: "search",
38 Tags: []string{
39 "api",
40 "v2",
41 },
42 },
43 },
44 {
45 Node: &consul.Node{
46 Address: "10.0.0.1",
47 Node: "app01.local",
48 },
49 Service: &consul.AgentService{
50 Address: "10.0.0.10",
51 ID: "search-db-0",
52 Port: 9000,
53 Service: "search",
54 Tags: []string{
55 "db",
56 },
57 },
58 },
59 }
60
61 func TestPublisher(t *testing.T) {
62 var (
63 logger = log.NewNopLogger()
64 client = newTestClient(consulState)
65 )
66
67 p, err := NewPublisher(client, testFactory, logger, "search", "api")
68 if err != nil {
69 t.Fatalf("publisher setup failed: %s", err)
70 }
71 defer p.Stop()
72
73 eps, err := p.Endpoints()
74 if err != nil {
75 t.Fatalf("endpoints failed: %s", err)
76 }
77
78 if have, want := len(eps), 2; have != want {
79 t.Errorf("have %v, want %v", have, want)
80 }
81 }
82
83 func TestPublisherNoService(t *testing.T) {
84 var (
85 logger = log.NewNopLogger()
86 client = newTestClient(consulState)
87 )
88
89 p, err := NewPublisher(client, testFactory, logger, "feed")
90 if err != nil {
91 t.Fatalf("publisher setup failed: %s", err)
92 }
93 defer p.Stop()
94
95 eps, err := p.Endpoints()
96 if err != nil {
97 t.Fatalf("endpoints failed: %s", err)
98 }
99
100 if have, want := len(eps), 0; have != want {
101 t.Fatalf("have %v, want %v", have, want)
102 }
103 }
104
105 func TestPublisherWithTags(t *testing.T) {
106 var (
107 logger = log.NewNopLogger()
108 client = newTestClient(consulState)
109 )
110
111 p, err := NewPublisher(client, testFactory, logger, "search", "api", "v2")
112 if err != nil {
113 t.Fatalf("publisher setup failed: %s", err)
114 }
115 defer p.Stop()
116
117 eps, err := p.Endpoints()
118 if err != nil {
119 t.Fatalf("endpoints failed: %s", err)
120 }
121
122 if have, want := len(eps), 1; have != want {
123 t.Fatalf("have %v, want %v", have, want)
124 }
125 }
126
127 func TestPublisherAddressOverride(t *testing.T) {
128 var (
129 ctx = context.Background()
130 logger = log.NewNopLogger()
131 client = newTestClient(consulState)
132 )
133
134 p, err := NewPublisher(client, testFactory, logger, "search", "db")
135 if err != nil {
136 t.Fatalf("publisher setup failed: %s", err)
137 }
138 defer p.Stop()
139
140 eps, err := p.Endpoints()
141 if err != nil {
142 t.Fatalf("endpoints failed: %s", err)
143 }
144
145 if have, want := len(eps), 1; have != want {
146 t.Fatalf("have %v, want %v", have, want)
147 }
148
149 ins, err := eps[0](ctx, struct{}{})
150 if err != nil {
151 t.Fatal(err)
152 }
153
154 if have, want := ins.(string), "10.0.0.10:9000"; have != want {
155 t.Errorf("have %#v, want %#v", have, want)
156 }
157 }
158
159 type testClient struct {
160 entries []*consul.ServiceEntry
161 }
162
163 func newTestClient(entries []*consul.ServiceEntry) Client {
164 if entries == nil {
165 entries = []*consul.ServiceEntry{}
166 }
167
168 return &testClient{
169 entries: entries,
170 }
171 }
172
173 func (c *testClient) Service(
174 service string,
175 tag string,
176 opts *consul.QueryOptions,
177 ) ([]*consul.ServiceEntry, *consul.QueryMeta, error) {
178 es := []*consul.ServiceEntry{}
179
180 for _, e := range c.entries {
181 if e.Service.Service != service {
182 continue
183 }
184 if tag != "" {
185 tagMap := map[string]struct{}{}
186
187 for _, t := range e.Service.Tags {
188 tagMap[t] = struct{}{}
189 }
190
191 if _, ok := tagMap[tag]; !ok {
192 continue
193 }
194 }
195
196 es = append(es, e)
197 }
198
199 return es, &consul.QueryMeta{}, nil
200 }
201
202 func testFactory(ins string) (endpoint.Endpoint, io.Closer, error) {
203 return func(context.Context, interface{}) (interface{}, error) {
204 return ins, nil
205 }, nil, nil
206 }
+0
-106
loadbalancer/dnssrv/publisher.go less more
0 package dnssrv
1
2 import (
3 "fmt"
4 "net"
5 "time"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/loadbalancer"
9 "github.com/go-kit/kit/log"
10 )
11
12 // Publisher yields endpoints taken from the named DNS SRV record. The name is
13 // resolved on a fixed schedule. Priorities and weights are ignored.
14 type Publisher struct {
15 name string
16 cache *loadbalancer.EndpointCache
17 logger log.Logger
18 quit chan struct{}
19 }
20
21 // NewPublisher returns a DNS SRV publisher. The name is resolved
22 // synchronously as part of construction; if that resolution fails, the
23 // constructor will return an error. The factory is used to convert a
24 // host:port to a usable endpoint. The logger is used to report DNS and
25 // factory errors.
26 func NewPublisher(
27 name string,
28 ttl time.Duration,
29 factory loadbalancer.Factory,
30 logger log.Logger,
31 ) *Publisher {
32 return NewPublisherDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger)
33 }
34
35 // NewPublisherDetailed is the same as NewPublisher, but allows users to provide
36 // an explicit lookup refresh ticker instead of a TTL, and specify the function
37 // used to perform lookups instead of using net.LookupSRV.
38 func NewPublisherDetailed(
39 name string,
40 refreshTicker *time.Ticker,
41 lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error),
42 factory loadbalancer.Factory,
43 logger log.Logger,
44 ) *Publisher {
45 p := &Publisher{
46 name: name,
47 cache: loadbalancer.NewEndpointCache(factory, logger),
48 logger: logger,
49 quit: make(chan struct{}),
50 }
51
52 instances, err := p.resolve(lookupSRV)
53 if err == nil {
54 logger.Log("name", name, "instances", len(instances))
55 } else {
56 logger.Log("name", name, "err", err)
57 }
58 p.cache.Replace(instances)
59
60 go p.loop(refreshTicker, lookupSRV)
61 return p
62 }
63
64 // Stop terminates the publisher.
65 func (p *Publisher) Stop() {
66 close(p.quit)
67 }
68
69 func (p *Publisher) loop(
70 refreshTicker *time.Ticker,
71 lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error),
72 ) {
73 defer refreshTicker.Stop()
74 for {
75 select {
76 case <-refreshTicker.C:
77 instances, err := p.resolve(lookupSRV)
78 if err != nil {
79 p.logger.Log(p.name, err)
80 continue // don't replace potentially-good with bad
81 }
82 p.cache.Replace(instances)
83
84 case <-p.quit:
85 return
86 }
87 }
88 }
89
90 // Endpoints implements the Publisher interface.
91 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
92 return p.cache.Endpoints()
93 }
94
95 func (p *Publisher) resolve(lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error)) ([]string, error) {
96 _, addrs, err := lookupSRV("", "", p.name)
97 if err != nil {
98 return []string{}, err
99 }
100 instances := make([]string, len(addrs))
101 for i, addr := range addrs {
102 instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
103 }
104 return instances, nil
105 }
+0
-133
loadbalancer/dnssrv/publisher_test.go less more
0 package dnssrv
1
2 import (
3 "errors"
4 "io"
5 "net"
6 "sync/atomic"
7 "testing"
8 "time"
9
10 "golang.org/x/net/context"
11
12 "github.com/go-kit/kit/endpoint"
13 "github.com/go-kit/kit/log"
14 )
15
16 func TestPublisher(t *testing.T) {
17 var (
18 name = "foo"
19 ttl = time.Second
20 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
21 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
22 logger = log.NewNopLogger()
23 )
24
25 p := NewPublisher(name, ttl, factory, logger)
26 defer p.Stop()
27
28 if _, err := p.Endpoints(); err != nil {
29 t.Fatal(err)
30 }
31 }
32
33 func TestBadLookup(t *testing.T) {
34 var (
35 name = "some-name"
36 ticker = time.NewTicker(time.Second)
37 lookups = uint32(0)
38 lookupSRV = func(string, string, string) (string, []*net.SRV, error) {
39 atomic.AddUint32(&lookups, 1)
40 return "", nil, errors.New("kaboom")
41 }
42 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
43 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
44 logger = log.NewNopLogger()
45 )
46
47 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
48 defer p.Stop()
49
50 endpoints, err := p.Endpoints()
51 if err != nil {
52 t.Error(err)
53 }
54 if want, have := 0, len(endpoints); want != have {
55 t.Errorf("want %d, have %d", want, have)
56 }
57 if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have {
58 t.Errorf("want %d, have %d", want, have)
59 }
60 }
61
62 func TestBadFactory(t *testing.T) {
63 var (
64 name = "some-name"
65 ticker = time.NewTicker(time.Second)
66 addr = &net.SRV{Target: "foo", Port: 1234}
67 addrs = []*net.SRV{addr}
68 lookupSRV = func(a, b, c string) (string, []*net.SRV, error) { return "", addrs, nil }
69 creates = uint32(0)
70 factory = func(s string) (endpoint.Endpoint, io.Closer, error) {
71 atomic.AddUint32(&creates, 1)
72 return nil, nil, errors.New("kaboom")
73 }
74 logger = log.NewNopLogger()
75 )
76
77 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
78 defer p.Stop()
79
80 endpoints, err := p.Endpoints()
81 if err != nil {
82 t.Error(err)
83 }
84 if want, have := 0, len(endpoints); want != have {
85 t.Errorf("want %q, have %q", want, have)
86 }
87 if want, have := uint32(1), atomic.LoadUint32(&creates); want != have {
88 t.Errorf("want %d, have %d", want, have)
89 }
90 }
91
92 func TestRefreshWithChange(t *testing.T) {
93 t.Skip("TODO")
94 }
95
96 func TestRefreshNoChange(t *testing.T) {
97 var (
98 addr = &net.SRV{Target: "my-target", Port: 5678}
99 addrs = []*net.SRV{addr}
100 name = "my-name"
101 ticker = time.NewTicker(time.Second)
102 lookups = uint32(0)
103 lookupSRV = func(string, string, string) (string, []*net.SRV, error) {
104 atomic.AddUint32(&lookups, 1)
105 return "", addrs, nil
106 }
107 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
108 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
109 logger = log.NewNopLogger()
110 )
111
112 ticker.Stop()
113 tickc := make(chan time.Time)
114 ticker.C = tickc
115
116 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
117 defer p.Stop()
118
119 if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have {
120 t.Errorf("want %d, have %d", want, have)
121 }
122
123 tickc <- time.Now()
124
125 if want, have := uint32(2), atomic.LoadUint32(&lookups); want != have {
126 t.Errorf("want %d, have %d", want, have)
127 }
128 }
129
130 func TestRefreshResolveError(t *testing.T) {
131 t.Skip("TODO")
132 }
+0
-112
loadbalancer/endpoint_cache.go less more
0 package loadbalancer
1
2 import (
3 "io"
4 "sort"
5 "sync"
6 "sync/atomic"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/log"
10 )
11
12 // EndpointCache caches endpoints that need to be deallocated when they're no
13 // longer useful. Clients update the cache by providing a current set of
14 // instance strings. The cache converts each instance string to an endpoint
15 // and a closer via the factory function.
16 //
17 // Instance strings are assumed to be unique and are used as keys. Endpoints
18 // that were in the previous set of instances and are not in the current set
19 // are considered invalid and closed.
20 //
21 // EndpointCache is designed to be used in your publisher implementation.
22 type EndpointCache struct {
23 mtx sync.Mutex
24 f Factory
25 m map[string]endpointCloser
26 cache atomic.Value //[]endpoint.Endpoint
27 logger log.Logger
28 }
29
30 // NewEndpointCache produces a new EndpointCache, ready for use. Instance
31 // strings will be converted to endpoints via the provided factory function.
32 // The logger is used to log errors.
33 func NewEndpointCache(f Factory, logger log.Logger) *EndpointCache {
34 endpointCache := &EndpointCache{
35 f: f,
36 m: map[string]endpointCloser{},
37 logger: log.NewContext(logger).With("component", "Endpoint Cache"),
38 }
39
40 endpointCache.cache.Store(make([]endpoint.Endpoint, 0))
41
42 return endpointCache
43 }
44
45 type endpointCloser struct {
46 endpoint.Endpoint
47 io.Closer
48 }
49
50 // Replace replaces the current set of endpoints with endpoints manufactured
51 // by the passed instances. If the same instance exists in both the existing
52 // and new sets, it's left untouched.
53 func (t *EndpointCache) Replace(instances []string) {
54 t.mtx.Lock()
55 defer t.mtx.Unlock()
56
57 // Produce the current set of endpoints.
58 oldMap := t.m
59 t.m = make(map[string]endpointCloser, len(instances))
60 for _, instance := range instances {
61 // If it already exists, just copy it over.
62 if ec, ok := oldMap[instance]; ok {
63 t.m[instance] = ec
64 delete(oldMap, instance)
65 continue
66 }
67
68 // If it doesn't exist, create it.
69 endpoint, closer, err := t.f(instance)
70 if err != nil {
71 t.logger.Log("instance", instance, "err", err)
72 continue
73 }
74 t.m[instance] = endpointCloser{endpoint, closer}
75 }
76
77 t.refreshCache()
78
79 // Close any leftover endpoints.
80 for _, ec := range oldMap {
81 if ec.Closer != nil {
82 ec.Closer.Close()
83 }
84 }
85 }
86
87 func (t *EndpointCache) refreshCache() {
88 var (
89 length = len(t.m)
90 instances = make([]string, 0, length)
91 newCache = make([]endpoint.Endpoint, 0, length)
92 )
93
94 for instance, _ := range t.m {
95 instances = append(instances, instance)
96 }
97 // Sort the instances for ensuring that Endpoints are returned into the same order if no modified.
98 sort.Strings(instances)
99
100 for _, instance := range instances {
101 newCache = append(newCache, t.m[instance].Endpoint)
102 }
103
104 t.cache.Store(newCache)
105 }
106
107 // Endpoints returns the current set of endpoints in undefined order. Satisfies
108 // Publisher interface.
109 func (t *EndpointCache) Endpoints() ([]endpoint.Endpoint, error) {
110 return t.cache.Load().([]endpoint.Endpoint), nil
111 }
+0
-92
loadbalancer/endpoint_cache_test.go less more
0 package loadbalancer_test
1
2 import (
3 "io"
4 "testing"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer"
11 "github.com/go-kit/kit/log"
12 )
13
14 func TestEndpointCache(t *testing.T) {
15 var (
16 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
17 ca = make(closer)
18 cb = make(closer)
19 c = map[string]io.Closer{"a": ca, "b": cb}
20 f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil }
21 ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger())
22 )
23
24 // Populate
25 ec.Replace([]string{"a", "b"})
26 select {
27 case <-ca:
28 t.Errorf("endpoint a closed, not good")
29 case <-cb:
30 t.Errorf("endpoint b closed, not good")
31 case <-time.After(time.Millisecond):
32 t.Logf("no closures yet, good")
33 }
34
35 // Duplicate, should be no-op
36 ec.Replace([]string{"a", "b"})
37 select {
38 case <-ca:
39 t.Errorf("endpoint a closed, not good")
40 case <-cb:
41 t.Errorf("endpoint b closed, not good")
42 case <-time.After(time.Millisecond):
43 t.Logf("no closures yet, good")
44 }
45
46 // Delete b
47 go ec.Replace([]string{"a"})
48 select {
49 case <-ca:
50 t.Errorf("endpoint a closed, not good")
51 case <-cb:
52 t.Logf("endpoint b closed, good")
53 case <-time.After(time.Millisecond):
54 t.Errorf("didn't close the deleted instance in time")
55 }
56
57 // Delete a
58 go ec.Replace([]string{""})
59 select {
60 // case <-cb: will succeed, as it's closed
61 case <-ca:
62 t.Logf("endpoint a closed, good")
63 case <-time.After(time.Millisecond):
64 t.Errorf("didn't close the deleted instance in time")
65 }
66 }
67
68 type closer chan struct{}
69
70 func (c closer) Close() error { close(c); return nil }
71
72 func BenchmarkEndpoints(b *testing.B) {
73 var (
74 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
75 ca = make(closer)
76 cb = make(closer)
77 c = map[string]io.Closer{"a": ca, "b": cb}
78 f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil }
79 ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger())
80 )
81
82 b.ReportAllocs()
83
84 ec.Replace([]string{"a", "b"})
85
86 b.RunParallel(func(pb *testing.PB) {
87 for pb.Next() {
88 ec.Endpoints()
89 }
90 })
91 }
+0
-131
loadbalancer/etcd/client.go less more
0 package etcd
1
2 import (
3 "crypto/tls"
4 "crypto/x509"
5 "io/ioutil"
6 "net"
7 "net/http"
8 "time"
9
10 etcd "github.com/coreos/etcd/client"
11 "golang.org/x/net/context"
12 )
13
14 // Client is a wrapper around the etcd client.
15 type Client interface {
16 // GetEntries will query the given prefix in etcd and returns a set of entries.
17 GetEntries(prefix string) ([]string, error)
18 // WatchPrefix starts watching every change for given prefix in etcd. When an
19 // change is detected it will populate the responseChan when an *etcd.Response.
20 WatchPrefix(prefix string, responseChan chan *etcd.Response)
21 }
22
23 type client struct {
24 keysAPI etcd.KeysAPI
25 ctx context.Context
26 }
27
28 type ClientOptions struct {
29 Cert string
30 Key string
31 CaCert string
32 DialTimeout time.Duration
33 DialKeepAline time.Duration
34 HeaderTimeoutPerRequest time.Duration
35 }
36
37 // NewClient returns an *etcd.Client with a connection to the named machines.
38 // It will return an error if a connection to the cluster cannot be made.
39 // The parameter machines needs to be a full URL with schemas.
40 // e.g. "http://localhost:2379" will work, but "localhost:2379" will not.
41 func NewClient(ctx context.Context, machines []string, options *ClientOptions) (Client, error) {
42 var (
43 c etcd.KeysAPI
44 err error
45 caCertCt []byte
46 tlsCert tls.Certificate
47 )
48 if options == nil {
49 options = &ClientOptions{}
50 }
51
52 if options.Cert != "" && options.Key != "" {
53 tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key)
54 if err != nil {
55 return nil, err
56 }
57
58 caCertCt, err = ioutil.ReadFile(options.CaCert)
59 if err != nil {
60 return nil, err
61 }
62 caCertPool := x509.NewCertPool()
63 caCertPool.AppendCertsFromPEM(caCertCt)
64
65 tlsConfig := &tls.Config{
66 Certificates: []tls.Certificate{tlsCert},
67 RootCAs: caCertPool,
68 }
69
70 transport := &http.Transport{
71 TLSClientConfig: tlsConfig,
72 Dial: func(network, addr string) (net.Conn, error) {
73 dial := &net.Dialer{
74 Timeout: options.DialTimeout,
75 KeepAlive: options.DialKeepAline,
76 }
77 return dial.Dial(network, addr)
78 },
79 }
80
81 cfg := etcd.Config{
82 Endpoints: machines,
83 Transport: transport,
84 HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
85 }
86 ce, err := etcd.New(cfg)
87 if err != nil {
88 return nil, err
89 }
90 c = etcd.NewKeysAPI(ce)
91 } else {
92 cfg := etcd.Config{
93 Endpoints: machines,
94 Transport: etcd.DefaultTransport,
95 HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
96 }
97 ce, err := etcd.New(cfg)
98 if err != nil {
99 return nil, err
100 }
101 c = etcd.NewKeysAPI(ce)
102 }
103 return &client{c, ctx}, nil
104 }
105
106 // GetEntries implements the etcd Client interface.
107 func (c *client) GetEntries(key string) ([]string, error) {
108 resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true})
109 if err != nil {
110 return nil, err
111 }
112
113 entries := make([]string, len(resp.Node.Nodes))
114 for i, node := range resp.Node.Nodes {
115 entries[i] = node.Value
116 }
117 return entries, nil
118 }
119
120 // WatchPrefix implements the etcd Client interface.
121 func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) {
122 watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true})
123 for {
124 res, err := watch.Next(c.ctx)
125 if err != nil {
126 return
127 }
128 responseChan <- res
129 }
130 }
+0
-71
loadbalancer/etcd/publisher.go less more
0 package etcd
1
2 import (
3 etcd "github.com/coreos/etcd/client"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/loadbalancer"
7 "github.com/go-kit/kit/log"
8 )
9
10 // Publisher yield endpoints stored in a certain etcd keyspace. Any kind of
11 // change in that keyspace is watched and will update the Publisher endpoints.
12 type Publisher struct {
13 client Client
14 prefix string
15 cache *loadbalancer.EndpointCache
16 logger log.Logger
17 quit chan struct{}
18 }
19
20 // NewPublisher returs a etcd publisher. Etcd will start watching the given
21 // prefix for changes and update the Publisher endpoints.
22 func NewPublisher(c Client, prefix string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) {
23 p := &Publisher{
24 client: c,
25 prefix: prefix,
26 cache: loadbalancer.NewEndpointCache(f, logger),
27 logger: logger,
28 quit: make(chan struct{}),
29 }
30
31 instances, err := p.client.GetEntries(p.prefix)
32 if err == nil {
33 logger.Log("prefix", p.prefix, "instances", len(instances))
34 } else {
35 logger.Log("prefix", p.prefix, "err", err)
36 }
37 p.cache.Replace(instances)
38
39 go p.loop()
40 return p, nil
41 }
42
43 func (p *Publisher) loop() {
44 responseChan := make(chan *etcd.Response)
45 go p.client.WatchPrefix(p.prefix, responseChan)
46 for {
47 select {
48 case <-responseChan:
49 instances, err := p.client.GetEntries(p.prefix)
50 if err != nil {
51 p.logger.Log("msg", "failed to retrieve entries", "err", err)
52 continue
53 }
54 p.cache.Replace(instances)
55
56 case <-p.quit:
57 return
58 }
59 }
60 }
61
62 // Endpoints implements the Publisher interface.
63 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
64 return p.cache.Endpoints()
65 }
66
67 // Stop terminates the Publisher.
68 func (p *Publisher) Stop() {
69 close(p.quit)
70 }
+0
-98
loadbalancer/etcd/publisher_test.go less more
0 package etcd_test
1
2 import (
3 "errors"
4 "io"
5 "testing"
6
7 stdetcd "github.com/coreos/etcd/client"
8 "golang.org/x/net/context"
9
10 "github.com/go-kit/kit/endpoint"
11 kitetcd "github.com/go-kit/kit/loadbalancer/etcd"
12 "github.com/go-kit/kit/log"
13 )
14
15 var (
16 node = &stdetcd.Node{
17 Key: "/foo",
18 Nodes: []*stdetcd.Node{
19 {Key: "/foo/1", Value: "1:1"},
20 {Key: "/foo/2", Value: "1:2"},
21 },
22 }
23 fakeResponse = &stdetcd.Response{
24 Node: node,
25 }
26 )
27
28 func TestPublisher(t *testing.T) {
29 var (
30 logger = log.NewNopLogger()
31 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
32 )
33
34 factory := func(string) (endpoint.Endpoint, io.Closer, error) {
35 return e, nil, nil
36 }
37
38 client := &fakeClient{
39 responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
40 }
41
42 p, err := kitetcd.NewPublisher(client, "/foo", factory, logger)
43 if err != nil {
44 t.Fatalf("failed to create new publisher: %v", err)
45 }
46 defer p.Stop()
47
48 if _, err := p.Endpoints(); err != nil {
49 t.Fatal(err)
50 }
51 }
52
53 func TestBadFactory(t *testing.T) {
54 logger := log.NewNopLogger()
55
56 factory := func(string) (endpoint.Endpoint, io.Closer, error) {
57 return nil, nil, errors.New("kaboom")
58 }
59
60 client := &fakeClient{
61 responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
62 }
63
64 p, err := kitetcd.NewPublisher(client, "/foo", factory, logger)
65 if err != nil {
66 t.Fatalf("failed to create new publisher: %v", err)
67 }
68 defer p.Stop()
69
70 endpoints, err := p.Endpoints()
71 if err != nil {
72 t.Fatal(err)
73 }
74
75 if want, have := 0, len(endpoints); want != have {
76 t.Errorf("want %q, have %q", want, have)
77 }
78 }
79
80 type fakeClient struct {
81 responses map[string]*stdetcd.Response
82 }
83
84 func (c *fakeClient) GetEntries(prefix string) ([]string, error) {
85 response, ok := c.responses[prefix]
86 if !ok {
87 return nil, errors.New("key not exist")
88 }
89
90 entries := make([]string, len(response.Node.Nodes))
91 for i, node := range response.Node.Nodes {
92 entries[i] = node.Value
93 }
94 return entries, nil
95 }
96
97 func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {}
+0
-15
loadbalancer/factory.go less more
0 package loadbalancer
1
2 import (
3 "io"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Factory is a function that converts an instance string, e.g. a host:port,
9 // to a usable endpoint. Factories are used by load balancers to convert
10 // instances returned by Publishers (typically host:port strings) into
11 // endpoints. Users are expected to provide their own factory functions that
12 // assume specific transports, or can deduce transports by parsing the
13 // instance string.
14 type Factory func(instance string) (endpoint.Endpoint, io.Closer, error)
+0
-35
loadbalancer/fixed/publisher.go less more
0 package fixed
1
2 import (
3 "sync"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Publisher yields the same set of fixed endpoints.
9 type Publisher struct {
10 mtx sync.RWMutex
11 endpoints []endpoint.Endpoint
12 }
13
14 // NewPublisher returns a fixed endpoint Publisher.
15 func NewPublisher(endpoints []endpoint.Endpoint) *Publisher {
16 return &Publisher{
17 endpoints: endpoints,
18 }
19 }
20
21 // Endpoints implements the Publisher interface.
22 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
23 p.mtx.RLock()
24 defer p.mtx.RUnlock()
25 return p.endpoints, nil
26 }
27
28 // Replace is a utility method to swap out the underlying endpoints of an
29 // existing fixed publisher. It's useful mostly for testing.
30 func (p *Publisher) Replace(endpoints []endpoint.Endpoint) {
31 p.mtx.Lock()
32 defer p.mtx.Unlock()
33 p.endpoints = endpoints
34 }
+0
-48
loadbalancer/fixed/publisher_test.go less more
0 package fixed_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer/fixed"
10 )
11
12 func TestFixed(t *testing.T) {
13 var (
14 e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
15 e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
16 endpoints = []endpoint.Endpoint{e1, e2}
17 )
18 p := fixed.NewPublisher(endpoints)
19 have, err := p.Endpoints()
20 if err != nil {
21 t.Fatal(err)
22 }
23 if want := endpoints; !reflect.DeepEqual(want, have) {
24 t.Fatalf("want %#+v, have %#+v", want, have)
25 }
26 }
27
28 func TestFixedReplace(t *testing.T) {
29 p := fixed.NewPublisher([]endpoint.Endpoint{
30 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
31 })
32 have, err := p.Endpoints()
33 if err != nil {
34 t.Fatal(err)
35 }
36 if want, have := 1, len(have); want != have {
37 t.Fatalf("want %d, have %d", want, have)
38 }
39 p.Replace([]endpoint.Endpoint{})
40 have, err = p.Endpoints()
41 if err != nil {
42 t.Fatal(err)
43 }
44 if want, have := 0, len(have); want != have {
45 t.Fatalf("want %d, have %d", want, have)
46 }
47 }
+0
-18
loadbalancer/loadbalancer.go less more
0 package loadbalancer
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // LoadBalancer describes something that can yield endpoints for a remote
9 // service method.
10 type LoadBalancer interface {
11 Endpoint() (endpoint.Endpoint, error)
12 }
13
14 // ErrNoEndpoints is returned when a load balancer (or one of its components)
15 // has no endpoints to return. In a request lifecycle, this is usually a fatal
16 // error.
17 var ErrNoEndpoints = errors.New("no endpoints available")
+0
-10
loadbalancer/publisher.go less more
0 package loadbalancer
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Publisher describes something that provides a set of identical endpoints.
5 // Different publisher implementations exist for different kinds of service
6 // discovery systems.
7 type Publisher interface {
8 Endpoints() ([]endpoint.Endpoint, error)
9 }
+0
-34
loadbalancer/random.go less more
0 package loadbalancer
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Random is a completely stateless load balancer that chooses a random
9 // endpoint to return each time.
10 type Random struct {
11 p Publisher
12 r *rand.Rand
13 }
14
15 // NewRandom returns a new Random load balancer.
16 func NewRandom(p Publisher, seed int64) *Random {
17 return &Random{
18 p: p,
19 r: rand.New(rand.NewSource(seed)),
20 }
21 }
22
23 // Endpoint implements the LoadBalancer interface.
24 func (r *Random) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := r.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
29 if len(endpoints) <= 0 {
30 return nil, ErrNoEndpoints
31 }
32 return endpoints[r.r.Intn(len(endpoints))], nil
33 }
+0
-60
loadbalancer/random_test.go less more
0 package loadbalancer_test
1
2 import (
3 "math"
4 "testing"
5
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/loadbalancer"
10 "github.com/go-kit/kit/loadbalancer/fixed"
11 )
12
13 func TestRandomDistribution(t *testing.T) {
14 var (
15 n = 3
16 endpoints = make([]endpoint.Endpoint, n)
17 counts = make([]int, n)
18 seed = int64(123)
19 ctx = context.Background()
20 iterations = 100000
21 want = iterations / n
22 tolerance = want / 100 // 1%
23 )
24
25 for i := 0; i < n; i++ {
26 i0 := i
27 endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil }
28 }
29
30 lb := loadbalancer.NewRandom(fixed.NewPublisher(endpoints), seed)
31
32 for i := 0; i < iterations; i++ {
33 e, err := lb.Endpoint()
34 if err != nil {
35 t.Fatal(err)
36 }
37 if _, err := e(ctx, struct{}{}); err != nil {
38 t.Error(err)
39 }
40 }
41
42 for i, have := range counts {
43 if math.Abs(float64(want-have)) > float64(tolerance) {
44 t.Errorf("%d: want %d, have %d", i, want, have)
45 }
46 }
47 }
48
49 func TestRandomBadPublisher(t *testing.T) {
50 t.Skip("TODO")
51 }
52
53 func TestRandomNoEndpoints(t *testing.T) {
54 lb := loadbalancer.NewRandom(fixed.NewPublisher([]endpoint.Endpoint{}), 123)
55 _, have := lb.Endpoint()
56 if want := loadbalancer.ErrNoEndpoints; want != have {
57 t.Errorf("want %q, have %q", want, have)
58 }
59 }
+0
-57
loadbalancer/retry.go less more
0 package loadbalancer
1
2 import (
3 "fmt"
4 "strings"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // Retry wraps the load balancer to make it behave like a simple endpoint.
13 // Requests to the endpoint will be automatically load balanced via the load
14 // balancer. Requests that return errors will be retried until they succeed,
15 // up to max times, or until the timeout is elapsed, whichever comes first.
16 func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint {
17 if lb == nil {
18 panic("nil LoadBalancer")
19 }
20
21 return func(ctx context.Context, request interface{}) (interface{}, error) {
22 var (
23 newctx, cancel = context.WithTimeout(ctx, timeout)
24 responses = make(chan interface{}, 1)
25 errs = make(chan error, 1)
26 a = []string{}
27 )
28 defer cancel()
29 for i := 1; i <= max; i++ {
30 go func() {
31 e, err := lb.Endpoint()
32 if err != nil {
33 errs <- err
34 return
35 }
36 response, err := e(newctx, request)
37 if err != nil {
38 errs <- err
39 return
40 }
41 responses <- response
42 }()
43
44 select {
45 case <-newctx.Done():
46 return nil, newctx.Err()
47 case response := <-responses:
48 return response, nil
49 case err := <-errs:
50 a = append(a, err.Error())
51 continue
52 }
53 }
54 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
55 }
56 }
+0
-83
loadbalancer/retry_test.go less more
0 package loadbalancer_test
1
2 import (
3 "errors"
4 "testing"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer"
11 "github.com/go-kit/kit/loadbalancer/fixed"
12 )
13
14 func TestRetryMaxTotalFail(t *testing.T) {
15 var (
16 endpoints = []endpoint.Endpoint{} // no endpoints
17 p = fixed.NewPublisher(endpoints)
18 lb = loadbalancer.NewRoundRobin(p)
19 retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries
20 ctx = context.Background()
21 )
22 if _, err := retry(ctx, struct{}{}); err == nil {
23 t.Errorf("expected error, got none") // should fail
24 }
25 }
26
27 func TestRetryMaxPartialFail(t *testing.T) {
28 var (
29 endpoints = []endpoint.Endpoint{
30 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
31 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
32 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
33 }
34 retries = len(endpoints) - 1 // not quite enough retries
35 p = fixed.NewPublisher(endpoints)
36 lb = loadbalancer.NewRoundRobin(p)
37 ctx = context.Background()
38 )
39 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil {
40 t.Errorf("expected error, got none")
41 }
42 }
43
44 func TestRetryMaxSuccess(t *testing.T) {
45 var (
46 endpoints = []endpoint.Endpoint{
47 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
48 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
49 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
50 }
51 retries = len(endpoints) // exactly enough retries
52 p = fixed.NewPublisher(endpoints)
53 lb = loadbalancer.NewRoundRobin(p)
54 ctx = context.Background()
55 )
56 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil {
57 t.Error(err)
58 }
59 }
60
61 func TestRetryTimeout(t *testing.T) {
62 var (
63 step = make(chan struct{})
64 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
65 timeout = time.Millisecond
66 retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(fixed.NewPublisher([]endpoint.Endpoint{e})))
67 errs = make(chan error, 1)
68 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
69 )
70
71 go func() { step <- struct{}{} }() // queue up a flush of the endpoint
72 invoke() // invoke the endpoint and trigger the flush
73 if err := <-errs; err != nil { // that should succeed
74 t.Error(err)
75 }
76
77 go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
78 invoke() // invoke the endpoint
79 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
80 t.Errorf("wanted %v, got none", context.DeadlineExceeded)
81 }
82 }
+0
-41
loadbalancer/round_robin.go less more
0 package loadbalancer
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // RoundRobin is a simple load balancer that returns each of the published
9 // endpoints in sequence.
10 type RoundRobin struct {
11 p Publisher
12 counter uint64
13 }
14
15 // NewRoundRobin returns a new RoundRobin load balancer.
16 func NewRoundRobin(p Publisher) *RoundRobin {
17 return &RoundRobin{
18 p: p,
19 counter: 0,
20 }
21 }
22
23 // Endpoint implements the LoadBalancer interface.
24 func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) {
25 endpoints, err := rr.p.Endpoints()
26 if err != nil {
27 return nil, err
28 }
29 if len(endpoints) <= 0 {
30 return nil, ErrNoEndpoints
31 }
32 var old uint64
33 for {
34 old = atomic.LoadUint64(&rr.counter)
35 if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) {
36 break
37 }
38 }
39 return endpoints[old%uint64(len(endpoints))], nil
40 }
+0
-51
loadbalancer/round_robin_test.go less more
0 package loadbalancer_test
1
2 import (
3 "reflect"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/loadbalancer"
8 "github.com/go-kit/kit/loadbalancer/fixed"
9 "golang.org/x/net/context"
10 )
11
12 func TestRoundRobinDistribution(t *testing.T) {
13 var (
14 ctx = context.Background()
15 counts = []int{0, 0, 0}
16 endpoints = []endpoint.Endpoint{
17 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
18 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
19 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
20 }
21 )
22
23 lb := loadbalancer.NewRoundRobin(fixed.NewPublisher(endpoints))
24
25 for i, want := range [][]int{
26 {1, 0, 0},
27 {1, 1, 0},
28 {1, 1, 1},
29 {2, 1, 1},
30 {2, 2, 1},
31 {2, 2, 2},
32 {3, 2, 2},
33 } {
34 e, err := lb.Endpoint()
35 if err != nil {
36 t.Fatal(err)
37 }
38 if _, err := e(ctx, struct{}{}); err != nil {
39 t.Error(err)
40 }
41 if have := counts; !reflect.DeepEqual(want, have) {
42 t.Fatalf("%d: want %v, have %v", i, want, have)
43 }
44
45 }
46 }
47
48 func TestRoundRobinBadPublisher(t *testing.T) {
49 t.Skip("TODO")
50 }
+0
-31
loadbalancer/static/publisher.go less more
0 package static
1
2 import (
3 "github.com/go-kit/kit/endpoint"
4 "github.com/go-kit/kit/loadbalancer"
5 "github.com/go-kit/kit/loadbalancer/fixed"
6 "github.com/go-kit/kit/log"
7 )
8
9 // Publisher yields a set of static endpoints as produced by the passed factory.
10 type Publisher struct{ publisher *fixed.Publisher }
11
12 // NewPublisher returns a static endpoint Publisher.
13 func NewPublisher(instances []string, factory loadbalancer.Factory, logger log.Logger) Publisher {
14 logger = log.NewContext(logger).With("component", "Static Publisher")
15 endpoints := []endpoint.Endpoint{}
16 for _, instance := range instances {
17 e, _, err := factory(instance) // never close
18 if err != nil {
19 logger.Log("instance", instance, "err", err)
20 continue
21 }
22 endpoints = append(endpoints, e)
23 }
24 return Publisher{publisher: fixed.NewPublisher(endpoints)}
25 }
26
27 // Endpoints implements Publisher.
28 func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) {
29 return p.publisher.Endpoints()
30 }
+0
-39
loadbalancer/static/publisher_test.go less more
0 package static_test
1
2 import (
3 "fmt"
4 "io"
5 "testing"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/loadbalancer/static"
11 "github.com/go-kit/kit/log"
12 )
13
14 func TestStatic(t *testing.T) {
15 var (
16 instances = []string{"foo", "bar", "baz"}
17 endpoints = map[string]endpoint.Endpoint{
18 "foo": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
19 "bar": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
20 "baz": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
21 }
22 factory = func(instance string) (endpoint.Endpoint, io.Closer, error) {
23 if e, ok := endpoints[instance]; ok {
24 return e, nil, nil
25 }
26 return nil, nil, fmt.Errorf("%s: not found", instance)
27 }
28 )
29 p := static.NewPublisher(instances, factory, log.NewNopLogger())
30 have, err := p.Endpoints()
31 if err != nil {
32 t.Fatal(err)
33 }
34 want := []endpoint.Endpoint{endpoints["foo"], endpoints["bar"], endpoints["baz"]}
35 if fmt.Sprint(want) != fmt.Sprint(have) {
36 t.Fatalf("want %v, have %v", want, have)
37 }
38 }
+0
-231
loadbalancer/zk/client.go less more
0 package zk
1
2 import (
3 "errors"
4 "net"
5 "strings"
6 "time"
7
8 "github.com/samuel/go-zookeeper/zk"
9
10 "github.com/go-kit/kit/log"
11 )
12
13 // DefaultACL is the default ACL to use for creating znodes.
14 var (
15 DefaultACL = zk.WorldACL(zk.PermAll)
16 ErrInvalidCredentials = errors.New("invalid credentials provided")
17 ErrClientClosed = errors.New("client service closed")
18 )
19
20 const (
21 // DefaultConnectTimeout is the default timeout to establish a connection to
22 // a ZooKeeper node.
23 DefaultConnectTimeout = 2 * time.Second
24 // DefaultSessionTimeout is the default timeout to keep the current
25 // ZooKeeper session alive during a temporary disconnect.
26 DefaultSessionTimeout = 5 * time.Second
27 )
28
29 // Client is a wrapper around a lower level ZooKeeper client implementation.
30 type Client interface {
31 // GetEntries should query the provided path in ZooKeeper, place a watch on
32 // it and retrieve data from its current child nodes.
33 GetEntries(path string) ([]string, <-chan zk.Event, error)
34 // CreateParentNodes should try to create the path in case it does not exist
35 // yet on ZooKeeper.
36 CreateParentNodes(path string) error
37 // Stop should properly shutdown the client implementation
38 Stop()
39 }
40
41 type clientConfig struct {
42 logger log.Logger
43 acl []zk.ACL
44 credentials []byte
45 connectTimeout time.Duration
46 sessionTimeout time.Duration
47 rootNodePayload [][]byte
48 eventHandler func(zk.Event)
49 }
50
51 // Option functions enable friendly APIs.
52 type Option func(*clientConfig) error
53
54 type client struct {
55 *zk.Conn
56 clientConfig
57 active bool
58 quit chan struct{}
59 }
60
61 // ACL returns an Option specifying a non-default ACL for creating parent nodes.
62 func ACL(acl []zk.ACL) Option {
63 return func(c *clientConfig) error {
64 c.acl = acl
65 return nil
66 }
67 }
68
69 // Credentials returns an Option specifying a user/password combination which
70 // the client will use to authenticate itself with.
71 func Credentials(user, pass string) Option {
72 return func(c *clientConfig) error {
73 if user == "" || pass == "" {
74 return ErrInvalidCredentials
75 }
76 c.credentials = []byte(user + ":" + pass)
77 return nil
78 }
79 }
80
81 // ConnectTimeout returns an Option specifying a non-default connection timeout
82 // when we try to establish a connection to a ZooKeeper server.
83 func ConnectTimeout(t time.Duration) Option {
84 return func(c *clientConfig) error {
85 if t.Seconds() < 1 {
86 return errors.New("invalid connect timeout (minimum value is 1 second)")
87 }
88 c.connectTimeout = t
89 return nil
90 }
91 }
92
93 // SessionTimeout returns an Option specifying a non-default session timeout.
94 func SessionTimeout(t time.Duration) Option {
95 return func(c *clientConfig) error {
96 if t.Seconds() < 1 {
97 return errors.New("invalid session timeout (minimum value is 1 second)")
98 }
99 c.sessionTimeout = t
100 return nil
101 }
102 }
103
104 // Payload returns an Option specifying non-default data values for each znode
105 // created by CreateParentNodes.
106 func Payload(payload [][]byte) Option {
107 return func(c *clientConfig) error {
108 c.rootNodePayload = payload
109 return nil
110 }
111 }
112
113 // EventHandler returns an Option specifying a callback function to handle
114 // incoming zk.Event payloads (ZooKeeper connection events).
115 func EventHandler(handler func(zk.Event)) Option {
116 return func(c *clientConfig) error {
117 c.eventHandler = handler
118 return nil
119 }
120 }
121
122 // NewClient returns a ZooKeeper client with a connection to the server cluster.
123 // It will return an error if the server cluster cannot be resolved.
124 func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) {
125 defaultEventHandler := func(event zk.Event) {
126 logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err)
127 }
128 config := clientConfig{
129 acl: DefaultACL,
130 connectTimeout: DefaultConnectTimeout,
131 sessionTimeout: DefaultSessionTimeout,
132 eventHandler: defaultEventHandler,
133 logger: logger,
134 }
135 for _, option := range options {
136 if err := option(&config); err != nil {
137 return nil, err
138 }
139 }
140 // dialer overrides the default ZooKeeper library Dialer so we can configure
141 // the connectTimeout. The current library has a hardcoded value of 1 second
142 // and there are reports of race conditions, due to slow DNS resolvers and
143 // other network latency issues.
144 dialer := func(network, address string, _ time.Duration) (net.Conn, error) {
145 return net.DialTimeout(network, address, config.connectTimeout)
146 }
147 conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer))
148
149 if err != nil {
150 return nil, err
151 }
152
153 if len(config.credentials) > 0 {
154 err = conn.AddAuth("digest", config.credentials)
155 if err != nil {
156 return nil, err
157 }
158 }
159
160 c := &client{conn, config, true, make(chan struct{})}
161
162 // Start listening for incoming Event payloads and callback the set
163 // eventHandler.
164 go func() {
165 for {
166 select {
167 case event := <-eventc:
168 config.eventHandler(event)
169 case <-c.quit:
170 return
171 }
172 }
173 }()
174 return c, nil
175 }
176
177 // CreateParentNodes implements the ZooKeeper Client interface.
178 func (c *client) CreateParentNodes(path string) error {
179 if !c.active {
180 return ErrClientClosed
181 }
182 if path[0] != '/' {
183 return zk.ErrInvalidPath
184 }
185 payload := []byte("")
186 pathString := ""
187 pathNodes := strings.Split(path, "/")
188 for i := 1; i < len(pathNodes); i++ {
189 if i <= len(c.rootNodePayload) {
190 payload = c.rootNodePayload[i-1]
191 } else {
192 payload = []byte("")
193 }
194 pathString += "/" + pathNodes[i]
195 _, err := c.Create(pathString, payload, 0, c.acl)
196 // not being able to create the node because it exists or not having
197 // sufficient rights is not an issue. It is ok for the node to already
198 // exist and/or us to only have read rights
199 if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth {
200 return err
201 }
202 }
203 return nil
204 }
205
206 // GetEntries implements the ZooKeeper Client interface.
207 func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) {
208 // retrieve list of child nodes for given path and add watch to path
209 znodes, _, eventc, err := c.ChildrenW(path)
210
211 if err != nil {
212 return nil, eventc, err
213 }
214
215 var resp []string
216 for _, znode := range znodes {
217 // retrieve payload for child znode and add to response array
218 if data, _, err := c.Get(path + "/" + znode); err == nil {
219 resp = append(resp, string(data))
220 }
221 }
222 return resp, eventc, nil
223 }
224
225 // Stop implements the ZooKeeper Client interface.
226 func (c *client) Stop() {
227 c.active = false
228 close(c.quit)
229 c.Close()
230 }
+0
-157
loadbalancer/zk/client_test.go less more
0 package zk
1
2 import (
3 "bytes"
4 "testing"
5 "time"
6
7 stdzk "github.com/samuel/go-zookeeper/zk"
8
9 "github.com/go-kit/kit/log"
10 )
11
12 func TestNewClient(t *testing.T) {
13 var (
14 acl = stdzk.WorldACL(stdzk.PermRead)
15 connectTimeout = 3 * time.Second
16 sessionTimeout = 20 * time.Second
17 payload = [][]byte{[]byte("Payload"), []byte("Test")}
18 )
19
20 c, err := NewClient(
21 []string{"FailThisInvalidHost!!!"},
22 log.NewNopLogger(),
23 )
24 if err == nil {
25 t.Errorf("expected error, got nil")
26 }
27
28 hasFired := false
29 calledEventHandler := make(chan struct{})
30 eventHandler := func(event stdzk.Event) {
31 if !hasFired {
32 // test is successful if this function has fired at least once
33 hasFired = true
34 close(calledEventHandler)
35 }
36 }
37
38 c, err = NewClient(
39 []string{"localhost"},
40 log.NewNopLogger(),
41 ACL(acl),
42 ConnectTimeout(connectTimeout),
43 SessionTimeout(sessionTimeout),
44 Payload(payload),
45 EventHandler(eventHandler),
46 )
47 if err != nil {
48 t.Fatal(err)
49 }
50 defer c.Stop()
51
52 clientImpl, ok := c.(*client)
53 if !ok {
54 t.Fatal("retrieved incorrect Client implementation")
55 }
56 if want, have := acl, clientImpl.acl; want[0] != have[0] {
57 t.Errorf("want %+v, have %+v", want, have)
58 }
59 if want, have := connectTimeout, clientImpl.connectTimeout; want != have {
60 t.Errorf("want %d, have %d", want, have)
61 }
62 if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have {
63 t.Errorf("want %d, have %d", want, have)
64 }
65 if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 {
66 t.Errorf("want %s, have %s", want, have)
67 }
68
69 select {
70 case <-calledEventHandler:
71 case <-time.After(100 * time.Millisecond):
72 t.Errorf("event handler never called")
73 }
74 }
75
76 func TestOptions(t *testing.T) {
77 _, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials"))
78 if err != nil && err != stdzk.ErrNoServer {
79 t.Errorf("unexpected error: %v", err)
80 }
81
82 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", ""))
83 if want, have := err, ErrInvalidCredentials; want != have {
84 t.Errorf("want %v, have %v", want, have)
85 }
86
87 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0))
88 if err == nil {
89 t.Errorf("expected connect timeout error")
90 }
91
92 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0))
93 if err == nil {
94 t.Errorf("expected connect timeout error")
95 }
96 }
97
98 func TestCreateParentNodes(t *testing.T) {
99 payload := [][]byte{[]byte("Payload"), []byte("Test")}
100
101 c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger())
102 if err != nil {
103 t.Errorf("unexpected error: %v", err)
104 }
105 if c == nil {
106 t.Fatal("expected new Client, got nil")
107 }
108
109 p, err := NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger())
110 if err != stdzk.ErrNoServer {
111 t.Errorf("unexpected error: %v", err)
112 }
113 if p != nil {
114 t.Error("expected failed new Publisher")
115 }
116
117 p, err = NewPublisher(c, "invalidpath", newFactory(""), log.NewNopLogger())
118 if err != stdzk.ErrInvalidPath {
119 t.Errorf("unexpected error: %v", err)
120 }
121 _, _, err = c.GetEntries("/validpath")
122 if err != stdzk.ErrNoServer {
123 t.Errorf("unexpected error: %v", err)
124 }
125
126 c.Stop()
127
128 err = c.CreateParentNodes("/validpath")
129 if err != ErrClientClosed {
130 t.Errorf("unexpected error: %v", err)
131 }
132
133 p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger())
134 if err != ErrClientClosed {
135 t.Errorf("unexpected error: %v", err)
136 }
137 if p != nil {
138 t.Error("expected failed new Publisher")
139 }
140
141 c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload))
142 if err != nil {
143 t.Errorf("unexpected error: %v", err)
144 }
145 if c == nil {
146 t.Fatal("expected new Client, got nil")
147 }
148
149 p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger())
150 if err != stdzk.ErrNoServer {
151 t.Errorf("unexpected error: %v", err)
152 }
153 if p != nil {
154 t.Error("expected failed new Publisher")
155 }
156 }
+0
-201
loadbalancer/zk/integration_test.go less more
0 // +build integration
1
2 package zk
3
4 import (
5 "bytes"
6 "flag"
7 "fmt"
8 "os"
9 "testing"
10 "time"
11
12 stdzk "github.com/samuel/go-zookeeper/zk"
13 )
14
15 var (
16 host []string
17 )
18
19 func TestMain(m *testing.M) {
20 flag.Parse()
21
22 fmt.Println("Starting ZooKeeper server...")
23
24 ts, err := stdzk.StartTestCluster(1, nil, nil)
25 if err != nil {
26 fmt.Printf("ZooKeeper server error: %v\n", err)
27 os.Exit(1)
28 }
29
30 host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)}
31 code := m.Run()
32
33 ts.Stop()
34 os.Exit(code)
35 }
36
37 func TestCreateParentNodesOnServer(t *testing.T) {
38 payload := [][]byte{[]byte("Payload"), []byte("Test")}
39 c1, err := NewClient(host, logger, Payload(payload))
40 if err != nil {
41 t.Fatalf("Connect returned error: %v", err)
42 }
43 if c1 == nil {
44 t.Fatal("Expected pointer to client, got nil")
45 }
46 defer c1.Stop()
47
48 p, err := NewPublisher(c1, path, newFactory(""), logger)
49 if err != nil {
50 t.Fatalf("Unable to create Publisher: %v", err)
51 }
52 defer p.Stop()
53
54 endpoints, err := p.Endpoints()
55 if err != nil {
56 t.Fatal(err)
57 }
58 if want, have := 0, len(endpoints); want != have {
59 t.Errorf("want %d, have %d", want, have)
60 }
61
62 c2, err := NewClient(host, logger)
63 if err != nil {
64 t.Fatalf("Connect returned error: %v", err)
65 }
66 defer c2.Stop()
67 data, _, err := c2.(*client).Get(path)
68 if err != nil {
69 t.Fatal(err)
70 }
71 // test Client implementation of CreateParentNodes. It should have created
72 // our payload
73 if bytes.Compare(data, payload[1]) != 0 {
74 t.Errorf("want %s, have %s", payload[1], data)
75 }
76
77 }
78
79 func TestCreateBadParentNodesOnServer(t *testing.T) {
80 c, _ := NewClient(host, logger)
81 defer c.Stop()
82
83 _, err := NewPublisher(c, "invalid/path", newFactory(""), logger)
84
85 if want, have := stdzk.ErrInvalidPath, err; want != have {
86 t.Errorf("want %v, have %v", want, have)
87 }
88 }
89
90 func TestCredentials1(t *testing.T) {
91 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
92 c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret"))
93 defer c.Stop()
94
95 _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger)
96
97 if err != nil {
98 t.Fatal(err)
99 }
100 }
101
102 func TestCredentials2(t *testing.T) {
103 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
104 c, _ := NewClient(host, logger, ACL(acl))
105 defer c.Stop()
106
107 _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger)
108
109 if err != stdzk.ErrNoAuth {
110 t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err)
111 }
112 }
113
114 func TestConnection(t *testing.T) {
115 c, _ := NewClient(host, logger)
116 c.Stop()
117
118 _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger)
119
120 if err != ErrClientClosed {
121 t.Errorf("want %v, have %v", ErrClientClosed, err)
122 }
123 }
124
125 func TestGetEntriesOnServer(t *testing.T) {
126 var instancePayload = "protocol://hostname:port/routing"
127
128 c1, err := NewClient(host, logger)
129 if err != nil {
130 t.Fatalf("Connect returned error: %v", err)
131 }
132
133 defer c1.Stop()
134
135 c2, err := NewClient(host, logger)
136 p, err := NewPublisher(c2, path, newFactory(""), logger)
137 if err != nil {
138 t.Fatal(err)
139 }
140 defer c2.Stop()
141
142 c2impl, _ := c2.(*client)
143 _, err = c2impl.Create(
144 path+"/instance1",
145 []byte(instancePayload),
146 stdzk.FlagEphemeral|stdzk.FlagSequence,
147 stdzk.WorldACL(stdzk.PermAll),
148 )
149 if err != nil {
150 t.Fatalf("Unable to create test ephemeral znode 1: %v", err)
151 }
152 _, err = c2impl.Create(
153 path+"/instance2",
154 []byte(instancePayload+"2"),
155 stdzk.FlagEphemeral|stdzk.FlagSequence,
156 stdzk.WorldACL(stdzk.PermAll),
157 )
158 if err != nil {
159 t.Fatalf("Unable to create test ephemeral znode 2: %v", err)
160 }
161
162 time.Sleep(50 * time.Millisecond)
163
164 endpoints, err := p.Endpoints()
165 if err != nil {
166 t.Fatal(err)
167 }
168 if want, have := 2, len(endpoints); want != have {
169 t.Errorf("want %d, have %d", want, have)
170 }
171 }
172
173 func TestGetEntriesPayloadOnServer(t *testing.T) {
174 c, err := NewClient(host, logger)
175 if err != nil {
176 t.Fatalf("Connect returned error: %v", err)
177 }
178 _, eventc, err := c.GetEntries(path)
179 if err != nil {
180 t.Fatal(err)
181 }
182 _, err = c.(*client).Create(
183 path+"/instance3",
184 []byte("just some payload"),
185 stdzk.FlagEphemeral|stdzk.FlagSequence,
186 stdzk.WorldACL(stdzk.PermAll),
187 )
188 if err != nil {
189 t.Fatalf("Unable to create test ephemeral znode: %v", err)
190 }
191 select {
192 case event := <-eventc:
193 if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have {
194 t.Errorf("want %s, have %s", want, have)
195 }
196 case <-time.After(20 * time.Millisecond):
197 t.Errorf("expected incoming watch event, timeout occurred")
198 }
199
200 }
+0
-27
loadbalancer/zk/logwrapper.go less more
0 package zk
1
2 import (
3 "fmt"
4
5 "github.com/samuel/go-zookeeper/zk"
6
7 "github.com/go-kit/kit/log"
8 )
9
10 // wrapLogger wraps a go-kit logger so we can use it as the logging service for
11 // the ZooKeeper library (which expects a Printf method to be available)
12 type wrapLogger struct {
13 log.Logger
14 }
15
16 func (logger wrapLogger) Printf(str string, vars ...interface{}) {
17 logger.Log("msg", fmt.Sprintf(str, vars...))
18 }
19
20 // withLogger replaces the ZooKeeper library's default logging service for our
21 // own go-kit logger
22 func withLogger(logger log.Logger) func(c *zk.Conn) {
23 return func(c *zk.Conn) {
24 c.SetLogger(wrapLogger{logger})
25 }
26 }
+0
-83
loadbalancer/zk/publisher.go less more
0 package zk
1
2 import (
3 "github.com/go-kit/kit/endpoint"
4 "github.com/go-kit/kit/loadbalancer"
5 "github.com/go-kit/kit/log"
6 "github.com/samuel/go-zookeeper/zk"
7 )
8
9 // Publisher yield endpoints stored in a certain ZooKeeper path. Any kind of
10 // change in that path is watched and will update the Publisher endpoints.
11 type Publisher struct {
12 client Client
13 path string
14 cache *loadbalancer.EndpointCache
15 logger log.Logger
16 quit chan struct{}
17 }
18
19 // NewPublisher returns a ZooKeeper publisher. ZooKeeper will start watching the
20 // given path for changes and update the Publisher endpoints.
21 func NewPublisher(c Client, path string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) {
22 p := &Publisher{
23 client: c,
24 path: path,
25 cache: loadbalancer.NewEndpointCache(f, logger),
26 logger: logger,
27 quit: make(chan struct{}),
28 }
29
30 err := p.client.CreateParentNodes(p.path)
31 if err != nil {
32 return nil, err
33 }
34
35 // initial node retrieval and cache fill
36 instances, eventc, err := p.client.GetEntries(p.path)
37 if err != nil {
38 logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err)
39 return nil, err
40 }
41 logger.Log("path", p.path, "instances", len(instances))
42 p.cache.Replace(instances)
43
44 // handle incoming path updates
45 go p.loop(eventc)
46
47 return p, nil
48 }
49
50 func (p *Publisher) loop(eventc <-chan zk.Event) {
51 var (
52 instances []string
53 err error
54 )
55 for {
56 select {
57 case <-eventc:
58 // we received a path update notification, call GetEntries to
59 // retrieve child node data and set new watch as zk watches are one
60 // time triggers
61 instances, eventc, err = p.client.GetEntries(p.path)
62 if err != nil {
63 p.logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err)
64 continue
65 }
66 p.logger.Log("path", p.path, "instances", len(instances))
67 p.cache.Replace(instances)
68 case <-p.quit:
69 return
70 }
71 }
72 }
73
74 // Endpoints implements the Publisher interface.
75 func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) {
76 return p.cache.Endpoints()
77 }
78
79 // Stop terminates the Publisher.
80 func (p *Publisher) Stop() {
81 close(p.quit)
82 }
+0
-116
loadbalancer/zk/publisher_test.go less more
0 package zk
1
2 import (
3 "testing"
4 "time"
5 )
6
7 func TestPublisher(t *testing.T) {
8 client := newFakeClient()
9
10 p, err := NewPublisher(client, path, newFactory(""), logger)
11 if err != nil {
12 t.Fatalf("failed to create new publisher: %v", err)
13 }
14 defer p.Stop()
15
16 if _, err := p.Endpoints(); err != nil {
17 t.Fatal(err)
18 }
19 }
20
21 func TestBadFactory(t *testing.T) {
22 client := newFakeClient()
23
24 p, err := NewPublisher(client, path, newFactory("kaboom"), logger)
25 if err != nil {
26 t.Fatalf("failed to create new publisher: %v", err)
27 }
28 defer p.Stop()
29
30 // instance1 came online
31 client.AddService(path+"/instance1", "kaboom")
32
33 // instance2 came online
34 client.AddService(path+"/instance2", "zookeeper_node_data")
35
36 if err = asyncTest(100*time.Millisecond, 1, p); err != nil {
37 t.Error(err)
38 }
39 }
40
41 func TestServiceUpdate(t *testing.T) {
42 client := newFakeClient()
43
44 p, err := NewPublisher(client, path, newFactory(""), logger)
45 if err != nil {
46 t.Fatalf("failed to create new publisher: %v", err)
47 }
48 defer p.Stop()
49
50 endpoints, err := p.Endpoints()
51 if err != nil {
52 t.Fatal(err)
53 }
54
55 if want, have := 0, len(endpoints); want != have {
56 t.Errorf("want %d, have %d", want, have)
57 }
58
59 // instance1 came online
60 client.AddService(path+"/instance1", "zookeeper_node_data")
61
62 // instance2 came online
63 client.AddService(path+"/instance2", "zookeeper_node_data2")
64
65 // we should have 2 instances
66 if err = asyncTest(100*time.Millisecond, 2, p); err != nil {
67 t.Error(err)
68 }
69
70 // watch triggers an error...
71 client.SendErrorOnWatch()
72
73 // test if error was consumed
74 if err = client.ErrorIsConsumed(100 * time.Millisecond); err != nil {
75 t.Error(err)
76 }
77
78 // instance3 came online
79 client.AddService(path+"/instance3", "zookeeper_node_data3")
80
81 // we should have 3 instances
82 if err = asyncTest(100*time.Millisecond, 3, p); err != nil {
83 t.Error(err)
84 }
85
86 // instance1 goes offline
87 client.RemoveService(path + "/instance1")
88
89 // instance2 goes offline
90 client.RemoveService(path + "/instance2")
91
92 // we should have 1 instance
93 if err = asyncTest(100*time.Millisecond, 1, p); err != nil {
94 t.Error(err)
95 }
96 }
97
98 func TestBadPublisherCreate(t *testing.T) {
99 client := newFakeClient()
100 client.SendErrorOnWatch()
101 p, err := NewPublisher(client, path, newFactory(""), logger)
102 if err == nil {
103 t.Error("expected error on new publisher")
104 }
105 if p != nil {
106 t.Error("expected publisher not to be created")
107 }
108 p, err = NewPublisher(client, "BadPath", newFactory(""), logger)
109 if err == nil {
110 t.Error("expected error on new publisher")
111 }
112 if p != nil {
113 t.Error("expected publisher not to be created")
114 }
115 }
+0
-127
loadbalancer/zk/util_test.go less more
0 package zk
1
2 import (
3 "errors"
4 "fmt"
5 "io"
6 "sync"
7 "time"
8
9 "github.com/samuel/go-zookeeper/zk"
10 "golang.org/x/net/context"
11
12 "github.com/go-kit/kit/endpoint"
13 "github.com/go-kit/kit/loadbalancer"
14 "github.com/go-kit/kit/log"
15 )
16
17 var (
18 path = "/gokit.test/service.name"
19 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
20 logger = log.NewNopLogger()
21 )
22
23 type fakeClient struct {
24 mtx sync.Mutex
25 ch chan zk.Event
26 responses map[string]string
27 result bool
28 }
29
30 func newFakeClient() *fakeClient {
31 return &fakeClient{
32 ch: make(chan zk.Event, 5),
33 responses: make(map[string]string),
34 result: true,
35 }
36 }
37
38 func (c *fakeClient) CreateParentNodes(path string) error {
39 if path == "BadPath" {
40 return errors.New("Dummy Error")
41 }
42 return nil
43 }
44
45 func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) {
46 c.mtx.Lock()
47 defer c.mtx.Unlock()
48 if c.result == false {
49 c.result = true
50 return []string{}, c.ch, errors.New("Dummy Error")
51 }
52 responses := []string{}
53 for _, data := range c.responses {
54 responses = append(responses, data)
55 }
56 return responses, c.ch, nil
57 }
58
59 func (c *fakeClient) AddService(node, data string) {
60 c.mtx.Lock()
61 defer c.mtx.Unlock()
62 c.responses[node] = data
63 c.ch <- zk.Event{}
64 }
65
66 func (c *fakeClient) RemoveService(node string) {
67 c.mtx.Lock()
68 defer c.mtx.Unlock()
69 delete(c.responses, node)
70 c.ch <- zk.Event{}
71 }
72
73 func (c *fakeClient) SendErrorOnWatch() {
74 c.mtx.Lock()
75 defer c.mtx.Unlock()
76 c.result = false
77 c.ch <- zk.Event{}
78 }
79
80 func (c *fakeClient) ErrorIsConsumed(t time.Duration) error {
81 timeout := time.After(t)
82 for {
83 select {
84 case <-timeout:
85 return fmt.Errorf("expected error not consumed after timeout %s", t.String())
86 default:
87 c.mtx.Lock()
88 if c.result == false {
89 c.mtx.Unlock()
90 return nil
91 }
92 c.mtx.Unlock()
93 }
94 }
95 }
96
97 func (c *fakeClient) Stop() {}
98
99 func newFactory(fakeError string) loadbalancer.Factory {
100 return func(instance string) (endpoint.Endpoint, io.Closer, error) {
101 if fakeError == instance {
102 return nil, nil, errors.New(fakeError)
103 }
104 return e, nil, nil
105 }
106 }
107
108 func asyncTest(timeout time.Duration, want int, p *Publisher) (err error) {
109 var endpoints []endpoint.Endpoint
110 // want can never be -1
111 have := -1
112 t := time.After(timeout)
113 for {
114 select {
115 case <-t:
116 return fmt.Errorf("want %d, have %d after timeout %s", want, have, timeout.String())
117 default:
118 endpoints, err = p.Endpoints()
119 have = len(endpoints)
120 if err != nil || want == have {
121 return
122 }
123 time.Sleep(time.Millisecond)
124 }
125 }
126 }