loadbalancer: rm
Peter Bourgon
7 years ago
0 | # package loadbalancer | |
1 | ||
2 | `package loadbalancer` provides a client-side load balancer abstraction. | |
3 | ||
4 | A publisher is responsible for emitting the most recent set of endpoints for a | |
5 | single logical service. Publishers exist for static endpoints, and endpoints | |
6 | discovered via periodic DNS SRV lookups on a single logical name. Consul and | |
7 | etcd publishers are planned. | |
8 | ||
9 | Different load balancers are implemented on top of publishers. Go kit | |
10 | currently provides random and round-robin load balancers. Smarter behaviors, | |
11 | e.g. load balancing based on underlying endpoint priority/weight, is planned. | |
12 | ||
13 | ## Rationale | |
14 | ||
15 | TODO | |
16 | ||
17 | ## Usage | |
18 | ||
19 | In your client, construct a publisher for a specific remote service, and pass | |
20 | it to a load balancer. Then, request an endpoint from the load balancer | |
21 | whenever you need to make a request to that remote service. | |
22 | ||
23 | ```go | |
24 | import ( | |
25 | "github.com/go-kit/kit/loadbalancer" | |
26 | "github.com/go-kit/kit/loadbalancer/dnssrv" | |
27 | ) | |
28 | ||
29 | func main() { | |
30 | // Construct a load balancer for foosvc, which gets foosvc instances by | |
31 | // polling a specific DNS SRV name. | |
32 | p, err := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) | |
33 | if err != nil { | |
34 | panic(err) | |
35 | } | |
36 | ||
37 | lb := loadbalancer.NewRoundRobin(p) | |
38 | ||
39 | // Get a new endpoint from the load balancer. | |
40 | endpoint, err := lb.Endpoint() | |
41 | if err != nil { | |
42 | panic(err) | |
43 | } | |
44 | ||
45 | // Use the endpoint to make a request. | |
46 | response, err := endpoint(ctx, request) | |
47 | } | |
48 | ||
49 | func fooFactory(instance string) (endpoint.Endpoint, error) { | |
50 | // Convert an instance (host:port) to an endpoint, via a defined transport binding. | |
51 | } | |
52 | ``` | |
53 | ||
54 | It's also possible to wrap a load balancer with a retry strategy, so that it | |
55 | can be used as an endpoint directly. This may make load balancers more | |
56 | convenient to use, at the cost of fine-grained control of failures. | |
57 | ||
58 | ```go | |
59 | func main() { | |
60 | p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) | |
61 | lb := loadbalancer.NewRoundRobin(p) | |
62 | endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb) | |
63 | ||
64 | response, err := endpoint(ctx, request) // requests will be automatically load balanced | |
65 | } | |
66 | ``` |
0 | package consul | |
1 | ||
2 | import consul "github.com/hashicorp/consul/api" | |
3 | ||
4 | // Client is a wrapper around the Consul API. | |
5 | type Client interface { | |
6 | Service(service string, tag string, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) | |
7 | } | |
8 | ||
9 | type client struct { | |
10 | consul *consul.Client | |
11 | } | |
12 | ||
13 | // NewClient returns an implementation of the Client interface expecting a fully | |
14 | // setup Consul Client. | |
15 | func NewClient(c *consul.Client) Client { | |
16 | return &client{ | |
17 | consul: c, | |
18 | } | |
19 | } | |
20 | ||
21 | // GetInstances returns the list of healthy entries for a given service filtered | |
22 | // by tag. | |
23 | func (c *client) Service( | |
24 | service string, | |
25 | tag string, | |
26 | opts *consul.QueryOptions, | |
27 | ) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { | |
28 | return c.consul.Health().Service(service, tag, true, opts) | |
29 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | ||
6 | consul "github.com/hashicorp/consul/api" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer" | |
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | const defaultIndex = 0 | |
14 | ||
15 | // Publisher yields endpoints for a service in Consul. Updates to the service | |
16 | // are watched and will update the Publisher endpoints. | |
17 | type Publisher struct { | |
18 | cache *loadbalancer.EndpointCache | |
19 | client Client | |
20 | logger log.Logger | |
21 | service string | |
22 | tags []string | |
23 | endpointsc chan []endpoint.Endpoint | |
24 | quitc chan struct{} | |
25 | } | |
26 | ||
27 | // NewPublisher returns a Consul publisher which returns Endpoints for the | |
28 | // requested service. It only returns instances for which all of the passed | |
29 | // tags are present. | |
30 | func NewPublisher( | |
31 | client Client, | |
32 | factory loadbalancer.Factory, | |
33 | logger log.Logger, | |
34 | service string, | |
35 | tags ...string, | |
36 | ) (*Publisher, error) { | |
37 | p := &Publisher{ | |
38 | cache: loadbalancer.NewEndpointCache(factory, logger), | |
39 | client: client, | |
40 | logger: logger, | |
41 | service: service, | |
42 | tags: tags, | |
43 | quitc: make(chan struct{}), | |
44 | } | |
45 | ||
46 | instances, index, err := p.getInstances(defaultIndex) | |
47 | if err == nil { | |
48 | logger.Log("service", service, "tags", strings.Join(tags, ", "), "instances", len(instances)) | |
49 | } else { | |
50 | logger.Log("service", service, "tags", strings.Join(tags, ", "), "err", err) | |
51 | } | |
52 | p.cache.Replace(instances) | |
53 | ||
54 | go p.loop(index) | |
55 | ||
56 | return p, nil | |
57 | } | |
58 | ||
59 | // Endpoints implements the Publisher interface. | |
60 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
61 | return p.cache.Endpoints() | |
62 | } | |
63 | ||
64 | // Stop terminates the publisher. | |
65 | func (p *Publisher) Stop() { | |
66 | close(p.quitc) | |
67 | } | |
68 | ||
69 | func (p *Publisher) loop(lastIndex uint64) { | |
70 | var ( | |
71 | errc = make(chan error, 1) | |
72 | resc = make(chan response, 1) | |
73 | ) | |
74 | ||
75 | for { | |
76 | go func() { | |
77 | instances, index, err := p.getInstances(lastIndex) | |
78 | if err != nil { | |
79 | errc <- err | |
80 | return | |
81 | } | |
82 | resc <- response{ | |
83 | index: index, | |
84 | instances: instances, | |
85 | } | |
86 | }() | |
87 | ||
88 | select { | |
89 | case err := <-errc: | |
90 | p.logger.Log("service", p.service, "err", err) | |
91 | case res := <-resc: | |
92 | p.cache.Replace(res.instances) | |
93 | lastIndex = res.index | |
94 | case <-p.quitc: | |
95 | return | |
96 | } | |
97 | } | |
98 | } | |
99 | ||
100 | func (p *Publisher) getInstances(lastIndex uint64) ([]string, uint64, error) { | |
101 | tag := "" | |
102 | ||
103 | if len(p.tags) > 0 { | |
104 | tag = p.tags[0] | |
105 | } | |
106 | ||
107 | entries, meta, err := p.client.Service( | |
108 | p.service, | |
109 | tag, | |
110 | &consul.QueryOptions{ | |
111 | WaitIndex: lastIndex, | |
112 | }, | |
113 | ) | |
114 | if err != nil { | |
115 | return nil, 0, err | |
116 | } | |
117 | ||
118 | // If more than one tag is passed we need to filter it in the publisher until | |
119 | // Consul supports multiple tags[0]. | |
120 | // | |
121 | // [0] https://github.com/hashicorp/consul/issues/294 | |
122 | if len(p.tags) > 1 { | |
123 | entries = filterEntries(entries, p.tags[1:]...) | |
124 | } | |
125 | ||
126 | return makeInstances(entries), meta.LastIndex, nil | |
127 | } | |
128 | ||
129 | // response is used as container to transport instances as well as the updated | |
130 | // index. | |
131 | type response struct { | |
132 | index uint64 | |
133 | instances []string | |
134 | } | |
135 | ||
136 | func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry { | |
137 | var es []*consul.ServiceEntry | |
138 | ||
139 | ENTRIES: | |
140 | for _, entry := range entries { | |
141 | ts := make(map[string]struct{}, len(entry.Service.Tags)) | |
142 | ||
143 | for _, tag := range entry.Service.Tags { | |
144 | ts[tag] = struct{}{} | |
145 | } | |
146 | ||
147 | for _, tag := range tags { | |
148 | if _, ok := ts[tag]; !ok { | |
149 | continue ENTRIES | |
150 | } | |
151 | } | |
152 | ||
153 | es = append(es, entry) | |
154 | } | |
155 | ||
156 | return es | |
157 | } | |
158 | ||
159 | func makeInstances(entries []*consul.ServiceEntry) []string { | |
160 | instances := make([]string, len(entries)) | |
161 | ||
162 | for i, entry := range entries { | |
163 | addr := entry.Node.Address | |
164 | ||
165 | if entry.Service.Address != "" { | |
166 | addr = entry.Service.Address | |
167 | } | |
168 | ||
169 | instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port) | |
170 | } | |
171 | ||
172 | return instances | |
173 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "testing" | |
5 | ||
6 | consul "github.com/hashicorp/consul/api" | |
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | var consulState = []*consul.ServiceEntry{ | |
14 | { | |
15 | Node: &consul.Node{ | |
16 | Address: "10.0.0.0", | |
17 | Node: "app00.local", | |
18 | }, | |
19 | Service: &consul.AgentService{ | |
20 | ID: "search-api-0", | |
21 | Port: 8000, | |
22 | Service: "search", | |
23 | Tags: []string{ | |
24 | "api", | |
25 | "v1", | |
26 | }, | |
27 | }, | |
28 | }, | |
29 | { | |
30 | Node: &consul.Node{ | |
31 | Address: "10.0.0.1", | |
32 | Node: "app01.local", | |
33 | }, | |
34 | Service: &consul.AgentService{ | |
35 | ID: "search-api-1", | |
36 | Port: 8001, | |
37 | Service: "search", | |
38 | Tags: []string{ | |
39 | "api", | |
40 | "v2", | |
41 | }, | |
42 | }, | |
43 | }, | |
44 | { | |
45 | Node: &consul.Node{ | |
46 | Address: "10.0.0.1", | |
47 | Node: "app01.local", | |
48 | }, | |
49 | Service: &consul.AgentService{ | |
50 | Address: "10.0.0.10", | |
51 | ID: "search-db-0", | |
52 | Port: 9000, | |
53 | Service: "search", | |
54 | Tags: []string{ | |
55 | "db", | |
56 | }, | |
57 | }, | |
58 | }, | |
59 | } | |
60 | ||
61 | func TestPublisher(t *testing.T) { | |
62 | var ( | |
63 | logger = log.NewNopLogger() | |
64 | client = newTestClient(consulState) | |
65 | ) | |
66 | ||
67 | p, err := NewPublisher(client, testFactory, logger, "search", "api") | |
68 | if err != nil { | |
69 | t.Fatalf("publisher setup failed: %s", err) | |
70 | } | |
71 | defer p.Stop() | |
72 | ||
73 | eps, err := p.Endpoints() | |
74 | if err != nil { | |
75 | t.Fatalf("endpoints failed: %s", err) | |
76 | } | |
77 | ||
78 | if have, want := len(eps), 2; have != want { | |
79 | t.Errorf("have %v, want %v", have, want) | |
80 | } | |
81 | } | |
82 | ||
83 | func TestPublisherNoService(t *testing.T) { | |
84 | var ( | |
85 | logger = log.NewNopLogger() | |
86 | client = newTestClient(consulState) | |
87 | ) | |
88 | ||
89 | p, err := NewPublisher(client, testFactory, logger, "feed") | |
90 | if err != nil { | |
91 | t.Fatalf("publisher setup failed: %s", err) | |
92 | } | |
93 | defer p.Stop() | |
94 | ||
95 | eps, err := p.Endpoints() | |
96 | if err != nil { | |
97 | t.Fatalf("endpoints failed: %s", err) | |
98 | } | |
99 | ||
100 | if have, want := len(eps), 0; have != want { | |
101 | t.Fatalf("have %v, want %v", have, want) | |
102 | } | |
103 | } | |
104 | ||
105 | func TestPublisherWithTags(t *testing.T) { | |
106 | var ( | |
107 | logger = log.NewNopLogger() | |
108 | client = newTestClient(consulState) | |
109 | ) | |
110 | ||
111 | p, err := NewPublisher(client, testFactory, logger, "search", "api", "v2") | |
112 | if err != nil { | |
113 | t.Fatalf("publisher setup failed: %s", err) | |
114 | } | |
115 | defer p.Stop() | |
116 | ||
117 | eps, err := p.Endpoints() | |
118 | if err != nil { | |
119 | t.Fatalf("endpoints failed: %s", err) | |
120 | } | |
121 | ||
122 | if have, want := len(eps), 1; have != want { | |
123 | t.Fatalf("have %v, want %v", have, want) | |
124 | } | |
125 | } | |
126 | ||
127 | func TestPublisherAddressOverride(t *testing.T) { | |
128 | var ( | |
129 | ctx = context.Background() | |
130 | logger = log.NewNopLogger() | |
131 | client = newTestClient(consulState) | |
132 | ) | |
133 | ||
134 | p, err := NewPublisher(client, testFactory, logger, "search", "db") | |
135 | if err != nil { | |
136 | t.Fatalf("publisher setup failed: %s", err) | |
137 | } | |
138 | defer p.Stop() | |
139 | ||
140 | eps, err := p.Endpoints() | |
141 | if err != nil { | |
142 | t.Fatalf("endpoints failed: %s", err) | |
143 | } | |
144 | ||
145 | if have, want := len(eps), 1; have != want { | |
146 | t.Fatalf("have %v, want %v", have, want) | |
147 | } | |
148 | ||
149 | ins, err := eps[0](ctx, struct{}{}) | |
150 | if err != nil { | |
151 | t.Fatal(err) | |
152 | } | |
153 | ||
154 | if have, want := ins.(string), "10.0.0.10:9000"; have != want { | |
155 | t.Errorf("have %#v, want %#v", have, want) | |
156 | } | |
157 | } | |
158 | ||
159 | type testClient struct { | |
160 | entries []*consul.ServiceEntry | |
161 | } | |
162 | ||
163 | func newTestClient(entries []*consul.ServiceEntry) Client { | |
164 | if entries == nil { | |
165 | entries = []*consul.ServiceEntry{} | |
166 | } | |
167 | ||
168 | return &testClient{ | |
169 | entries: entries, | |
170 | } | |
171 | } | |
172 | ||
173 | func (c *testClient) Service( | |
174 | service string, | |
175 | tag string, | |
176 | opts *consul.QueryOptions, | |
177 | ) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { | |
178 | es := []*consul.ServiceEntry{} | |
179 | ||
180 | for _, e := range c.entries { | |
181 | if e.Service.Service != service { | |
182 | continue | |
183 | } | |
184 | if tag != "" { | |
185 | tagMap := map[string]struct{}{} | |
186 | ||
187 | for _, t := range e.Service.Tags { | |
188 | tagMap[t] = struct{}{} | |
189 | } | |
190 | ||
191 | if _, ok := tagMap[tag]; !ok { | |
192 | continue | |
193 | } | |
194 | } | |
195 | ||
196 | es = append(es, e) | |
197 | } | |
198 | ||
199 | return es, &consul.QueryMeta{}, nil | |
200 | } | |
201 | ||
202 | func testFactory(ins string) (endpoint.Endpoint, io.Closer, error) { | |
203 | return func(context.Context, interface{}) (interface{}, error) { | |
204 | return ins, nil | |
205 | }, nil, nil | |
206 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/loadbalancer" | |
9 | "github.com/go-kit/kit/log" | |
10 | ) | |
11 | ||
12 | // Publisher yields endpoints taken from the named DNS SRV record. The name is | |
13 | // resolved on a fixed schedule. Priorities and weights are ignored. | |
14 | type Publisher struct { | |
15 | name string | |
16 | cache *loadbalancer.EndpointCache | |
17 | logger log.Logger | |
18 | quit chan struct{} | |
19 | } | |
20 | ||
21 | // NewPublisher returns a DNS SRV publisher. The name is resolved | |
22 | // synchronously as part of construction; if that resolution fails, the | |
23 | // constructor will return an error. The factory is used to convert a | |
24 | // host:port to a usable endpoint. The logger is used to report DNS and | |
25 | // factory errors. | |
26 | func NewPublisher( | |
27 | name string, | |
28 | ttl time.Duration, | |
29 | factory loadbalancer.Factory, | |
30 | logger log.Logger, | |
31 | ) *Publisher { | |
32 | return NewPublisherDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) | |
33 | } | |
34 | ||
35 | // NewPublisherDetailed is the same as NewPublisher, but allows users to provide | |
36 | // an explicit lookup refresh ticker instead of a TTL, and specify the function | |
37 | // used to perform lookups instead of using net.LookupSRV. | |
38 | func NewPublisherDetailed( | |
39 | name string, | |
40 | refreshTicker *time.Ticker, | |
41 | lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), | |
42 | factory loadbalancer.Factory, | |
43 | logger log.Logger, | |
44 | ) *Publisher { | |
45 | p := &Publisher{ | |
46 | name: name, | |
47 | cache: loadbalancer.NewEndpointCache(factory, logger), | |
48 | logger: logger, | |
49 | quit: make(chan struct{}), | |
50 | } | |
51 | ||
52 | instances, err := p.resolve(lookupSRV) | |
53 | if err == nil { | |
54 | logger.Log("name", name, "instances", len(instances)) | |
55 | } else { | |
56 | logger.Log("name", name, "err", err) | |
57 | } | |
58 | p.cache.Replace(instances) | |
59 | ||
60 | go p.loop(refreshTicker, lookupSRV) | |
61 | return p | |
62 | } | |
63 | ||
64 | // Stop terminates the publisher. | |
65 | func (p *Publisher) Stop() { | |
66 | close(p.quit) | |
67 | } | |
68 | ||
69 | func (p *Publisher) loop( | |
70 | refreshTicker *time.Ticker, | |
71 | lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), | |
72 | ) { | |
73 | defer refreshTicker.Stop() | |
74 | for { | |
75 | select { | |
76 | case <-refreshTicker.C: | |
77 | instances, err := p.resolve(lookupSRV) | |
78 | if err != nil { | |
79 | p.logger.Log(p.name, err) | |
80 | continue // don't replace potentially-good with bad | |
81 | } | |
82 | p.cache.Replace(instances) | |
83 | ||
84 | case <-p.quit: | |
85 | return | |
86 | } | |
87 | } | |
88 | } | |
89 | ||
90 | // Endpoints implements the Publisher interface. | |
91 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
92 | return p.cache.Endpoints() | |
93 | } | |
94 | ||
95 | func (p *Publisher) resolve(lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error)) ([]string, error) { | |
96 | _, addrs, err := lookupSRV("", "", p.name) | |
97 | if err != nil { | |
98 | return []string{}, err | |
99 | } | |
100 | instances := make([]string, len(addrs)) | |
101 | for i, addr := range addrs { | |
102 | instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
103 | } | |
104 | return instances, nil | |
105 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "net" | |
6 | "sync/atomic" | |
7 | "testing" | |
8 | "time" | |
9 | ||
10 | "golang.org/x/net/context" | |
11 | ||
12 | "github.com/go-kit/kit/endpoint" | |
13 | "github.com/go-kit/kit/log" | |
14 | ) | |
15 | ||
16 | func TestPublisher(t *testing.T) { | |
17 | var ( | |
18 | name = "foo" | |
19 | ttl = time.Second | |
20 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
21 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
22 | logger = log.NewNopLogger() | |
23 | ) | |
24 | ||
25 | p := NewPublisher(name, ttl, factory, logger) | |
26 | defer p.Stop() | |
27 | ||
28 | if _, err := p.Endpoints(); err != nil { | |
29 | t.Fatal(err) | |
30 | } | |
31 | } | |
32 | ||
33 | func TestBadLookup(t *testing.T) { | |
34 | var ( | |
35 | name = "some-name" | |
36 | ticker = time.NewTicker(time.Second) | |
37 | lookups = uint32(0) | |
38 | lookupSRV = func(string, string, string) (string, []*net.SRV, error) { | |
39 | atomic.AddUint32(&lookups, 1) | |
40 | return "", nil, errors.New("kaboom") | |
41 | } | |
42 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
43 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
44 | logger = log.NewNopLogger() | |
45 | ) | |
46 | ||
47 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
48 | defer p.Stop() | |
49 | ||
50 | endpoints, err := p.Endpoints() | |
51 | if err != nil { | |
52 | t.Error(err) | |
53 | } | |
54 | if want, have := 0, len(endpoints); want != have { | |
55 | t.Errorf("want %d, have %d", want, have) | |
56 | } | |
57 | if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { | |
58 | t.Errorf("want %d, have %d", want, have) | |
59 | } | |
60 | } | |
61 | ||
62 | func TestBadFactory(t *testing.T) { | |
63 | var ( | |
64 | name = "some-name" | |
65 | ticker = time.NewTicker(time.Second) | |
66 | addr = &net.SRV{Target: "foo", Port: 1234} | |
67 | addrs = []*net.SRV{addr} | |
68 | lookupSRV = func(a, b, c string) (string, []*net.SRV, error) { return "", addrs, nil } | |
69 | creates = uint32(0) | |
70 | factory = func(s string) (endpoint.Endpoint, io.Closer, error) { | |
71 | atomic.AddUint32(&creates, 1) | |
72 | return nil, nil, errors.New("kaboom") | |
73 | } | |
74 | logger = log.NewNopLogger() | |
75 | ) | |
76 | ||
77 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
78 | defer p.Stop() | |
79 | ||
80 | endpoints, err := p.Endpoints() | |
81 | if err != nil { | |
82 | t.Error(err) | |
83 | } | |
84 | if want, have := 0, len(endpoints); want != have { | |
85 | t.Errorf("want %q, have %q", want, have) | |
86 | } | |
87 | if want, have := uint32(1), atomic.LoadUint32(&creates); want != have { | |
88 | t.Errorf("want %d, have %d", want, have) | |
89 | } | |
90 | } | |
91 | ||
92 | func TestRefreshWithChange(t *testing.T) { | |
93 | t.Skip("TODO") | |
94 | } | |
95 | ||
96 | func TestRefreshNoChange(t *testing.T) { | |
97 | var ( | |
98 | addr = &net.SRV{Target: "my-target", Port: 5678} | |
99 | addrs = []*net.SRV{addr} | |
100 | name = "my-name" | |
101 | ticker = time.NewTicker(time.Second) | |
102 | lookups = uint32(0) | |
103 | lookupSRV = func(string, string, string) (string, []*net.SRV, error) { | |
104 | atomic.AddUint32(&lookups, 1) | |
105 | return "", addrs, nil | |
106 | } | |
107 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
108 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
109 | logger = log.NewNopLogger() | |
110 | ) | |
111 | ||
112 | ticker.Stop() | |
113 | tickc := make(chan time.Time) | |
114 | ticker.C = tickc | |
115 | ||
116 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
117 | defer p.Stop() | |
118 | ||
119 | if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { | |
120 | t.Errorf("want %d, have %d", want, have) | |
121 | } | |
122 | ||
123 | tickc <- time.Now() | |
124 | ||
125 | if want, have := uint32(2), atomic.LoadUint32(&lookups); want != have { | |
126 | t.Errorf("want %d, have %d", want, have) | |
127 | } | |
128 | } | |
129 | ||
130 | func TestRefreshResolveError(t *testing.T) { | |
131 | t.Skip("TODO") | |
132 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "sort" | |
5 | "sync" | |
6 | "sync/atomic" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/log" | |
10 | ) | |
11 | ||
12 | // EndpointCache caches endpoints that need to be deallocated when they're no | |
13 | // longer useful. Clients update the cache by providing a current set of | |
14 | // instance strings. The cache converts each instance string to an endpoint | |
15 | // and a closer via the factory function. | |
16 | // | |
17 | // Instance strings are assumed to be unique and are used as keys. Endpoints | |
18 | // that were in the previous set of instances and are not in the current set | |
19 | // are considered invalid and closed. | |
20 | // | |
21 | // EndpointCache is designed to be used in your publisher implementation. | |
22 | type EndpointCache struct { | |
23 | mtx sync.Mutex | |
24 | f Factory | |
25 | m map[string]endpointCloser | |
26 | cache atomic.Value //[]endpoint.Endpoint | |
27 | logger log.Logger | |
28 | } | |
29 | ||
30 | // NewEndpointCache produces a new EndpointCache, ready for use. Instance | |
31 | // strings will be converted to endpoints via the provided factory function. | |
32 | // The logger is used to log errors. | |
33 | func NewEndpointCache(f Factory, logger log.Logger) *EndpointCache { | |
34 | endpointCache := &EndpointCache{ | |
35 | f: f, | |
36 | m: map[string]endpointCloser{}, | |
37 | logger: log.NewContext(logger).With("component", "Endpoint Cache"), | |
38 | } | |
39 | ||
40 | endpointCache.cache.Store(make([]endpoint.Endpoint, 0)) | |
41 | ||
42 | return endpointCache | |
43 | } | |
44 | ||
45 | type endpointCloser struct { | |
46 | endpoint.Endpoint | |
47 | io.Closer | |
48 | } | |
49 | ||
50 | // Replace replaces the current set of endpoints with endpoints manufactured | |
51 | // by the passed instances. If the same instance exists in both the existing | |
52 | // and new sets, it's left untouched. | |
53 | func (t *EndpointCache) Replace(instances []string) { | |
54 | t.mtx.Lock() | |
55 | defer t.mtx.Unlock() | |
56 | ||
57 | // Produce the current set of endpoints. | |
58 | oldMap := t.m | |
59 | t.m = make(map[string]endpointCloser, len(instances)) | |
60 | for _, instance := range instances { | |
61 | // If it already exists, just copy it over. | |
62 | if ec, ok := oldMap[instance]; ok { | |
63 | t.m[instance] = ec | |
64 | delete(oldMap, instance) | |
65 | continue | |
66 | } | |
67 | ||
68 | // If it doesn't exist, create it. | |
69 | endpoint, closer, err := t.f(instance) | |
70 | if err != nil { | |
71 | t.logger.Log("instance", instance, "err", err) | |
72 | continue | |
73 | } | |
74 | t.m[instance] = endpointCloser{endpoint, closer} | |
75 | } | |
76 | ||
77 | t.refreshCache() | |
78 | ||
79 | // Close any leftover endpoints. | |
80 | for _, ec := range oldMap { | |
81 | if ec.Closer != nil { | |
82 | ec.Closer.Close() | |
83 | } | |
84 | } | |
85 | } | |
86 | ||
87 | func (t *EndpointCache) refreshCache() { | |
88 | var ( | |
89 | length = len(t.m) | |
90 | instances = make([]string, 0, length) | |
91 | newCache = make([]endpoint.Endpoint, 0, length) | |
92 | ) | |
93 | ||
94 | for instance, _ := range t.m { | |
95 | instances = append(instances, instance) | |
96 | } | |
97 | // Sort the instances for ensuring that Endpoints are returned into the same order if no modified. | |
98 | sort.Strings(instances) | |
99 | ||
100 | for _, instance := range instances { | |
101 | newCache = append(newCache, t.m[instance].Endpoint) | |
102 | } | |
103 | ||
104 | t.cache.Store(newCache) | |
105 | } | |
106 | ||
107 | // Endpoints returns the current set of endpoints in undefined order. Satisfies | |
108 | // Publisher interface. | |
109 | func (t *EndpointCache) Endpoints() ([]endpoint.Endpoint, error) { | |
110 | return t.cache.Load().([]endpoint.Endpoint), nil | |
111 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer" | |
11 | "github.com/go-kit/kit/log" | |
12 | ) | |
13 | ||
14 | func TestEndpointCache(t *testing.T) { | |
15 | var ( | |
16 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
17 | ca = make(closer) | |
18 | cb = make(closer) | |
19 | c = map[string]io.Closer{"a": ca, "b": cb} | |
20 | f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil } | |
21 | ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger()) | |
22 | ) | |
23 | ||
24 | // Populate | |
25 | ec.Replace([]string{"a", "b"}) | |
26 | select { | |
27 | case <-ca: | |
28 | t.Errorf("endpoint a closed, not good") | |
29 | case <-cb: | |
30 | t.Errorf("endpoint b closed, not good") | |
31 | case <-time.After(time.Millisecond): | |
32 | t.Logf("no closures yet, good") | |
33 | } | |
34 | ||
35 | // Duplicate, should be no-op | |
36 | ec.Replace([]string{"a", "b"}) | |
37 | select { | |
38 | case <-ca: | |
39 | t.Errorf("endpoint a closed, not good") | |
40 | case <-cb: | |
41 | t.Errorf("endpoint b closed, not good") | |
42 | case <-time.After(time.Millisecond): | |
43 | t.Logf("no closures yet, good") | |
44 | } | |
45 | ||
46 | // Delete b | |
47 | go ec.Replace([]string{"a"}) | |
48 | select { | |
49 | case <-ca: | |
50 | t.Errorf("endpoint a closed, not good") | |
51 | case <-cb: | |
52 | t.Logf("endpoint b closed, good") | |
53 | case <-time.After(time.Millisecond): | |
54 | t.Errorf("didn't close the deleted instance in time") | |
55 | } | |
56 | ||
57 | // Delete a | |
58 | go ec.Replace([]string{""}) | |
59 | select { | |
60 | // case <-cb: will succeed, as it's closed | |
61 | case <-ca: | |
62 | t.Logf("endpoint a closed, good") | |
63 | case <-time.After(time.Millisecond): | |
64 | t.Errorf("didn't close the deleted instance in time") | |
65 | } | |
66 | } | |
67 | ||
68 | type closer chan struct{} | |
69 | ||
70 | func (c closer) Close() error { close(c); return nil } | |
71 | ||
72 | func BenchmarkEndpoints(b *testing.B) { | |
73 | var ( | |
74 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
75 | ca = make(closer) | |
76 | cb = make(closer) | |
77 | c = map[string]io.Closer{"a": ca, "b": cb} | |
78 | f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil } | |
79 | ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger()) | |
80 | ) | |
81 | ||
82 | b.ReportAllocs() | |
83 | ||
84 | ec.Replace([]string{"a", "b"}) | |
85 | ||
86 | b.RunParallel(func(pb *testing.PB) { | |
87 | for pb.Next() { | |
88 | ec.Endpoints() | |
89 | } | |
90 | }) | |
91 | }⏎ |
0 | package etcd | |
1 | ||
2 | import ( | |
3 | "crypto/tls" | |
4 | "crypto/x509" | |
5 | "io/ioutil" | |
6 | "net" | |
7 | "net/http" | |
8 | "time" | |
9 | ||
10 | etcd "github.com/coreos/etcd/client" | |
11 | "golang.org/x/net/context" | |
12 | ) | |
13 | ||
14 | // Client is a wrapper around the etcd client. | |
15 | type Client interface { | |
16 | // GetEntries will query the given prefix in etcd and returns a set of entries. | |
17 | GetEntries(prefix string) ([]string, error) | |
18 | // WatchPrefix starts watching every change for given prefix in etcd. When an | |
19 | // change is detected it will populate the responseChan when an *etcd.Response. | |
20 | WatchPrefix(prefix string, responseChan chan *etcd.Response) | |
21 | } | |
22 | ||
23 | type client struct { | |
24 | keysAPI etcd.KeysAPI | |
25 | ctx context.Context | |
26 | } | |
27 | ||
28 | type ClientOptions struct { | |
29 | Cert string | |
30 | Key string | |
31 | CaCert string | |
32 | DialTimeout time.Duration | |
33 | DialKeepAline time.Duration | |
34 | HeaderTimeoutPerRequest time.Duration | |
35 | } | |
36 | ||
37 | // NewClient returns an *etcd.Client with a connection to the named machines. | |
38 | // It will return an error if a connection to the cluster cannot be made. | |
39 | // The parameter machines needs to be a full URL with schemas. | |
40 | // e.g. "http://localhost:2379" will work, but "localhost:2379" will not. | |
41 | func NewClient(ctx context.Context, machines []string, options *ClientOptions) (Client, error) { | |
42 | var ( | |
43 | c etcd.KeysAPI | |
44 | err error | |
45 | caCertCt []byte | |
46 | tlsCert tls.Certificate | |
47 | ) | |
48 | if options == nil { | |
49 | options = &ClientOptions{} | |
50 | } | |
51 | ||
52 | if options.Cert != "" && options.Key != "" { | |
53 | tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key) | |
54 | if err != nil { | |
55 | return nil, err | |
56 | } | |
57 | ||
58 | caCertCt, err = ioutil.ReadFile(options.CaCert) | |
59 | if err != nil { | |
60 | return nil, err | |
61 | } | |
62 | caCertPool := x509.NewCertPool() | |
63 | caCertPool.AppendCertsFromPEM(caCertCt) | |
64 | ||
65 | tlsConfig := &tls.Config{ | |
66 | Certificates: []tls.Certificate{tlsCert}, | |
67 | RootCAs: caCertPool, | |
68 | } | |
69 | ||
70 | transport := &http.Transport{ | |
71 | TLSClientConfig: tlsConfig, | |
72 | Dial: func(network, addr string) (net.Conn, error) { | |
73 | dial := &net.Dialer{ | |
74 | Timeout: options.DialTimeout, | |
75 | KeepAlive: options.DialKeepAline, | |
76 | } | |
77 | return dial.Dial(network, addr) | |
78 | }, | |
79 | } | |
80 | ||
81 | cfg := etcd.Config{ | |
82 | Endpoints: machines, | |
83 | Transport: transport, | |
84 | HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, | |
85 | } | |
86 | ce, err := etcd.New(cfg) | |
87 | if err != nil { | |
88 | return nil, err | |
89 | } | |
90 | c = etcd.NewKeysAPI(ce) | |
91 | } else { | |
92 | cfg := etcd.Config{ | |
93 | Endpoints: machines, | |
94 | Transport: etcd.DefaultTransport, | |
95 | HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, | |
96 | } | |
97 | ce, err := etcd.New(cfg) | |
98 | if err != nil { | |
99 | return nil, err | |
100 | } | |
101 | c = etcd.NewKeysAPI(ce) | |
102 | } | |
103 | return &client{c, ctx}, nil | |
104 | } | |
105 | ||
106 | // GetEntries implements the etcd Client interface. | |
107 | func (c *client) GetEntries(key string) ([]string, error) { | |
108 | resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true}) | |
109 | if err != nil { | |
110 | return nil, err | |
111 | } | |
112 | ||
113 | entries := make([]string, len(resp.Node.Nodes)) | |
114 | for i, node := range resp.Node.Nodes { | |
115 | entries[i] = node.Value | |
116 | } | |
117 | return entries, nil | |
118 | } | |
119 | ||
120 | // WatchPrefix implements the etcd Client interface. | |
121 | func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) { | |
122 | watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true}) | |
123 | for { | |
124 | res, err := watch.Next(c.ctx) | |
125 | if err != nil { | |
126 | return | |
127 | } | |
128 | responseChan <- res | |
129 | } | |
130 | } |
0 | package etcd | |
1 | ||
2 | import ( | |
3 | etcd "github.com/coreos/etcd/client" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/loadbalancer" | |
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | // Publisher yield endpoints stored in a certain etcd keyspace. Any kind of | |
11 | // change in that keyspace is watched and will update the Publisher endpoints. | |
12 | type Publisher struct { | |
13 | client Client | |
14 | prefix string | |
15 | cache *loadbalancer.EndpointCache | |
16 | logger log.Logger | |
17 | quit chan struct{} | |
18 | } | |
19 | ||
20 | // NewPublisher returs a etcd publisher. Etcd will start watching the given | |
21 | // prefix for changes and update the Publisher endpoints. | |
22 | func NewPublisher(c Client, prefix string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { | |
23 | p := &Publisher{ | |
24 | client: c, | |
25 | prefix: prefix, | |
26 | cache: loadbalancer.NewEndpointCache(f, logger), | |
27 | logger: logger, | |
28 | quit: make(chan struct{}), | |
29 | } | |
30 | ||
31 | instances, err := p.client.GetEntries(p.prefix) | |
32 | if err == nil { | |
33 | logger.Log("prefix", p.prefix, "instances", len(instances)) | |
34 | } else { | |
35 | logger.Log("prefix", p.prefix, "err", err) | |
36 | } | |
37 | p.cache.Replace(instances) | |
38 | ||
39 | go p.loop() | |
40 | return p, nil | |
41 | } | |
42 | ||
43 | func (p *Publisher) loop() { | |
44 | responseChan := make(chan *etcd.Response) | |
45 | go p.client.WatchPrefix(p.prefix, responseChan) | |
46 | for { | |
47 | select { | |
48 | case <-responseChan: | |
49 | instances, err := p.client.GetEntries(p.prefix) | |
50 | if err != nil { | |
51 | p.logger.Log("msg", "failed to retrieve entries", "err", err) | |
52 | continue | |
53 | } | |
54 | p.cache.Replace(instances) | |
55 | ||
56 | case <-p.quit: | |
57 | return | |
58 | } | |
59 | } | |
60 | } | |
61 | ||
62 | // Endpoints implements the Publisher interface. | |
63 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
64 | return p.cache.Endpoints() | |
65 | } | |
66 | ||
67 | // Stop terminates the Publisher. | |
68 | func (p *Publisher) Stop() { | |
69 | close(p.quit) | |
70 | } |
0 | package etcd_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "testing" | |
6 | ||
7 | stdetcd "github.com/coreos/etcd/client" | |
8 | "golang.org/x/net/context" | |
9 | ||
10 | "github.com/go-kit/kit/endpoint" | |
11 | kitetcd "github.com/go-kit/kit/loadbalancer/etcd" | |
12 | "github.com/go-kit/kit/log" | |
13 | ) | |
14 | ||
15 | var ( | |
16 | node = &stdetcd.Node{ | |
17 | Key: "/foo", | |
18 | Nodes: []*stdetcd.Node{ | |
19 | {Key: "/foo/1", Value: "1:1"}, | |
20 | {Key: "/foo/2", Value: "1:2"}, | |
21 | }, | |
22 | } | |
23 | fakeResponse = &stdetcd.Response{ | |
24 | Node: node, | |
25 | } | |
26 | ) | |
27 | ||
28 | func TestPublisher(t *testing.T) { | |
29 | var ( | |
30 | logger = log.NewNopLogger() | |
31 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
32 | ) | |
33 | ||
34 | factory := func(string) (endpoint.Endpoint, io.Closer, error) { | |
35 | return e, nil, nil | |
36 | } | |
37 | ||
38 | client := &fakeClient{ | |
39 | responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, | |
40 | } | |
41 | ||
42 | p, err := kitetcd.NewPublisher(client, "/foo", factory, logger) | |
43 | if err != nil { | |
44 | t.Fatalf("failed to create new publisher: %v", err) | |
45 | } | |
46 | defer p.Stop() | |
47 | ||
48 | if _, err := p.Endpoints(); err != nil { | |
49 | t.Fatal(err) | |
50 | } | |
51 | } | |
52 | ||
53 | func TestBadFactory(t *testing.T) { | |
54 | logger := log.NewNopLogger() | |
55 | ||
56 | factory := func(string) (endpoint.Endpoint, io.Closer, error) { | |
57 | return nil, nil, errors.New("kaboom") | |
58 | } | |
59 | ||
60 | client := &fakeClient{ | |
61 | responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, | |
62 | } | |
63 | ||
64 | p, err := kitetcd.NewPublisher(client, "/foo", factory, logger) | |
65 | if err != nil { | |
66 | t.Fatalf("failed to create new publisher: %v", err) | |
67 | } | |
68 | defer p.Stop() | |
69 | ||
70 | endpoints, err := p.Endpoints() | |
71 | if err != nil { | |
72 | t.Fatal(err) | |
73 | } | |
74 | ||
75 | if want, have := 0, len(endpoints); want != have { | |
76 | t.Errorf("want %q, have %q", want, have) | |
77 | } | |
78 | } | |
79 | ||
80 | type fakeClient struct { | |
81 | responses map[string]*stdetcd.Response | |
82 | } | |
83 | ||
84 | func (c *fakeClient) GetEntries(prefix string) ([]string, error) { | |
85 | response, ok := c.responses[prefix] | |
86 | if !ok { | |
87 | return nil, errors.New("key not exist") | |
88 | } | |
89 | ||
90 | entries := make([]string, len(response.Node.Nodes)) | |
91 | for i, node := range response.Node.Nodes { | |
92 | entries[i] = node.Value | |
93 | } | |
94 | return entries, nil | |
95 | } | |
96 | ||
97 | func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {} |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Factory is a function that converts an instance string, e.g. a host:port, | |
9 | // to a usable endpoint. Factories are used by load balancers to convert | |
10 | // instances returned by Publishers (typically host:port strings) into | |
11 | // endpoints. Users are expected to provide their own factory functions that | |
12 | // assume specific transports, or can deduce transports by parsing the | |
13 | // instance string. | |
14 | type Factory func(instance string) (endpoint.Endpoint, io.Closer, error) |
0 | package fixed | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Publisher yields the same set of fixed endpoints. | |
9 | type Publisher struct { | |
10 | mtx sync.RWMutex | |
11 | endpoints []endpoint.Endpoint | |
12 | } | |
13 | ||
14 | // NewPublisher returns a fixed endpoint Publisher. | |
15 | func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { | |
16 | return &Publisher{ | |
17 | endpoints: endpoints, | |
18 | } | |
19 | } | |
20 | ||
21 | // Endpoints implements the Publisher interface. | |
22 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
23 | p.mtx.RLock() | |
24 | defer p.mtx.RUnlock() | |
25 | return p.endpoints, nil | |
26 | } | |
27 | ||
28 | // Replace is a utility method to swap out the underlying endpoints of an | |
29 | // existing fixed publisher. It's useful mostly for testing. | |
30 | func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { | |
31 | p.mtx.Lock() | |
32 | defer p.mtx.Unlock() | |
33 | p.endpoints = endpoints | |
34 | } |
0 | package fixed_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer/fixed" | |
10 | ) | |
11 | ||
12 | func TestFixed(t *testing.T) { | |
13 | var ( | |
14 | e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
15 | e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
16 | endpoints = []endpoint.Endpoint{e1, e2} | |
17 | ) | |
18 | p := fixed.NewPublisher(endpoints) | |
19 | have, err := p.Endpoints() | |
20 | if err != nil { | |
21 | t.Fatal(err) | |
22 | } | |
23 | if want := endpoints; !reflect.DeepEqual(want, have) { | |
24 | t.Fatalf("want %#+v, have %#+v", want, have) | |
25 | } | |
26 | } | |
27 | ||
28 | func TestFixedReplace(t *testing.T) { | |
29 | p := fixed.NewPublisher([]endpoint.Endpoint{ | |
30 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
31 | }) | |
32 | have, err := p.Endpoints() | |
33 | if err != nil { | |
34 | t.Fatal(err) | |
35 | } | |
36 | if want, have := 1, len(have); want != have { | |
37 | t.Fatalf("want %d, have %d", want, have) | |
38 | } | |
39 | p.Replace([]endpoint.Endpoint{}) | |
40 | have, err = p.Endpoints() | |
41 | if err != nil { | |
42 | t.Fatal(err) | |
43 | } | |
44 | if want, have := 0, len(have); want != have { | |
45 | t.Fatalf("want %d, have %d", want, have) | |
46 | } | |
47 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // LoadBalancer describes something that can yield endpoints for a remote | |
9 | // service method. | |
10 | type LoadBalancer interface { | |
11 | Endpoint() (endpoint.Endpoint, error) | |
12 | } | |
13 | ||
14 | // ErrNoEndpoints is returned when a load balancer (or one of its components) | |
15 | // has no endpoints to return. In a request lifecycle, this is usually a fatal | |
16 | // error. | |
17 | var ErrNoEndpoints = errors.New("no endpoints available") |
0 | package loadbalancer | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Publisher describes something that provides a set of identical endpoints. | |
5 | // Different publisher implementations exist for different kinds of service | |
6 | // discovery systems. | |
7 | type Publisher interface { | |
8 | Endpoints() ([]endpoint.Endpoint, error) | |
9 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Random is a completely stateless load balancer that chooses a random | |
9 | // endpoint to return each time. | |
10 | type Random struct { | |
11 | p Publisher | |
12 | r *rand.Rand | |
13 | } | |
14 | ||
15 | // NewRandom returns a new Random load balancer. | |
16 | func NewRandom(p Publisher, seed int64) *Random { | |
17 | return &Random{ | |
18 | p: p, | |
19 | r: rand.New(rand.NewSource(seed)), | |
20 | } | |
21 | } | |
22 | ||
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (r *Random) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := r.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
29 | if len(endpoints) <= 0 { | |
30 | return nil, ErrNoEndpoints | |
31 | } | |
32 | return endpoints[r.r.Intn(len(endpoints))], nil | |
33 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "math" | |
4 | "testing" | |
5 | ||
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/loadbalancer" | |
10 | "github.com/go-kit/kit/loadbalancer/fixed" | |
11 | ) | |
12 | ||
13 | func TestRandomDistribution(t *testing.T) { | |
14 | var ( | |
15 | n = 3 | |
16 | endpoints = make([]endpoint.Endpoint, n) | |
17 | counts = make([]int, n) | |
18 | seed = int64(123) | |
19 | ctx = context.Background() | |
20 | iterations = 100000 | |
21 | want = iterations / n | |
22 | tolerance = want / 100 // 1% | |
23 | ) | |
24 | ||
25 | for i := 0; i < n; i++ { | |
26 | i0 := i | |
27 | endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } | |
28 | } | |
29 | ||
30 | lb := loadbalancer.NewRandom(fixed.NewPublisher(endpoints), seed) | |
31 | ||
32 | for i := 0; i < iterations; i++ { | |
33 | e, err := lb.Endpoint() | |
34 | if err != nil { | |
35 | t.Fatal(err) | |
36 | } | |
37 | if _, err := e(ctx, struct{}{}); err != nil { | |
38 | t.Error(err) | |
39 | } | |
40 | } | |
41 | ||
42 | for i, have := range counts { | |
43 | if math.Abs(float64(want-have)) > float64(tolerance) { | |
44 | t.Errorf("%d: want %d, have %d", i, want, have) | |
45 | } | |
46 | } | |
47 | } | |
48 | ||
49 | func TestRandomBadPublisher(t *testing.T) { | |
50 | t.Skip("TODO") | |
51 | } | |
52 | ||
53 | func TestRandomNoEndpoints(t *testing.T) { | |
54 | lb := loadbalancer.NewRandom(fixed.NewPublisher([]endpoint.Endpoint{}), 123) | |
55 | _, have := lb.Endpoint() | |
56 | if want := loadbalancer.ErrNoEndpoints; want != have { | |
57 | t.Errorf("want %q, have %q", want, have) | |
58 | } | |
59 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | // Retry wraps the load balancer to make it behave like a simple endpoint. | |
13 | // Requests to the endpoint will be automatically load balanced via the load | |
14 | // balancer. Requests that return errors will be retried until they succeed, | |
15 | // up to max times, or until the timeout is elapsed, whichever comes first. | |
16 | func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { | |
17 | if lb == nil { | |
18 | panic("nil LoadBalancer") | |
19 | } | |
20 | ||
21 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
22 | var ( | |
23 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
24 | responses = make(chan interface{}, 1) | |
25 | errs = make(chan error, 1) | |
26 | a = []string{} | |
27 | ) | |
28 | defer cancel() | |
29 | for i := 1; i <= max; i++ { | |
30 | go func() { | |
31 | e, err := lb.Endpoint() | |
32 | if err != nil { | |
33 | errs <- err | |
34 | return | |
35 | } | |
36 | response, err := e(newctx, request) | |
37 | if err != nil { | |
38 | errs <- err | |
39 | return | |
40 | } | |
41 | responses <- response | |
42 | }() | |
43 | ||
44 | select { | |
45 | case <-newctx.Done(): | |
46 | return nil, newctx.Err() | |
47 | case response := <-responses: | |
48 | return response, nil | |
49 | case err := <-errs: | |
50 | a = append(a, err.Error()) | |
51 | continue | |
52 | } | |
53 | } | |
54 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
55 | } | |
56 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer" | |
11 | "github.com/go-kit/kit/loadbalancer/fixed" | |
12 | ) | |
13 | ||
14 | func TestRetryMaxTotalFail(t *testing.T) { | |
15 | var ( | |
16 | endpoints = []endpoint.Endpoint{} // no endpoints | |
17 | p = fixed.NewPublisher(endpoints) | |
18 | lb = loadbalancer.NewRoundRobin(p) | |
19 | retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries | |
20 | ctx = context.Background() | |
21 | ) | |
22 | if _, err := retry(ctx, struct{}{}); err == nil { | |
23 | t.Errorf("expected error, got none") // should fail | |
24 | } | |
25 | } | |
26 | ||
27 | func TestRetryMaxPartialFail(t *testing.T) { | |
28 | var ( | |
29 | endpoints = []endpoint.Endpoint{ | |
30 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
31 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
32 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
33 | } | |
34 | retries = len(endpoints) - 1 // not quite enough retries | |
35 | p = fixed.NewPublisher(endpoints) | |
36 | lb = loadbalancer.NewRoundRobin(p) | |
37 | ctx = context.Background() | |
38 | ) | |
39 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { | |
40 | t.Errorf("expected error, got none") | |
41 | } | |
42 | } | |
43 | ||
44 | func TestRetryMaxSuccess(t *testing.T) { | |
45 | var ( | |
46 | endpoints = []endpoint.Endpoint{ | |
47 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
48 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
49 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
50 | } | |
51 | retries = len(endpoints) // exactly enough retries | |
52 | p = fixed.NewPublisher(endpoints) | |
53 | lb = loadbalancer.NewRoundRobin(p) | |
54 | ctx = context.Background() | |
55 | ) | |
56 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { | |
57 | t.Error(err) | |
58 | } | |
59 | } | |
60 | ||
61 | func TestRetryTimeout(t *testing.T) { | |
62 | var ( | |
63 | step = make(chan struct{}) | |
64 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } | |
65 | timeout = time.Millisecond | |
66 | retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(fixed.NewPublisher([]endpoint.Endpoint{e}))) | |
67 | errs = make(chan error, 1) | |
68 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } | |
69 | ) | |
70 | ||
71 | go func() { step <- struct{}{} }() // queue up a flush of the endpoint | |
72 | invoke() // invoke the endpoint and trigger the flush | |
73 | if err := <-errs; err != nil { // that should succeed | |
74 | t.Error(err) | |
75 | } | |
76 | ||
77 | go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush | |
78 | invoke() // invoke the endpoint | |
79 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
80 | t.Errorf("wanted %v, got none", context.DeadlineExceeded) | |
81 | } | |
82 | } |
0 | package loadbalancer | |
1 | ||
2 | import ( | |
3 | "sync/atomic" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // RoundRobin is a simple load balancer that returns each of the published | |
9 | // endpoints in sequence. | |
10 | type RoundRobin struct { | |
11 | p Publisher | |
12 | counter uint64 | |
13 | } | |
14 | ||
15 | // NewRoundRobin returns a new RoundRobin load balancer. | |
16 | func NewRoundRobin(p Publisher) *RoundRobin { | |
17 | return &RoundRobin{ | |
18 | p: p, | |
19 | counter: 0, | |
20 | } | |
21 | } | |
22 | ||
23 | // Endpoint implements the LoadBalancer interface. | |
24 | func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) { | |
25 | endpoints, err := rr.p.Endpoints() | |
26 | if err != nil { | |
27 | return nil, err | |
28 | } | |
29 | if len(endpoints) <= 0 { | |
30 | return nil, ErrNoEndpoints | |
31 | } | |
32 | var old uint64 | |
33 | for { | |
34 | old = atomic.LoadUint64(&rr.counter) | |
35 | if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) { | |
36 | break | |
37 | } | |
38 | } | |
39 | return endpoints[old%uint64(len(endpoints))], nil | |
40 | } |
0 | package loadbalancer_test | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "testing" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/loadbalancer" | |
8 | "github.com/go-kit/kit/loadbalancer/fixed" | |
9 | "golang.org/x/net/context" | |
10 | ) | |
11 | ||
12 | func TestRoundRobinDistribution(t *testing.T) { | |
13 | var ( | |
14 | ctx = context.Background() | |
15 | counts = []int{0, 0, 0} | |
16 | endpoints = []endpoint.Endpoint{ | |
17 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
18 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
19 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
20 | } | |
21 | ) | |
22 | ||
23 | lb := loadbalancer.NewRoundRobin(fixed.NewPublisher(endpoints)) | |
24 | ||
25 | for i, want := range [][]int{ | |
26 | {1, 0, 0}, | |
27 | {1, 1, 0}, | |
28 | {1, 1, 1}, | |
29 | {2, 1, 1}, | |
30 | {2, 2, 1}, | |
31 | {2, 2, 2}, | |
32 | {3, 2, 2}, | |
33 | } { | |
34 | e, err := lb.Endpoint() | |
35 | if err != nil { | |
36 | t.Fatal(err) | |
37 | } | |
38 | if _, err := e(ctx, struct{}{}); err != nil { | |
39 | t.Error(err) | |
40 | } | |
41 | if have := counts; !reflect.DeepEqual(want, have) { | |
42 | t.Fatalf("%d: want %v, have %v", i, want, have) | |
43 | } | |
44 | ||
45 | } | |
46 | } | |
47 | ||
48 | func TestRoundRobinBadPublisher(t *testing.T) { | |
49 | t.Skip("TODO") | |
50 | } |
0 | package static | |
1 | ||
2 | import ( | |
3 | "github.com/go-kit/kit/endpoint" | |
4 | "github.com/go-kit/kit/loadbalancer" | |
5 | "github.com/go-kit/kit/loadbalancer/fixed" | |
6 | "github.com/go-kit/kit/log" | |
7 | ) | |
8 | ||
9 | // Publisher yields a set of static endpoints as produced by the passed factory. | |
10 | type Publisher struct{ publisher *fixed.Publisher } | |
11 | ||
12 | // NewPublisher returns a static endpoint Publisher. | |
13 | func NewPublisher(instances []string, factory loadbalancer.Factory, logger log.Logger) Publisher { | |
14 | logger = log.NewContext(logger).With("component", "Static Publisher") | |
15 | endpoints := []endpoint.Endpoint{} | |
16 | for _, instance := range instances { | |
17 | e, _, err := factory(instance) // never close | |
18 | if err != nil { | |
19 | logger.Log("instance", instance, "err", err) | |
20 | continue | |
21 | } | |
22 | endpoints = append(endpoints, e) | |
23 | } | |
24 | return Publisher{publisher: fixed.NewPublisher(endpoints)} | |
25 | } | |
26 | ||
27 | // Endpoints implements Publisher. | |
28 | func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
29 | return p.publisher.Endpoints() | |
30 | } |
0 | package static_test | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "io" | |
5 | "testing" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/loadbalancer/static" | |
11 | "github.com/go-kit/kit/log" | |
12 | ) | |
13 | ||
14 | func TestStatic(t *testing.T) { | |
15 | var ( | |
16 | instances = []string{"foo", "bar", "baz"} | |
17 | endpoints = map[string]endpoint.Endpoint{ | |
18 | "foo": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
19 | "bar": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
20 | "baz": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
21 | } | |
22 | factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
23 | if e, ok := endpoints[instance]; ok { | |
24 | return e, nil, nil | |
25 | } | |
26 | return nil, nil, fmt.Errorf("%s: not found", instance) | |
27 | } | |
28 | ) | |
29 | p := static.NewPublisher(instances, factory, log.NewNopLogger()) | |
30 | have, err := p.Endpoints() | |
31 | if err != nil { | |
32 | t.Fatal(err) | |
33 | } | |
34 | want := []endpoint.Endpoint{endpoints["foo"], endpoints["bar"], endpoints["baz"]} | |
35 | if fmt.Sprint(want) != fmt.Sprint(have) { | |
36 | t.Fatalf("want %v, have %v", want, have) | |
37 | } | |
38 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "net" | |
5 | "strings" | |
6 | "time" | |
7 | ||
8 | "github.com/samuel/go-zookeeper/zk" | |
9 | ||
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | // DefaultACL is the default ACL to use for creating znodes. | |
14 | var ( | |
15 | DefaultACL = zk.WorldACL(zk.PermAll) | |
16 | ErrInvalidCredentials = errors.New("invalid credentials provided") | |
17 | ErrClientClosed = errors.New("client service closed") | |
18 | ) | |
19 | ||
20 | const ( | |
21 | // DefaultConnectTimeout is the default timeout to establish a connection to | |
22 | // a ZooKeeper node. | |
23 | DefaultConnectTimeout = 2 * time.Second | |
24 | // DefaultSessionTimeout is the default timeout to keep the current | |
25 | // ZooKeeper session alive during a temporary disconnect. | |
26 | DefaultSessionTimeout = 5 * time.Second | |
27 | ) | |
28 | ||
29 | // Client is a wrapper around a lower level ZooKeeper client implementation. | |
30 | type Client interface { | |
31 | // GetEntries should query the provided path in ZooKeeper, place a watch on | |
32 | // it and retrieve data from its current child nodes. | |
33 | GetEntries(path string) ([]string, <-chan zk.Event, error) | |
34 | // CreateParentNodes should try to create the path in case it does not exist | |
35 | // yet on ZooKeeper. | |
36 | CreateParentNodes(path string) error | |
37 | // Stop should properly shutdown the client implementation | |
38 | Stop() | |
39 | } | |
40 | ||
41 | type clientConfig struct { | |
42 | logger log.Logger | |
43 | acl []zk.ACL | |
44 | credentials []byte | |
45 | connectTimeout time.Duration | |
46 | sessionTimeout time.Duration | |
47 | rootNodePayload [][]byte | |
48 | eventHandler func(zk.Event) | |
49 | } | |
50 | ||
51 | // Option functions enable friendly APIs. | |
52 | type Option func(*clientConfig) error | |
53 | ||
54 | type client struct { | |
55 | *zk.Conn | |
56 | clientConfig | |
57 | active bool | |
58 | quit chan struct{} | |
59 | } | |
60 | ||
61 | // ACL returns an Option specifying a non-default ACL for creating parent nodes. | |
62 | func ACL(acl []zk.ACL) Option { | |
63 | return func(c *clientConfig) error { | |
64 | c.acl = acl | |
65 | return nil | |
66 | } | |
67 | } | |
68 | ||
69 | // Credentials returns an Option specifying a user/password combination which | |
70 | // the client will use to authenticate itself with. | |
71 | func Credentials(user, pass string) Option { | |
72 | return func(c *clientConfig) error { | |
73 | if user == "" || pass == "" { | |
74 | return ErrInvalidCredentials | |
75 | } | |
76 | c.credentials = []byte(user + ":" + pass) | |
77 | return nil | |
78 | } | |
79 | } | |
80 | ||
81 | // ConnectTimeout returns an Option specifying a non-default connection timeout | |
82 | // when we try to establish a connection to a ZooKeeper server. | |
83 | func ConnectTimeout(t time.Duration) Option { | |
84 | return func(c *clientConfig) error { | |
85 | if t.Seconds() < 1 { | |
86 | return errors.New("invalid connect timeout (minimum value is 1 second)") | |
87 | } | |
88 | c.connectTimeout = t | |
89 | return nil | |
90 | } | |
91 | } | |
92 | ||
93 | // SessionTimeout returns an Option specifying a non-default session timeout. | |
94 | func SessionTimeout(t time.Duration) Option { | |
95 | return func(c *clientConfig) error { | |
96 | if t.Seconds() < 1 { | |
97 | return errors.New("invalid session timeout (minimum value is 1 second)") | |
98 | } | |
99 | c.sessionTimeout = t | |
100 | return nil | |
101 | } | |
102 | } | |
103 | ||
104 | // Payload returns an Option specifying non-default data values for each znode | |
105 | // created by CreateParentNodes. | |
106 | func Payload(payload [][]byte) Option { | |
107 | return func(c *clientConfig) error { | |
108 | c.rootNodePayload = payload | |
109 | return nil | |
110 | } | |
111 | } | |
112 | ||
113 | // EventHandler returns an Option specifying a callback function to handle | |
114 | // incoming zk.Event payloads (ZooKeeper connection events). | |
115 | func EventHandler(handler func(zk.Event)) Option { | |
116 | return func(c *clientConfig) error { | |
117 | c.eventHandler = handler | |
118 | return nil | |
119 | } | |
120 | } | |
121 | ||
122 | // NewClient returns a ZooKeeper client with a connection to the server cluster. | |
123 | // It will return an error if the server cluster cannot be resolved. | |
124 | func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) { | |
125 | defaultEventHandler := func(event zk.Event) { | |
126 | logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err) | |
127 | } | |
128 | config := clientConfig{ | |
129 | acl: DefaultACL, | |
130 | connectTimeout: DefaultConnectTimeout, | |
131 | sessionTimeout: DefaultSessionTimeout, | |
132 | eventHandler: defaultEventHandler, | |
133 | logger: logger, | |
134 | } | |
135 | for _, option := range options { | |
136 | if err := option(&config); err != nil { | |
137 | return nil, err | |
138 | } | |
139 | } | |
140 | // dialer overrides the default ZooKeeper library Dialer so we can configure | |
141 | // the connectTimeout. The current library has a hardcoded value of 1 second | |
142 | // and there are reports of race conditions, due to slow DNS resolvers and | |
143 | // other network latency issues. | |
144 | dialer := func(network, address string, _ time.Duration) (net.Conn, error) { | |
145 | return net.DialTimeout(network, address, config.connectTimeout) | |
146 | } | |
147 | conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer)) | |
148 | ||
149 | if err != nil { | |
150 | return nil, err | |
151 | } | |
152 | ||
153 | if len(config.credentials) > 0 { | |
154 | err = conn.AddAuth("digest", config.credentials) | |
155 | if err != nil { | |
156 | return nil, err | |
157 | } | |
158 | } | |
159 | ||
160 | c := &client{conn, config, true, make(chan struct{})} | |
161 | ||
162 | // Start listening for incoming Event payloads and callback the set | |
163 | // eventHandler. | |
164 | go func() { | |
165 | for { | |
166 | select { | |
167 | case event := <-eventc: | |
168 | config.eventHandler(event) | |
169 | case <-c.quit: | |
170 | return | |
171 | } | |
172 | } | |
173 | }() | |
174 | return c, nil | |
175 | } | |
176 | ||
177 | // CreateParentNodes implements the ZooKeeper Client interface. | |
178 | func (c *client) CreateParentNodes(path string) error { | |
179 | if !c.active { | |
180 | return ErrClientClosed | |
181 | } | |
182 | if path[0] != '/' { | |
183 | return zk.ErrInvalidPath | |
184 | } | |
185 | payload := []byte("") | |
186 | pathString := "" | |
187 | pathNodes := strings.Split(path, "/") | |
188 | for i := 1; i < len(pathNodes); i++ { | |
189 | if i <= len(c.rootNodePayload) { | |
190 | payload = c.rootNodePayload[i-1] | |
191 | } else { | |
192 | payload = []byte("") | |
193 | } | |
194 | pathString += "/" + pathNodes[i] | |
195 | _, err := c.Create(pathString, payload, 0, c.acl) | |
196 | // not being able to create the node because it exists or not having | |
197 | // sufficient rights is not an issue. It is ok for the node to already | |
198 | // exist and/or us to only have read rights | |
199 | if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth { | |
200 | return err | |
201 | } | |
202 | } | |
203 | return nil | |
204 | } | |
205 | ||
206 | // GetEntries implements the ZooKeeper Client interface. | |
207 | func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) { | |
208 | // retrieve list of child nodes for given path and add watch to path | |
209 | znodes, _, eventc, err := c.ChildrenW(path) | |
210 | ||
211 | if err != nil { | |
212 | return nil, eventc, err | |
213 | } | |
214 | ||
215 | var resp []string | |
216 | for _, znode := range znodes { | |
217 | // retrieve payload for child znode and add to response array | |
218 | if data, _, err := c.Get(path + "/" + znode); err == nil { | |
219 | resp = append(resp, string(data)) | |
220 | } | |
221 | } | |
222 | return resp, eventc, nil | |
223 | } | |
224 | ||
225 | // Stop implements the ZooKeeper Client interface. | |
226 | func (c *client) Stop() { | |
227 | c.active = false | |
228 | close(c.quit) | |
229 | c.Close() | |
230 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "bytes" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | stdzk "github.com/samuel/go-zookeeper/zk" | |
8 | ||
9 | "github.com/go-kit/kit/log" | |
10 | ) | |
11 | ||
12 | func TestNewClient(t *testing.T) { | |
13 | var ( | |
14 | acl = stdzk.WorldACL(stdzk.PermRead) | |
15 | connectTimeout = 3 * time.Second | |
16 | sessionTimeout = 20 * time.Second | |
17 | payload = [][]byte{[]byte("Payload"), []byte("Test")} | |
18 | ) | |
19 | ||
20 | c, err := NewClient( | |
21 | []string{"FailThisInvalidHost!!!"}, | |
22 | log.NewNopLogger(), | |
23 | ) | |
24 | if err == nil { | |
25 | t.Errorf("expected error, got nil") | |
26 | } | |
27 | ||
28 | hasFired := false | |
29 | calledEventHandler := make(chan struct{}) | |
30 | eventHandler := func(event stdzk.Event) { | |
31 | if !hasFired { | |
32 | // test is successful if this function has fired at least once | |
33 | hasFired = true | |
34 | close(calledEventHandler) | |
35 | } | |
36 | } | |
37 | ||
38 | c, err = NewClient( | |
39 | []string{"localhost"}, | |
40 | log.NewNopLogger(), | |
41 | ACL(acl), | |
42 | ConnectTimeout(connectTimeout), | |
43 | SessionTimeout(sessionTimeout), | |
44 | Payload(payload), | |
45 | EventHandler(eventHandler), | |
46 | ) | |
47 | if err != nil { | |
48 | t.Fatal(err) | |
49 | } | |
50 | defer c.Stop() | |
51 | ||
52 | clientImpl, ok := c.(*client) | |
53 | if !ok { | |
54 | t.Fatal("retrieved incorrect Client implementation") | |
55 | } | |
56 | if want, have := acl, clientImpl.acl; want[0] != have[0] { | |
57 | t.Errorf("want %+v, have %+v", want, have) | |
58 | } | |
59 | if want, have := connectTimeout, clientImpl.connectTimeout; want != have { | |
60 | t.Errorf("want %d, have %d", want, have) | |
61 | } | |
62 | if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have { | |
63 | t.Errorf("want %d, have %d", want, have) | |
64 | } | |
65 | if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 { | |
66 | t.Errorf("want %s, have %s", want, have) | |
67 | } | |
68 | ||
69 | select { | |
70 | case <-calledEventHandler: | |
71 | case <-time.After(100 * time.Millisecond): | |
72 | t.Errorf("event handler never called") | |
73 | } | |
74 | } | |
75 | ||
76 | func TestOptions(t *testing.T) { | |
77 | _, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials")) | |
78 | if err != nil && err != stdzk.ErrNoServer { | |
79 | t.Errorf("unexpected error: %v", err) | |
80 | } | |
81 | ||
82 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", "")) | |
83 | if want, have := err, ErrInvalidCredentials; want != have { | |
84 | t.Errorf("want %v, have %v", want, have) | |
85 | } | |
86 | ||
87 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0)) | |
88 | if err == nil { | |
89 | t.Errorf("expected connect timeout error") | |
90 | } | |
91 | ||
92 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0)) | |
93 | if err == nil { | |
94 | t.Errorf("expected connect timeout error") | |
95 | } | |
96 | } | |
97 | ||
98 | func TestCreateParentNodes(t *testing.T) { | |
99 | payload := [][]byte{[]byte("Payload"), []byte("Test")} | |
100 | ||
101 | c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger()) | |
102 | if err != nil { | |
103 | t.Errorf("unexpected error: %v", err) | |
104 | } | |
105 | if c == nil { | |
106 | t.Fatal("expected new Client, got nil") | |
107 | } | |
108 | ||
109 | p, err := NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
110 | if err != stdzk.ErrNoServer { | |
111 | t.Errorf("unexpected error: %v", err) | |
112 | } | |
113 | if p != nil { | |
114 | t.Error("expected failed new Publisher") | |
115 | } | |
116 | ||
117 | p, err = NewPublisher(c, "invalidpath", newFactory(""), log.NewNopLogger()) | |
118 | if err != stdzk.ErrInvalidPath { | |
119 | t.Errorf("unexpected error: %v", err) | |
120 | } | |
121 | _, _, err = c.GetEntries("/validpath") | |
122 | if err != stdzk.ErrNoServer { | |
123 | t.Errorf("unexpected error: %v", err) | |
124 | } | |
125 | ||
126 | c.Stop() | |
127 | ||
128 | err = c.CreateParentNodes("/validpath") | |
129 | if err != ErrClientClosed { | |
130 | t.Errorf("unexpected error: %v", err) | |
131 | } | |
132 | ||
133 | p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
134 | if err != ErrClientClosed { | |
135 | t.Errorf("unexpected error: %v", err) | |
136 | } | |
137 | if p != nil { | |
138 | t.Error("expected failed new Publisher") | |
139 | } | |
140 | ||
141 | c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload)) | |
142 | if err != nil { | |
143 | t.Errorf("unexpected error: %v", err) | |
144 | } | |
145 | if c == nil { | |
146 | t.Fatal("expected new Client, got nil") | |
147 | } | |
148 | ||
149 | p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
150 | if err != stdzk.ErrNoServer { | |
151 | t.Errorf("unexpected error: %v", err) | |
152 | } | |
153 | if p != nil { | |
154 | t.Error("expected failed new Publisher") | |
155 | } | |
156 | } |
0 | // +build integration | |
1 | ||
2 | package zk | |
3 | ||
4 | import ( | |
5 | "bytes" | |
6 | "flag" | |
7 | "fmt" | |
8 | "os" | |
9 | "testing" | |
10 | "time" | |
11 | ||
12 | stdzk "github.com/samuel/go-zookeeper/zk" | |
13 | ) | |
14 | ||
15 | var ( | |
16 | host []string | |
17 | ) | |
18 | ||
19 | func TestMain(m *testing.M) { | |
20 | flag.Parse() | |
21 | ||
22 | fmt.Println("Starting ZooKeeper server...") | |
23 | ||
24 | ts, err := stdzk.StartTestCluster(1, nil, nil) | |
25 | if err != nil { | |
26 | fmt.Printf("ZooKeeper server error: %v\n", err) | |
27 | os.Exit(1) | |
28 | } | |
29 | ||
30 | host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)} | |
31 | code := m.Run() | |
32 | ||
33 | ts.Stop() | |
34 | os.Exit(code) | |
35 | } | |
36 | ||
37 | func TestCreateParentNodesOnServer(t *testing.T) { | |
38 | payload := [][]byte{[]byte("Payload"), []byte("Test")} | |
39 | c1, err := NewClient(host, logger, Payload(payload)) | |
40 | if err != nil { | |
41 | t.Fatalf("Connect returned error: %v", err) | |
42 | } | |
43 | if c1 == nil { | |
44 | t.Fatal("Expected pointer to client, got nil") | |
45 | } | |
46 | defer c1.Stop() | |
47 | ||
48 | p, err := NewPublisher(c1, path, newFactory(""), logger) | |
49 | if err != nil { | |
50 | t.Fatalf("Unable to create Publisher: %v", err) | |
51 | } | |
52 | defer p.Stop() | |
53 | ||
54 | endpoints, err := p.Endpoints() | |
55 | if err != nil { | |
56 | t.Fatal(err) | |
57 | } | |
58 | if want, have := 0, len(endpoints); want != have { | |
59 | t.Errorf("want %d, have %d", want, have) | |
60 | } | |
61 | ||
62 | c2, err := NewClient(host, logger) | |
63 | if err != nil { | |
64 | t.Fatalf("Connect returned error: %v", err) | |
65 | } | |
66 | defer c2.Stop() | |
67 | data, _, err := c2.(*client).Get(path) | |
68 | if err != nil { | |
69 | t.Fatal(err) | |
70 | } | |
71 | // test Client implementation of CreateParentNodes. It should have created | |
72 | // our payload | |
73 | if bytes.Compare(data, payload[1]) != 0 { | |
74 | t.Errorf("want %s, have %s", payload[1], data) | |
75 | } | |
76 | ||
77 | } | |
78 | ||
79 | func TestCreateBadParentNodesOnServer(t *testing.T) { | |
80 | c, _ := NewClient(host, logger) | |
81 | defer c.Stop() | |
82 | ||
83 | _, err := NewPublisher(c, "invalid/path", newFactory(""), logger) | |
84 | ||
85 | if want, have := stdzk.ErrInvalidPath, err; want != have { | |
86 | t.Errorf("want %v, have %v", want, have) | |
87 | } | |
88 | } | |
89 | ||
90 | func TestCredentials1(t *testing.T) { | |
91 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") | |
92 | c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret")) | |
93 | defer c.Stop() | |
94 | ||
95 | _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) | |
96 | ||
97 | if err != nil { | |
98 | t.Fatal(err) | |
99 | } | |
100 | } | |
101 | ||
102 | func TestCredentials2(t *testing.T) { | |
103 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") | |
104 | c, _ := NewClient(host, logger, ACL(acl)) | |
105 | defer c.Stop() | |
106 | ||
107 | _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) | |
108 | ||
109 | if err != stdzk.ErrNoAuth { | |
110 | t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err) | |
111 | } | |
112 | } | |
113 | ||
114 | func TestConnection(t *testing.T) { | |
115 | c, _ := NewClient(host, logger) | |
116 | c.Stop() | |
117 | ||
118 | _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) | |
119 | ||
120 | if err != ErrClientClosed { | |
121 | t.Errorf("want %v, have %v", ErrClientClosed, err) | |
122 | } | |
123 | } | |
124 | ||
125 | func TestGetEntriesOnServer(t *testing.T) { | |
126 | var instancePayload = "protocol://hostname:port/routing" | |
127 | ||
128 | c1, err := NewClient(host, logger) | |
129 | if err != nil { | |
130 | t.Fatalf("Connect returned error: %v", err) | |
131 | } | |
132 | ||
133 | defer c1.Stop() | |
134 | ||
135 | c2, err := NewClient(host, logger) | |
136 | p, err := NewPublisher(c2, path, newFactory(""), logger) | |
137 | if err != nil { | |
138 | t.Fatal(err) | |
139 | } | |
140 | defer c2.Stop() | |
141 | ||
142 | c2impl, _ := c2.(*client) | |
143 | _, err = c2impl.Create( | |
144 | path+"/instance1", | |
145 | []byte(instancePayload), | |
146 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
147 | stdzk.WorldACL(stdzk.PermAll), | |
148 | ) | |
149 | if err != nil { | |
150 | t.Fatalf("Unable to create test ephemeral znode 1: %v", err) | |
151 | } | |
152 | _, err = c2impl.Create( | |
153 | path+"/instance2", | |
154 | []byte(instancePayload+"2"), | |
155 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
156 | stdzk.WorldACL(stdzk.PermAll), | |
157 | ) | |
158 | if err != nil { | |
159 | t.Fatalf("Unable to create test ephemeral znode 2: %v", err) | |
160 | } | |
161 | ||
162 | time.Sleep(50 * time.Millisecond) | |
163 | ||
164 | endpoints, err := p.Endpoints() | |
165 | if err != nil { | |
166 | t.Fatal(err) | |
167 | } | |
168 | if want, have := 2, len(endpoints); want != have { | |
169 | t.Errorf("want %d, have %d", want, have) | |
170 | } | |
171 | } | |
172 | ||
173 | func TestGetEntriesPayloadOnServer(t *testing.T) { | |
174 | c, err := NewClient(host, logger) | |
175 | if err != nil { | |
176 | t.Fatalf("Connect returned error: %v", err) | |
177 | } | |
178 | _, eventc, err := c.GetEntries(path) | |
179 | if err != nil { | |
180 | t.Fatal(err) | |
181 | } | |
182 | _, err = c.(*client).Create( | |
183 | path+"/instance3", | |
184 | []byte("just some payload"), | |
185 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
186 | stdzk.WorldACL(stdzk.PermAll), | |
187 | ) | |
188 | if err != nil { | |
189 | t.Fatalf("Unable to create test ephemeral znode: %v", err) | |
190 | } | |
191 | select { | |
192 | case event := <-eventc: | |
193 | if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have { | |
194 | t.Errorf("want %s, have %s", want, have) | |
195 | } | |
196 | case <-time.After(20 * time.Millisecond): | |
197 | t.Errorf("expected incoming watch event, timeout occurred") | |
198 | } | |
199 | ||
200 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | ||
5 | "github.com/samuel/go-zookeeper/zk" | |
6 | ||
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | // wrapLogger wraps a go-kit logger so we can use it as the logging service for | |
11 | // the ZooKeeper library (which expects a Printf method to be available) | |
12 | type wrapLogger struct { | |
13 | log.Logger | |
14 | } | |
15 | ||
16 | func (logger wrapLogger) Printf(str string, vars ...interface{}) { | |
17 | logger.Log("msg", fmt.Sprintf(str, vars...)) | |
18 | } | |
19 | ||
20 | // withLogger replaces the ZooKeeper library's default logging service for our | |
21 | // own go-kit logger | |
22 | func withLogger(logger log.Logger) func(c *zk.Conn) { | |
23 | return func(c *zk.Conn) { | |
24 | c.SetLogger(wrapLogger{logger}) | |
25 | } | |
26 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "github.com/go-kit/kit/endpoint" | |
4 | "github.com/go-kit/kit/loadbalancer" | |
5 | "github.com/go-kit/kit/log" | |
6 | "github.com/samuel/go-zookeeper/zk" | |
7 | ) | |
8 | ||
9 | // Publisher yield endpoints stored in a certain ZooKeeper path. Any kind of | |
10 | // change in that path is watched and will update the Publisher endpoints. | |
11 | type Publisher struct { | |
12 | client Client | |
13 | path string | |
14 | cache *loadbalancer.EndpointCache | |
15 | logger log.Logger | |
16 | quit chan struct{} | |
17 | } | |
18 | ||
19 | // NewPublisher returns a ZooKeeper publisher. ZooKeeper will start watching the | |
20 | // given path for changes and update the Publisher endpoints. | |
21 | func NewPublisher(c Client, path string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { | |
22 | p := &Publisher{ | |
23 | client: c, | |
24 | path: path, | |
25 | cache: loadbalancer.NewEndpointCache(f, logger), | |
26 | logger: logger, | |
27 | quit: make(chan struct{}), | |
28 | } | |
29 | ||
30 | err := p.client.CreateParentNodes(p.path) | |
31 | if err != nil { | |
32 | return nil, err | |
33 | } | |
34 | ||
35 | // initial node retrieval and cache fill | |
36 | instances, eventc, err := p.client.GetEntries(p.path) | |
37 | if err != nil { | |
38 | logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err) | |
39 | return nil, err | |
40 | } | |
41 | logger.Log("path", p.path, "instances", len(instances)) | |
42 | p.cache.Replace(instances) | |
43 | ||
44 | // handle incoming path updates | |
45 | go p.loop(eventc) | |
46 | ||
47 | return p, nil | |
48 | } | |
49 | ||
50 | func (p *Publisher) loop(eventc <-chan zk.Event) { | |
51 | var ( | |
52 | instances []string | |
53 | err error | |
54 | ) | |
55 | for { | |
56 | select { | |
57 | case <-eventc: | |
58 | // we received a path update notification, call GetEntries to | |
59 | // retrieve child node data and set new watch as zk watches are one | |
60 | // time triggers | |
61 | instances, eventc, err = p.client.GetEntries(p.path) | |
62 | if err != nil { | |
63 | p.logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err) | |
64 | continue | |
65 | } | |
66 | p.logger.Log("path", p.path, "instances", len(instances)) | |
67 | p.cache.Replace(instances) | |
68 | case <-p.quit: | |
69 | return | |
70 | } | |
71 | } | |
72 | } | |
73 | ||
74 | // Endpoints implements the Publisher interface. | |
75 | func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { | |
76 | return p.cache.Endpoints() | |
77 | } | |
78 | ||
79 | // Stop terminates the Publisher. | |
80 | func (p *Publisher) Stop() { | |
81 | close(p.quit) | |
82 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | "time" | |
5 | ) | |
6 | ||
7 | func TestPublisher(t *testing.T) { | |
8 | client := newFakeClient() | |
9 | ||
10 | p, err := NewPublisher(client, path, newFactory(""), logger) | |
11 | if err != nil { | |
12 | t.Fatalf("failed to create new publisher: %v", err) | |
13 | } | |
14 | defer p.Stop() | |
15 | ||
16 | if _, err := p.Endpoints(); err != nil { | |
17 | t.Fatal(err) | |
18 | } | |
19 | } | |
20 | ||
21 | func TestBadFactory(t *testing.T) { | |
22 | client := newFakeClient() | |
23 | ||
24 | p, err := NewPublisher(client, path, newFactory("kaboom"), logger) | |
25 | if err != nil { | |
26 | t.Fatalf("failed to create new publisher: %v", err) | |
27 | } | |
28 | defer p.Stop() | |
29 | ||
30 | // instance1 came online | |
31 | client.AddService(path+"/instance1", "kaboom") | |
32 | ||
33 | // instance2 came online | |
34 | client.AddService(path+"/instance2", "zookeeper_node_data") | |
35 | ||
36 | if err = asyncTest(100*time.Millisecond, 1, p); err != nil { | |
37 | t.Error(err) | |
38 | } | |
39 | } | |
40 | ||
41 | func TestServiceUpdate(t *testing.T) { | |
42 | client := newFakeClient() | |
43 | ||
44 | p, err := NewPublisher(client, path, newFactory(""), logger) | |
45 | if err != nil { | |
46 | t.Fatalf("failed to create new publisher: %v", err) | |
47 | } | |
48 | defer p.Stop() | |
49 | ||
50 | endpoints, err := p.Endpoints() | |
51 | if err != nil { | |
52 | t.Fatal(err) | |
53 | } | |
54 | ||
55 | if want, have := 0, len(endpoints); want != have { | |
56 | t.Errorf("want %d, have %d", want, have) | |
57 | } | |
58 | ||
59 | // instance1 came online | |
60 | client.AddService(path+"/instance1", "zookeeper_node_data") | |
61 | ||
62 | // instance2 came online | |
63 | client.AddService(path+"/instance2", "zookeeper_node_data2") | |
64 | ||
65 | // we should have 2 instances | |
66 | if err = asyncTest(100*time.Millisecond, 2, p); err != nil { | |
67 | t.Error(err) | |
68 | } | |
69 | ||
70 | // watch triggers an error... | |
71 | client.SendErrorOnWatch() | |
72 | ||
73 | // test if error was consumed | |
74 | if err = client.ErrorIsConsumed(100 * time.Millisecond); err != nil { | |
75 | t.Error(err) | |
76 | } | |
77 | ||
78 | // instance3 came online | |
79 | client.AddService(path+"/instance3", "zookeeper_node_data3") | |
80 | ||
81 | // we should have 3 instances | |
82 | if err = asyncTest(100*time.Millisecond, 3, p); err != nil { | |
83 | t.Error(err) | |
84 | } | |
85 | ||
86 | // instance1 goes offline | |
87 | client.RemoveService(path + "/instance1") | |
88 | ||
89 | // instance2 goes offline | |
90 | client.RemoveService(path + "/instance2") | |
91 | ||
92 | // we should have 1 instance | |
93 | if err = asyncTest(100*time.Millisecond, 1, p); err != nil { | |
94 | t.Error(err) | |
95 | } | |
96 | } | |
97 | ||
98 | func TestBadPublisherCreate(t *testing.T) { | |
99 | client := newFakeClient() | |
100 | client.SendErrorOnWatch() | |
101 | p, err := NewPublisher(client, path, newFactory(""), logger) | |
102 | if err == nil { | |
103 | t.Error("expected error on new publisher") | |
104 | } | |
105 | if p != nil { | |
106 | t.Error("expected publisher not to be created") | |
107 | } | |
108 | p, err = NewPublisher(client, "BadPath", newFactory(""), logger) | |
109 | if err == nil { | |
110 | t.Error("expected error on new publisher") | |
111 | } | |
112 | if p != nil { | |
113 | t.Error("expected publisher not to be created") | |
114 | } | |
115 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "io" | |
6 | "sync" | |
7 | "time" | |
8 | ||
9 | "github.com/samuel/go-zookeeper/zk" | |
10 | "golang.org/x/net/context" | |
11 | ||
12 | "github.com/go-kit/kit/endpoint" | |
13 | "github.com/go-kit/kit/loadbalancer" | |
14 | "github.com/go-kit/kit/log" | |
15 | ) | |
16 | ||
17 | var ( | |
18 | path = "/gokit.test/service.name" | |
19 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
20 | logger = log.NewNopLogger() | |
21 | ) | |
22 | ||
23 | type fakeClient struct { | |
24 | mtx sync.Mutex | |
25 | ch chan zk.Event | |
26 | responses map[string]string | |
27 | result bool | |
28 | } | |
29 | ||
30 | func newFakeClient() *fakeClient { | |
31 | return &fakeClient{ | |
32 | ch: make(chan zk.Event, 5), | |
33 | responses: make(map[string]string), | |
34 | result: true, | |
35 | } | |
36 | } | |
37 | ||
38 | func (c *fakeClient) CreateParentNodes(path string) error { | |
39 | if path == "BadPath" { | |
40 | return errors.New("Dummy Error") | |
41 | } | |
42 | return nil | |
43 | } | |
44 | ||
45 | func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) { | |
46 | c.mtx.Lock() | |
47 | defer c.mtx.Unlock() | |
48 | if c.result == false { | |
49 | c.result = true | |
50 | return []string{}, c.ch, errors.New("Dummy Error") | |
51 | } | |
52 | responses := []string{} | |
53 | for _, data := range c.responses { | |
54 | responses = append(responses, data) | |
55 | } | |
56 | return responses, c.ch, nil | |
57 | } | |
58 | ||
59 | func (c *fakeClient) AddService(node, data string) { | |
60 | c.mtx.Lock() | |
61 | defer c.mtx.Unlock() | |
62 | c.responses[node] = data | |
63 | c.ch <- zk.Event{} | |
64 | } | |
65 | ||
66 | func (c *fakeClient) RemoveService(node string) { | |
67 | c.mtx.Lock() | |
68 | defer c.mtx.Unlock() | |
69 | delete(c.responses, node) | |
70 | c.ch <- zk.Event{} | |
71 | } | |
72 | ||
73 | func (c *fakeClient) SendErrorOnWatch() { | |
74 | c.mtx.Lock() | |
75 | defer c.mtx.Unlock() | |
76 | c.result = false | |
77 | c.ch <- zk.Event{} | |
78 | } | |
79 | ||
80 | func (c *fakeClient) ErrorIsConsumed(t time.Duration) error { | |
81 | timeout := time.After(t) | |
82 | for { | |
83 | select { | |
84 | case <-timeout: | |
85 | return fmt.Errorf("expected error not consumed after timeout %s", t.String()) | |
86 | default: | |
87 | c.mtx.Lock() | |
88 | if c.result == false { | |
89 | c.mtx.Unlock() | |
90 | return nil | |
91 | } | |
92 | c.mtx.Unlock() | |
93 | } | |
94 | } | |
95 | } | |
96 | ||
97 | func (c *fakeClient) Stop() {} | |
98 | ||
99 | func newFactory(fakeError string) loadbalancer.Factory { | |
100 | return func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
101 | if fakeError == instance { | |
102 | return nil, nil, errors.New(fakeError) | |
103 | } | |
104 | return e, nil, nil | |
105 | } | |
106 | } | |
107 | ||
108 | func asyncTest(timeout time.Duration, want int, p *Publisher) (err error) { | |
109 | var endpoints []endpoint.Endpoint | |
110 | // want can never be -1 | |
111 | have := -1 | |
112 | t := time.After(timeout) | |
113 | for { | |
114 | select { | |
115 | case <-t: | |
116 | return fmt.Errorf("want %d, have %d after timeout %s", want, have, timeout.String()) | |
117 | default: | |
118 | endpoints, err = p.Endpoints() | |
119 | have = len(endpoints) | |
120 | if err != nil || want == have { | |
121 | return | |
122 | } | |
123 | time.Sleep(time.Millisecond) | |
124 | } | |
125 | } | |
126 | } |