sd: port, without service.Service
Peter Bourgon
5 years ago
8 | 8 | // Endpoint is the fundamental building block of servers and clients. |
9 | 9 | // It represents a single RPC method. |
10 | 10 | type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error) |
11 | ||
12 | // Nop is an endpoint that does nothing and returns a nil error. | |
13 | // Useful for tests. | |
14 | func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
11 | 15 | |
12 | 16 | // Middleware is a chainable behavior modifier for endpoints. |
13 | 17 | type Middleware func(Endpoint) Endpoint |
0 | package cache | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "testing" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | func BenchmarkEndpoints(b *testing.B) { | |
11 | var ( | |
12 | ca = make(closer) | |
13 | cb = make(closer) | |
14 | cmap = map[string]io.Closer{"a": ca, "b": cb} | |
15 | factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, cmap[instance], nil } | |
16 | c = New(factory, log.NewNopLogger()) | |
17 | ) | |
18 | ||
19 | b.ReportAllocs() | |
20 | ||
21 | c.Update([]string{"a", "b"}) | |
22 | ||
23 | b.RunParallel(func(pb *testing.PB) { | |
24 | for pb.Next() { | |
25 | c.Endpoints() | |
26 | } | |
27 | }) | |
28 | } |
0 | package cache | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "sort" | |
5 | "sync" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/log" | |
9 | "github.com/go-kit/kit/sd" | |
10 | ) | |
11 | ||
12 | // Cache collects the most recent set of endpoints from a service discovery | |
13 | // system via a subscriber, and makes them available to consumers. Cache is | |
14 | // meant to be embedded inside of a concrete subscriber, and can serve Service | |
15 | // invocations directly. | |
16 | type Cache struct { | |
17 | mtx sync.RWMutex | |
18 | factory sd.Factory | |
19 | cache map[string]endpointCloser | |
20 | slice []endpoint.Endpoint | |
21 | logger log.Logger | |
22 | } | |
23 | ||
24 | type endpointCloser struct { | |
25 | endpoint.Endpoint | |
26 | io.Closer | |
27 | } | |
28 | ||
29 | // New returns a new, empty endpoint cache. | |
30 | func New(factory sd.Factory, logger log.Logger) *Cache { | |
31 | return &Cache{ | |
32 | factory: factory, | |
33 | cache: map[string]endpointCloser{}, | |
34 | logger: logger, | |
35 | } | |
36 | } | |
37 | ||
38 | // Update should be invoked by clients with a complete set of current instance | |
39 | // strings whenever that set changes. The cache manufactures new endpoints via | |
40 | // the factory, closes old endpoints when they disappear, and persists existing | |
41 | // endpoints if they survive through an update. | |
42 | func (c *Cache) Update(instances []string) { | |
43 | c.mtx.Lock() | |
44 | defer c.mtx.Unlock() | |
45 | ||
46 | // Deterministic order (for later). | |
47 | sort.Strings(instances) | |
48 | ||
49 | // Produce the current set of services. | |
50 | cache := make(map[string]endpointCloser, len(instances)) | |
51 | for _, instance := range instances { | |
52 | // If it already exists, just copy it over. | |
53 | if sc, ok := c.cache[instance]; ok { | |
54 | cache[instance] = sc | |
55 | delete(c.cache, instance) | |
56 | continue | |
57 | } | |
58 | ||
59 | // If it doesn't exist, create it. | |
60 | service, closer, err := c.factory(instance) | |
61 | if err != nil { | |
62 | c.logger.Log("instance", instance, "err", err) | |
63 | continue | |
64 | } | |
65 | cache[instance] = endpointCloser{service, closer} | |
66 | } | |
67 | ||
68 | // Close any leftover endpoints. | |
69 | for _, sc := range c.cache { | |
70 | if sc.Closer != nil { | |
71 | sc.Closer.Close() | |
72 | } | |
73 | } | |
74 | ||
75 | // Populate the slice of endpoints. | |
76 | slice := make([]endpoint.Endpoint, 0, len(cache)) | |
77 | for _, instance := range instances { | |
78 | // A bad factory may mean an instance is not present. | |
79 | if _, ok := cache[instance]; !ok { | |
80 | continue | |
81 | } | |
82 | slice = append(slice, cache[instance].Endpoint) | |
83 | } | |
84 | ||
85 | // Swap and trigger GC for old copies. | |
86 | c.slice = slice | |
87 | c.cache = cache | |
88 | } | |
89 | ||
90 | // Endpoints yields the current set of (presumably identical) endpoints, ordered | |
91 | // lexicographically by the corresponding instance string. | |
92 | func (c *Cache) Endpoints() []endpoint.Endpoint { | |
93 | c.mtx.RLock() | |
94 | defer c.mtx.RUnlock() | |
95 | return c.slice | |
96 | } |
0 | package cache | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/log" | |
10 | ) | |
11 | ||
12 | func TestCache(t *testing.T) { | |
13 | var ( | |
14 | ca = make(closer) | |
15 | cb = make(closer) | |
16 | c = map[string]io.Closer{"a": ca, "b": cb} | |
17 | f = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, c[instance], nil } | |
18 | cache = New(f, log.NewNopLogger()) | |
19 | ) | |
20 | ||
21 | // Populate | |
22 | cache.Update([]string{"a", "b"}) | |
23 | select { | |
24 | case <-ca: | |
25 | t.Errorf("endpoint a closed, not good") | |
26 | case <-cb: | |
27 | t.Errorf("endpoint b closed, not good") | |
28 | case <-time.After(time.Millisecond): | |
29 | t.Logf("no closures yet, good") | |
30 | } | |
31 | if want, have := 2, len(cache.Endpoints()); want != have { | |
32 | t.Errorf("want %d, have %d", want, have) | |
33 | } | |
34 | ||
35 | // Duplicate, should be no-op | |
36 | cache.Update([]string{"a", "b"}) | |
37 | select { | |
38 | case <-ca: | |
39 | t.Errorf("endpoint a closed, not good") | |
40 | case <-cb: | |
41 | t.Errorf("endpoint b closed, not good") | |
42 | case <-time.After(time.Millisecond): | |
43 | t.Logf("no closures yet, good") | |
44 | } | |
45 | if want, have := 2, len(cache.Endpoints()); want != have { | |
46 | t.Errorf("want %d, have %d", want, have) | |
47 | } | |
48 | ||
49 | // Delete b | |
50 | go cache.Update([]string{"a"}) | |
51 | select { | |
52 | case <-ca: | |
53 | t.Errorf("endpoint a closed, not good") | |
54 | case <-cb: | |
55 | t.Logf("endpoint b closed, good") | |
56 | case <-time.After(time.Second): | |
57 | t.Errorf("didn't close the deleted instance in time") | |
58 | } | |
59 | if want, have := 1, len(cache.Endpoints()); want != have { | |
60 | t.Errorf("want %d, have %d", want, have) | |
61 | } | |
62 | ||
63 | // Delete a | |
64 | go cache.Update([]string{}) | |
65 | select { | |
66 | // case <-cb: will succeed, as it's closed | |
67 | case <-ca: | |
68 | t.Logf("endpoint a closed, good") | |
69 | case <-time.After(time.Second): | |
70 | t.Errorf("didn't close the deleted instance in time") | |
71 | } | |
72 | if want, have := 0, len(cache.Endpoints()); want != have { | |
73 | t.Errorf("want %d, have %d", want, have) | |
74 | } | |
75 | } | |
76 | ||
77 | func TestBadFactory(t *testing.T) { | |
78 | cache := New(func(string) (endpoint.Endpoint, io.Closer, error) { | |
79 | return nil, nil, errors.New("bad factory") | |
80 | }, log.NewNopLogger()) | |
81 | ||
82 | cache.Update([]string{"foo:1234", "bar:5678"}) | |
83 | if want, have := 0, len(cache.Endpoints()); want != have { | |
84 | t.Errorf("want %d, have %d", want, have) | |
85 | } | |
86 | } | |
87 | ||
88 | type closer chan struct{} | |
89 | ||
90 | func (c closer) Close() error { close(c); return nil } |
0 | package consul | |
1 | ||
2 | import consul "github.com/hashicorp/consul/api" | |
3 | ||
4 | // Client is a wrapper around the Consul API. | |
5 | type Client interface { | |
6 | // Register a service with the local agent. | |
7 | Register(r *consul.AgentServiceRegistration) error | |
8 | ||
9 | // Deregister a service with the local agent. | |
10 | Deregister(r *consul.AgentServiceRegistration) error | |
11 | ||
12 | // Service | |
13 | Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) | |
14 | } | |
15 | ||
16 | type client struct { | |
17 | consul *consul.Client | |
18 | } | |
19 | ||
20 | // NewClient returns an implementation of the Client interface, wrapping a | |
21 | // concrete Consul client. | |
22 | func NewClient(c *consul.Client) Client { | |
23 | return &client{consul: c} | |
24 | } | |
25 | ||
26 | func (c *client) Register(r *consul.AgentServiceRegistration) error { | |
27 | return c.consul.Agent().ServiceRegister(r) | |
28 | } | |
29 | ||
30 | func (c *client) Deregister(r *consul.AgentServiceRegistration) error { | |
31 | return c.consul.Agent().ServiceDeregister(r.ID) | |
32 | } | |
33 | ||
34 | func (c *client) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { | |
35 | return c.consul.Health().Service(service, tag, passingOnly, queryOpts) | |
36 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "reflect" | |
6 | "testing" | |
7 | ||
8 | stdconsul "github.com/hashicorp/consul/api" | |
9 | "golang.org/x/net/context" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | ) | |
13 | ||
14 | func TestClientRegistration(t *testing.T) { | |
15 | c := newTestClient(nil) | |
16 | ||
17 | services, _, err := c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) | |
18 | if err != nil { | |
19 | t.Error(err) | |
20 | } | |
21 | if want, have := 0, len(services); want != have { | |
22 | t.Errorf("want %d, have %d", want, have) | |
23 | } | |
24 | ||
25 | if err := c.Register(testRegistration); err != nil { | |
26 | t.Error(err) | |
27 | } | |
28 | ||
29 | if err := c.Register(testRegistration); err == nil { | |
30 | t.Errorf("want error, have %v", err) | |
31 | } | |
32 | ||
33 | services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) | |
34 | if err != nil { | |
35 | t.Error(err) | |
36 | } | |
37 | if want, have := 1, len(services); want != have { | |
38 | t.Errorf("want %d, have %d", want, have) | |
39 | } | |
40 | ||
41 | if err := c.Deregister(testRegistration); err != nil { | |
42 | t.Error(err) | |
43 | } | |
44 | ||
45 | if err := c.Deregister(testRegistration); err == nil { | |
46 | t.Errorf("want error, have %v", err) | |
47 | } | |
48 | ||
49 | services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) | |
50 | if err != nil { | |
51 | t.Error(err) | |
52 | } | |
53 | if want, have := 0, len(services); want != have { | |
54 | t.Errorf("want %d, have %d", want, have) | |
55 | } | |
56 | } | |
57 | ||
58 | type testClient struct { | |
59 | entries []*stdconsul.ServiceEntry | |
60 | } | |
61 | ||
62 | func newTestClient(entries []*stdconsul.ServiceEntry) *testClient { | |
63 | return &testClient{ | |
64 | entries: entries, | |
65 | } | |
66 | } | |
67 | ||
68 | var _ Client = &testClient{} | |
69 | ||
70 | func (c *testClient) Service(service, tag string, _ bool, opts *stdconsul.QueryOptions) ([]*stdconsul.ServiceEntry, *stdconsul.QueryMeta, error) { | |
71 | var results []*stdconsul.ServiceEntry | |
72 | ||
73 | for _, entry := range c.entries { | |
74 | if entry.Service.Service != service { | |
75 | continue | |
76 | } | |
77 | if tag != "" { | |
78 | tagMap := map[string]struct{}{} | |
79 | ||
80 | for _, t := range entry.Service.Tags { | |
81 | tagMap[t] = struct{}{} | |
82 | } | |
83 | ||
84 | if _, ok := tagMap[tag]; !ok { | |
85 | continue | |
86 | } | |
87 | } | |
88 | ||
89 | results = append(results, entry) | |
90 | } | |
91 | ||
92 | return results, &stdconsul.QueryMeta{}, nil | |
93 | } | |
94 | ||
95 | func (c *testClient) Register(r *stdconsul.AgentServiceRegistration) error { | |
96 | toAdd := registration2entry(r) | |
97 | ||
98 | for _, entry := range c.entries { | |
99 | if reflect.DeepEqual(*entry, *toAdd) { | |
100 | return errors.New("duplicate") | |
101 | } | |
102 | } | |
103 | ||
104 | c.entries = append(c.entries, toAdd) | |
105 | return nil | |
106 | } | |
107 | ||
108 | func (c *testClient) Deregister(r *stdconsul.AgentServiceRegistration) error { | |
109 | toDelete := registration2entry(r) | |
110 | ||
111 | var newEntries []*stdconsul.ServiceEntry | |
112 | for _, entry := range c.entries { | |
113 | if reflect.DeepEqual(*entry, *toDelete) { | |
114 | continue | |
115 | } | |
116 | newEntries = append(newEntries, entry) | |
117 | } | |
118 | if len(newEntries) == len(c.entries) { | |
119 | return errors.New("not found") | |
120 | } | |
121 | ||
122 | c.entries = newEntries | |
123 | return nil | |
124 | } | |
125 | ||
126 | func registration2entry(r *stdconsul.AgentServiceRegistration) *stdconsul.ServiceEntry { | |
127 | return &stdconsul.ServiceEntry{ | |
128 | Node: &stdconsul.Node{ | |
129 | Node: "some-node", | |
130 | Address: r.Address, | |
131 | }, | |
132 | Service: &stdconsul.AgentService{ | |
133 | ID: r.ID, | |
134 | Service: r.Name, | |
135 | Tags: r.Tags, | |
136 | Port: r.Port, | |
137 | Address: r.Address, | |
138 | }, | |
139 | // Checks ignored | |
140 | } | |
141 | } | |
142 | ||
143 | func testFactory(instance string) (endpoint.Endpoint, io.Closer, error) { | |
144 | return func(context.Context, interface{}) (interface{}, error) { | |
145 | return instance, nil | |
146 | }, nil, nil | |
147 | } | |
148 | ||
149 | var testRegistration = &stdconsul.AgentServiceRegistration{ | |
150 | ID: "my-id", | |
151 | Name: "my-name", | |
152 | Tags: []string{"my-tag-1", "my-tag-2"}, | |
153 | Port: 12345, | |
154 | Address: "my-address", | |
155 | } |
0 | // +build integration | |
1 | ||
2 | package consul | |
3 | ||
4 | import ( | |
5 | "io" | |
6 | "os" | |
7 | "testing" | |
8 | "time" | |
9 | ||
10 | "github.com/go-kit/kit/log" | |
11 | "github.com/go-kit/kit/service" | |
12 | stdconsul "github.com/hashicorp/consul/api" | |
13 | ) | |
14 | ||
15 | func TestIntegration(t *testing.T) { | |
16 | // Connect to Consul. | |
17 | // docker run -p 8500:8500 progrium/consul -server -bootstrap | |
18 | consulAddr := os.Getenv("CONSUL_ADDRESS") | |
19 | if consulAddr == "" { | |
20 | t.Fatal("CONSUL_ADDRESS is not set") | |
21 | } | |
22 | stdClient, err := stdconsul.NewClient(&stdconsul.Config{ | |
23 | Address: consulAddr, | |
24 | }) | |
25 | if err != nil { | |
26 | t.Fatal(err) | |
27 | } | |
28 | client := NewClient(stdClient) | |
29 | logger := log.NewLogfmtLogger(os.Stderr) | |
30 | ||
31 | // Produce a fake service registration. | |
32 | r := &stdconsul.AgentServiceRegistration{ | |
33 | ID: "my-service-ID", | |
34 | Name: "my-service-name", | |
35 | Tags: []string{"alpha", "beta"}, | |
36 | Port: 12345, | |
37 | Address: "my-address", | |
38 | EnableTagOverride: false, | |
39 | // skipping check(s) | |
40 | } | |
41 | ||
42 | // Build a subscriber on r.Name + r.Tags. | |
43 | factory := func(instance string) (service.Service, io.Closer, error) { | |
44 | t.Logf("factory invoked for %q", instance) | |
45 | return service.Fixed{}, nil, nil | |
46 | } | |
47 | subscriber, err := NewSubscriber( | |
48 | client, | |
49 | factory, | |
50 | log.NewContext(logger).With("component", "subscriber"), | |
51 | r.Name, | |
52 | r.Tags, | |
53 | true, | |
54 | ) | |
55 | if err != nil { | |
56 | t.Fatal(err) | |
57 | } | |
58 | ||
59 | time.Sleep(time.Second) | |
60 | ||
61 | // Before we publish, we should have no services. | |
62 | services, err := subscriber.Services() | |
63 | if err != nil { | |
64 | t.Error(err) | |
65 | } | |
66 | if want, have := 0, len(services); want != have { | |
67 | t.Errorf("want %d, have %d", want, have) | |
68 | } | |
69 | ||
70 | // Build a registrar for r. | |
71 | registrar := NewRegistrar(client, r, log.NewContext(logger).With("component", "registrar")) | |
72 | registrar.Register() | |
73 | defer registrar.Deregister() | |
74 | ||
75 | time.Sleep(time.Second) | |
76 | ||
77 | // Now we should have one active service. | |
78 | services, err = subscriber.Services() | |
79 | if err != nil { | |
80 | t.Error(err) | |
81 | } | |
82 | if want, have := 1, len(services); want != have { | |
83 | t.Errorf("want %d, have %d", want, have) | |
84 | } | |
85 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | ||
5 | stdconsul "github.com/hashicorp/consul/api" | |
6 | ||
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | // Registrar registers service instance liveness information to Consul. | |
11 | type Registrar struct { | |
12 | client Client | |
13 | registration *stdconsul.AgentServiceRegistration | |
14 | logger log.Logger | |
15 | } | |
16 | ||
17 | // NewRegistrar returns a Consul Registrar acting on the provided catalog | |
18 | // registration. | |
19 | func NewRegistrar(client Client, r *stdconsul.AgentServiceRegistration, logger log.Logger) *Registrar { | |
20 | return &Registrar{ | |
21 | client: client, | |
22 | registration: r, | |
23 | logger: log.NewContext(logger).With("service", r.Name, "tags", fmt.Sprint(r.Tags), "address", r.Address), | |
24 | } | |
25 | } | |
26 | ||
27 | // Register implements sd.Registrar interface. | |
28 | func (p *Registrar) Register() { | |
29 | if err := p.client.Register(p.registration); err != nil { | |
30 | p.logger.Log("err", err) | |
31 | } else { | |
32 | p.logger.Log("action", "register") | |
33 | } | |
34 | } | |
35 | ||
36 | // Deregister implements sd.Registrar interface. | |
37 | func (p *Registrar) Deregister() { | |
38 | if err := p.client.Deregister(p.registration); err != nil { | |
39 | p.logger.Log("err", err) | |
40 | } else { | |
41 | p.logger.Log("action", "deregister") | |
42 | } | |
43 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | stdconsul "github.com/hashicorp/consul/api" | |
6 | ||
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | func TestRegistrar(t *testing.T) { | |
11 | client := newTestClient([]*stdconsul.ServiceEntry{}) | |
12 | p := NewRegistrar(client, testRegistration, log.NewNopLogger()) | |
13 | if want, have := 0, len(client.entries); want != have { | |
14 | t.Errorf("want %d, have %d", want, have) | |
15 | } | |
16 | ||
17 | p.Register() | |
18 | if want, have := 1, len(client.entries); want != have { | |
19 | t.Errorf("want %d, have %d", want, have) | |
20 | } | |
21 | ||
22 | p.Deregister() | |
23 | if want, have := 0, len(client.entries); want != have { | |
24 | t.Errorf("want %d, have %d", want, have) | |
25 | } | |
26 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "io" | |
5 | ||
6 | consul "github.com/hashicorp/consul/api" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/log" | |
10 | "github.com/go-kit/kit/sd" | |
11 | "github.com/go-kit/kit/sd/cache" | |
12 | ) | |
13 | ||
14 | const defaultIndex = 0 | |
15 | ||
16 | // Subscriber yields endpoints for a service in Consul. Updates to the service | |
17 | // are watched and will update the Subscriber endpoints. | |
18 | type Subscriber struct { | |
19 | cache *cache.Cache | |
20 | client Client | |
21 | logger log.Logger | |
22 | service string | |
23 | tags []string | |
24 | passingOnly bool | |
25 | endpointsc chan []endpoint.Endpoint | |
26 | quitc chan struct{} | |
27 | } | |
28 | ||
29 | var _ sd.Subscriber = &Subscriber{} | |
30 | ||
31 | // NewSubscriber returns a Consul subscriber which returns endpoints for the | |
32 | // requested service. It only returns instances for which all of the passed tags | |
33 | // are present. | |
34 | func NewSubscriber(client Client, factory sd.Factory, logger log.Logger, service string, tags []string, passingOnly bool) (*Subscriber, error) { | |
35 | s := &Subscriber{ | |
36 | cache: cache.New(factory, logger), | |
37 | client: client, | |
38 | logger: log.NewContext(logger).With("service", service, "tags", fmt.Sprint(tags)), | |
39 | service: service, | |
40 | tags: tags, | |
41 | passingOnly: passingOnly, | |
42 | quitc: make(chan struct{}), | |
43 | } | |
44 | ||
45 | instances, index, err := s.getInstances(defaultIndex, nil) | |
46 | if err == nil { | |
47 | s.logger.Log("instances", len(instances)) | |
48 | } else { | |
49 | s.logger.Log("err", err) | |
50 | } | |
51 | ||
52 | s.cache.Update(instances) | |
53 | go s.loop(index) | |
54 | return s, nil | |
55 | } | |
56 | ||
57 | // Endpoints implements the Subscriber interface. | |
58 | func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { | |
59 | return s.cache.Endpoints(), nil | |
60 | } | |
61 | ||
62 | // Stop terminates the subscriber. | |
63 | func (s *Subscriber) Stop() { | |
64 | close(s.quitc) | |
65 | } | |
66 | ||
67 | func (s *Subscriber) loop(lastIndex uint64) { | |
68 | var ( | |
69 | instances []string | |
70 | err error | |
71 | ) | |
72 | for { | |
73 | instances, lastIndex, err = s.getInstances(lastIndex, s.quitc) | |
74 | switch { | |
75 | case err == io.EOF: | |
76 | return // stopped via quitc | |
77 | case err != nil: | |
78 | s.logger.Log("err", err) | |
79 | default: | |
80 | s.cache.Update(instances) | |
81 | } | |
82 | } | |
83 | } | |
84 | ||
85 | func (s *Subscriber) getInstances(lastIndex uint64, interruptc chan struct{}) ([]string, uint64, error) { | |
86 | tag := "" | |
87 | if len(s.tags) > 0 { | |
88 | tag = s.tags[0] | |
89 | } | |
90 | ||
91 | // Consul doesn't support more than one tag in its service query method. | |
92 | // https://github.com/hashicorp/consul/issues/294 | |
93 | // Hashi suggest prepared queries, but they don't support blocking. | |
94 | // https://www.consul.io/docs/agent/http/query.html#execute | |
95 | // If we want blocking for efficiency, we must filter tags manually. | |
96 | ||
97 | type response struct { | |
98 | instances []string | |
99 | index uint64 | |
100 | } | |
101 | ||
102 | var ( | |
103 | errc = make(chan error, 1) | |
104 | resc = make(chan response, 1) | |
105 | ) | |
106 | ||
107 | go func() { | |
108 | entries, meta, err := s.client.Service(s.service, tag, s.passingOnly, &consul.QueryOptions{ | |
109 | WaitIndex: lastIndex, | |
110 | }) | |
111 | if err != nil { | |
112 | errc <- err | |
113 | return | |
114 | } | |
115 | if len(s.tags) > 1 { | |
116 | entries = filterEntries(entries, s.tags[1:]...) | |
117 | } | |
118 | resc <- response{ | |
119 | instances: makeInstances(entries), | |
120 | index: meta.LastIndex, | |
121 | } | |
122 | }() | |
123 | ||
124 | select { | |
125 | case err := <-errc: | |
126 | return nil, 0, err | |
127 | case res := <-resc: | |
128 | return res.instances, res.index, nil | |
129 | case <-interruptc: | |
130 | return nil, 0, io.EOF | |
131 | } | |
132 | } | |
133 | ||
134 | func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry { | |
135 | var es []*consul.ServiceEntry | |
136 | ||
137 | ENTRIES: | |
138 | for _, entry := range entries { | |
139 | ts := make(map[string]struct{}, len(entry.Service.Tags)) | |
140 | for _, tag := range entry.Service.Tags { | |
141 | ts[tag] = struct{}{} | |
142 | } | |
143 | ||
144 | for _, tag := range tags { | |
145 | if _, ok := ts[tag]; !ok { | |
146 | continue ENTRIES | |
147 | } | |
148 | } | |
149 | es = append(es, entry) | |
150 | } | |
151 | ||
152 | return es | |
153 | } | |
154 | ||
155 | func makeInstances(entries []*consul.ServiceEntry) []string { | |
156 | instances := make([]string, len(entries)) | |
157 | for i, entry := range entries { | |
158 | addr := entry.Node.Address | |
159 | if entry.Service.Address != "" { | |
160 | addr = entry.Service.Address | |
161 | } | |
162 | instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port) | |
163 | } | |
164 | return instances | |
165 | } |
0 | package consul | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | consul "github.com/hashicorp/consul/api" | |
6 | "golang.org/x/net/context" | |
7 | ||
8 | "github.com/go-kit/kit/log" | |
9 | ) | |
10 | ||
11 | var consulState = []*consul.ServiceEntry{ | |
12 | { | |
13 | Node: &consul.Node{ | |
14 | Address: "10.0.0.0", | |
15 | Node: "app00.local", | |
16 | }, | |
17 | Service: &consul.AgentService{ | |
18 | ID: "search-api-0", | |
19 | Port: 8000, | |
20 | Service: "search", | |
21 | Tags: []string{ | |
22 | "api", | |
23 | "v1", | |
24 | }, | |
25 | }, | |
26 | }, | |
27 | { | |
28 | Node: &consul.Node{ | |
29 | Address: "10.0.0.1", | |
30 | Node: "app01.local", | |
31 | }, | |
32 | Service: &consul.AgentService{ | |
33 | ID: "search-api-1", | |
34 | Port: 8001, | |
35 | Service: "search", | |
36 | Tags: []string{ | |
37 | "api", | |
38 | "v2", | |
39 | }, | |
40 | }, | |
41 | }, | |
42 | { | |
43 | Node: &consul.Node{ | |
44 | Address: "10.0.0.1", | |
45 | Node: "app01.local", | |
46 | }, | |
47 | Service: &consul.AgentService{ | |
48 | Address: "10.0.0.10", | |
49 | ID: "search-db-0", | |
50 | Port: 9000, | |
51 | Service: "search", | |
52 | Tags: []string{ | |
53 | "db", | |
54 | }, | |
55 | }, | |
56 | }, | |
57 | } | |
58 | ||
59 | func TestSubscriber(t *testing.T) { | |
60 | var ( | |
61 | logger = log.NewNopLogger() | |
62 | client = newTestClient(consulState) | |
63 | ) | |
64 | ||
65 | s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api"}, true) | |
66 | if err != nil { | |
67 | t.Fatal(err) | |
68 | } | |
69 | defer s.Stop() | |
70 | ||
71 | endpoints, err := s.Endpoints() | |
72 | if err != nil { | |
73 | t.Fatal(err) | |
74 | } | |
75 | ||
76 | if want, have := 2, len(endpoints); want != have { | |
77 | t.Errorf("want %d, have %d", want, have) | |
78 | } | |
79 | } | |
80 | ||
81 | func TestSubscriberNoService(t *testing.T) { | |
82 | var ( | |
83 | logger = log.NewNopLogger() | |
84 | client = newTestClient(consulState) | |
85 | ) | |
86 | ||
87 | s, err := NewSubscriber(client, testFactory, logger, "feed", []string{}, true) | |
88 | if err != nil { | |
89 | t.Fatal(err) | |
90 | } | |
91 | defer s.Stop() | |
92 | ||
93 | endpoints, err := s.Endpoints() | |
94 | if err != nil { | |
95 | t.Fatal(err) | |
96 | } | |
97 | ||
98 | if want, have := 0, len(endpoints); want != have { | |
99 | t.Fatalf("want %d, have %d", want, have) | |
100 | } | |
101 | } | |
102 | ||
103 | func TestSubscriberWithTags(t *testing.T) { | |
104 | var ( | |
105 | logger = log.NewNopLogger() | |
106 | client = newTestClient(consulState) | |
107 | ) | |
108 | ||
109 | s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api", "v2"}, true) | |
110 | if err != nil { | |
111 | t.Fatal(err) | |
112 | } | |
113 | defer s.Stop() | |
114 | ||
115 | endpoints, err := s.Endpoints() | |
116 | if err != nil { | |
117 | t.Fatal(err) | |
118 | } | |
119 | ||
120 | if want, have := 1, len(endpoints); want != have { | |
121 | t.Fatalf("want %d, have %d", want, have) | |
122 | } | |
123 | } | |
124 | ||
125 | func TestSubscriberAddressOverride(t *testing.T) { | |
126 | s, err := NewSubscriber(newTestClient(consulState), testFactory, log.NewNopLogger(), "search", []string{"db"}, true) | |
127 | if err != nil { | |
128 | t.Fatal(err) | |
129 | } | |
130 | defer s.Stop() | |
131 | ||
132 | endpoints, err := s.Endpoints() | |
133 | if err != nil { | |
134 | t.Fatal(err) | |
135 | } | |
136 | ||
137 | if want, have := 1, len(endpoints); want != have { | |
138 | t.Fatalf("want %d, have %d", want, have) | |
139 | } | |
140 | ||
141 | response, err := endpoints[0](context.Background(), struct{}{}) | |
142 | if err != nil { | |
143 | t.Fatal(err) | |
144 | } | |
145 | ||
146 | if want, have := "10.0.0.10:9000", response.(string); want != have { | |
147 | t.Errorf("want %q, have %q", want, have) | |
148 | } | |
149 | } |
0 | package dnssrv | |
1 | ||
2 | import "net" | |
3 | ||
4 | // Lookup is a function that resolves a DNS SRV record to multiple addresses. | |
5 | // It has the same signature as net.LookupSRV. | |
6 | type Lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error) |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/log" | |
9 | "github.com/go-kit/kit/sd" | |
10 | "github.com/go-kit/kit/sd/cache" | |
11 | ) | |
12 | ||
13 | // Subscriber yields endpoints taken from the named DNS SRV record. The name is | |
14 | // resolved on a fixed schedule. Priorities and weights are ignored. | |
15 | type Subscriber struct { | |
16 | name string | |
17 | cache *cache.Cache | |
18 | logger log.Logger | |
19 | quit chan struct{} | |
20 | } | |
21 | ||
22 | // NewSubscriber returns a DNS SRV subscriber. | |
23 | func NewSubscriber( | |
24 | name string, | |
25 | ttl time.Duration, | |
26 | factory sd.Factory, | |
27 | logger log.Logger, | |
28 | ) *Subscriber { | |
29 | return NewSubscriberDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) | |
30 | } | |
31 | ||
32 | // NewSubscriberDetailed is the same as NewSubscriber, but allows users to | |
33 | // provide an explicit lookup refresh ticker instead of a TTL, and specify the | |
34 | // lookup function instead of using net.LookupSRV. | |
35 | func NewSubscriberDetailed( | |
36 | name string, | |
37 | refresh *time.Ticker, | |
38 | lookup Lookup, | |
39 | factory sd.Factory, | |
40 | logger log.Logger, | |
41 | ) *Subscriber { | |
42 | p := &Subscriber{ | |
43 | name: name, | |
44 | cache: cache.New(factory, logger), | |
45 | logger: logger, | |
46 | quit: make(chan struct{}), | |
47 | } | |
48 | ||
49 | instances, err := p.resolve(lookup) | |
50 | if err == nil { | |
51 | logger.Log("name", name, "instances", len(instances)) | |
52 | } else { | |
53 | logger.Log("name", name, "err", err) | |
54 | } | |
55 | p.cache.Update(instances) | |
56 | ||
57 | go p.loop(refresh, lookup) | |
58 | return p | |
59 | } | |
60 | ||
61 | // Stop terminates the Subscriber. | |
62 | func (p *Subscriber) Stop() { | |
63 | close(p.quit) | |
64 | } | |
65 | ||
66 | func (p *Subscriber) loop(t *time.Ticker, lookup Lookup) { | |
67 | defer t.Stop() | |
68 | for { | |
69 | select { | |
70 | case <-t.C: | |
71 | instances, err := p.resolve(lookup) | |
72 | if err != nil { | |
73 | p.logger.Log("name", p.name, "err", err) | |
74 | continue // don't replace potentially-good with bad | |
75 | } | |
76 | p.cache.Update(instances) | |
77 | ||
78 | case <-p.quit: | |
79 | return | |
80 | } | |
81 | } | |
82 | } | |
83 | ||
84 | // Endpoints implements the Subscriber interface. | |
85 | func (p *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { | |
86 | return p.cache.Endpoints(), nil | |
87 | } | |
88 | ||
89 | func (p *Subscriber) resolve(lookup Lookup) ([]string, error) { | |
90 | _, addrs, err := lookup("", "", p.name) | |
91 | if err != nil { | |
92 | return []string{}, err | |
93 | } | |
94 | instances := make([]string, len(addrs)) | |
95 | for i, addr := range addrs { | |
96 | instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
97 | } | |
98 | return instances, nil | |
99 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | "net" | |
5 | "sync/atomic" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | func TestRefresh(t *testing.T) { | |
14 | name := "some.service.internal" | |
15 | ||
16 | ticker := time.NewTicker(time.Second) | |
17 | ticker.Stop() | |
18 | tickc := make(chan time.Time) | |
19 | ticker.C = tickc | |
20 | ||
21 | var lookups uint64 | |
22 | records := []*net.SRV{} | |
23 | lookup := func(service, proto, name string) (string, []*net.SRV, error) { | |
24 | t.Logf("lookup(%q, %q, %q)", service, proto, name) | |
25 | atomic.AddUint64(&lookups, 1) | |
26 | return "cname", records, nil | |
27 | } | |
28 | ||
29 | var generates uint64 | |
30 | factory := func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
31 | t.Logf("factory(%q)", instance) | |
32 | atomic.AddUint64(&generates, 1) | |
33 | return endpoint.Nop, nopCloser{}, nil | |
34 | } | |
35 | ||
36 | subscriber := NewSubscriberDetailed(name, ticker, lookup, factory, log.NewNopLogger()) | |
37 | defer subscriber.Stop() | |
38 | ||
39 | // First lookup, empty | |
40 | endpoints, err := subscriber.Endpoints() | |
41 | if err != nil { | |
42 | t.Error(err) | |
43 | } | |
44 | if want, have := 0, len(endpoints); want != have { | |
45 | t.Errorf("want %d, have %d", want, have) | |
46 | } | |
47 | if want, have := uint64(1), atomic.LoadUint64(&lookups); want != have { | |
48 | t.Errorf("want %d, have %d", want, have) | |
49 | } | |
50 | if want, have := uint64(0), atomic.LoadUint64(&generates); want != have { | |
51 | t.Errorf("want %d, have %d", want, have) | |
52 | } | |
53 | ||
54 | // Load some records and lookup again | |
55 | records = []*net.SRV{ | |
56 | &net.SRV{Target: "1.0.0.1", Port: 1001}, | |
57 | &net.SRV{Target: "1.0.0.2", Port: 1002}, | |
58 | &net.SRV{Target: "1.0.0.3", Port: 1003}, | |
59 | } | |
60 | tickc <- time.Now() | |
61 | ||
62 | // There is a race condition where the subscriber.Endpoints call below | |
63 | // invokes the cache before it is updated by the tick above. | |
64 | // TODO(pb): solve by running the read through the loop goroutine. | |
65 | time.Sleep(100 * time.Millisecond) | |
66 | ||
67 | endpoints, err = subscriber.Endpoints() | |
68 | if err != nil { | |
69 | t.Error(err) | |
70 | } | |
71 | if want, have := 3, len(endpoints); want != have { | |
72 | t.Errorf("want %d, have %d", want, have) | |
73 | } | |
74 | if want, have := uint64(2), atomic.LoadUint64(&lookups); want != have { | |
75 | t.Errorf("want %d, have %d", want, have) | |
76 | } | |
77 | if want, have := uint64(len(records)), atomic.LoadUint64(&generates); want != have { | |
78 | t.Errorf("want %d, have %d", want, have) | |
79 | } | |
80 | } | |
81 | ||
82 | type nopCloser struct{} | |
83 | ||
84 | func (nopCloser) Close() error { return nil } |
0 | // Package sd provides utilities related to service discovery. That includes | |
1 | // subscribing to service discovery systems in order to reach remote instances, | |
2 | // and publishing to service discovery systems to make an instance available. | |
3 | // Implementations are provided for most common systems. | |
4 | package sd |
0 | package etcd | |
1 | ||
2 | import ( | |
3 | "crypto/tls" | |
4 | "crypto/x509" | |
5 | "io/ioutil" | |
6 | "net" | |
7 | "net/http" | |
8 | "time" | |
9 | ||
10 | etcd "github.com/coreos/etcd/client" | |
11 | "golang.org/x/net/context" | |
12 | ) | |
13 | ||
14 | // Client is a wrapper around the etcd client. | |
15 | type Client interface { | |
16 | // GetEntries will query the given prefix in etcd and returns a set of entries. | |
17 | GetEntries(prefix string) ([]string, error) | |
18 | ||
19 | // WatchPrefix starts watching every change for given prefix in etcd. When an | |
20 | // change is detected it will populate the responseChan when an *etcd.Response. | |
21 | WatchPrefix(prefix string, responseChan chan *etcd.Response) | |
22 | } | |
23 | ||
24 | type client struct { | |
25 | keysAPI etcd.KeysAPI | |
26 | ctx context.Context | |
27 | } | |
28 | ||
29 | // ClientOptions defines options for the etcd client. | |
30 | type ClientOptions struct { | |
31 | Cert string | |
32 | Key string | |
33 | CaCert string | |
34 | DialTimeout time.Duration | |
35 | DialKeepAline time.Duration | |
36 | HeaderTimeoutPerRequest time.Duration | |
37 | } | |
38 | ||
39 | // NewClient returns an *etcd.Client with a connection to the named machines. | |
40 | // It will return an error if a connection to the cluster cannot be made. | |
41 | // The parameter machines needs to be a full URL with schemas. | |
42 | // e.g. "http://localhost:2379" will work, but "localhost:2379" will not. | |
43 | func NewClient(ctx context.Context, machines []string, options ClientOptions) (Client, error) { | |
44 | var ( | |
45 | c etcd.KeysAPI | |
46 | err error | |
47 | caCertCt []byte | |
48 | tlsCert tls.Certificate | |
49 | ) | |
50 | ||
51 | if options.Cert != "" && options.Key != "" { | |
52 | tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key) | |
53 | if err != nil { | |
54 | return nil, err | |
55 | } | |
56 | ||
57 | caCertCt, err = ioutil.ReadFile(options.CaCert) | |
58 | if err != nil { | |
59 | return nil, err | |
60 | } | |
61 | caCertPool := x509.NewCertPool() | |
62 | caCertPool.AppendCertsFromPEM(caCertCt) | |
63 | ||
64 | tlsConfig := &tls.Config{ | |
65 | Certificates: []tls.Certificate{tlsCert}, | |
66 | RootCAs: caCertPool, | |
67 | } | |
68 | ||
69 | transport := &http.Transport{ | |
70 | TLSClientConfig: tlsConfig, | |
71 | Dial: func(network, addr string) (net.Conn, error) { | |
72 | dial := &net.Dialer{ | |
73 | Timeout: options.DialTimeout, | |
74 | KeepAlive: options.DialKeepAline, | |
75 | } | |
76 | return dial.Dial(network, addr) | |
77 | }, | |
78 | } | |
79 | ||
80 | cfg := etcd.Config{ | |
81 | Endpoints: machines, | |
82 | Transport: transport, | |
83 | HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, | |
84 | } | |
85 | ce, err := etcd.New(cfg) | |
86 | if err != nil { | |
87 | return nil, err | |
88 | } | |
89 | c = etcd.NewKeysAPI(ce) | |
90 | } else { | |
91 | cfg := etcd.Config{ | |
92 | Endpoints: machines, | |
93 | Transport: etcd.DefaultTransport, | |
94 | HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest, | |
95 | } | |
96 | ce, err := etcd.New(cfg) | |
97 | if err != nil { | |
98 | return nil, err | |
99 | } | |
100 | c = etcd.NewKeysAPI(ce) | |
101 | } | |
102 | ||
103 | return &client{c, ctx}, nil | |
104 | } | |
105 | ||
106 | // GetEntries implements the etcd Client interface. | |
107 | func (c *client) GetEntries(key string) ([]string, error) { | |
108 | resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true}) | |
109 | if err != nil { | |
110 | return nil, err | |
111 | } | |
112 | ||
113 | entries := make([]string, len(resp.Node.Nodes)) | |
114 | for i, node := range resp.Node.Nodes { | |
115 | entries[i] = node.Value | |
116 | } | |
117 | return entries, nil | |
118 | } | |
119 | ||
120 | // WatchPrefix implements the etcd Client interface. | |
121 | func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) { | |
122 | watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true}) | |
123 | for { | |
124 | res, err := watch.Next(c.ctx) | |
125 | if err != nil { | |
126 | return | |
127 | } | |
128 | responseChan <- res | |
129 | } | |
130 | } |
0 | package etcd | |
1 | ||
2 | import ( | |
3 | etcd "github.com/coreos/etcd/client" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/log" | |
7 | "github.com/go-kit/kit/sd" | |
8 | "github.com/go-kit/kit/sd/cache" | |
9 | ) | |
10 | ||
11 | // Subscriber yield endpoints stored in a certain etcd keyspace. Any kind of | |
12 | // change in that keyspace is watched and will update the Subscriber endpoints. | |
13 | type Subscriber struct { | |
14 | client Client | |
15 | prefix string | |
16 | cache *cache.Cache | |
17 | logger log.Logger | |
18 | quitc chan struct{} | |
19 | } | |
20 | ||
21 | var _ sd.Subscriber = &Subscriber{} | |
22 | ||
23 | // NewSubscriber returns an etcd subscriber. It will start watching the given | |
24 | // prefix for changes, and update the endpoints. | |
25 | func NewSubscriber(c Client, prefix string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { | |
26 | s := &Subscriber{ | |
27 | client: c, | |
28 | prefix: prefix, | |
29 | cache: cache.New(factory, logger), | |
30 | logger: logger, | |
31 | quitc: make(chan struct{}), | |
32 | } | |
33 | ||
34 | instances, err := s.client.GetEntries(s.prefix) | |
35 | if err == nil { | |
36 | logger.Log("prefix", s.prefix, "instances", len(instances)) | |
37 | } else { | |
38 | logger.Log("prefix", s.prefix, "err", err) | |
39 | } | |
40 | s.cache.Update(instances) | |
41 | ||
42 | go s.loop() | |
43 | return s, nil | |
44 | } | |
45 | ||
46 | func (s *Subscriber) loop() { | |
47 | responseChan := make(chan *etcd.Response) | |
48 | go s.client.WatchPrefix(s.prefix, responseChan) | |
49 | for { | |
50 | select { | |
51 | case <-responseChan: | |
52 | instances, err := s.client.GetEntries(s.prefix) | |
53 | if err != nil { | |
54 | s.logger.Log("msg", "failed to retrieve entries", "err", err) | |
55 | continue | |
56 | } | |
57 | s.cache.Update(instances) | |
58 | ||
59 | case <-s.quitc: | |
60 | return | |
61 | } | |
62 | } | |
63 | } | |
64 | ||
65 | // Endpoints implements the Subscriber interface. | |
66 | func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { | |
67 | return s.cache.Endpoints(), nil | |
68 | } | |
69 | ||
70 | // Stop terminates the Subscriber. | |
71 | func (s *Subscriber) Stop() { | |
72 | close(s.quitc) | |
73 | } |
0 | package etcd | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "testing" | |
6 | ||
7 | stdetcd "github.com/coreos/etcd/client" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | var ( | |
14 | node = &stdetcd.Node{ | |
15 | Key: "/foo", | |
16 | Nodes: []*stdetcd.Node{ | |
17 | {Key: "/foo/1", Value: "1:1"}, | |
18 | {Key: "/foo/2", Value: "1:2"}, | |
19 | }, | |
20 | } | |
21 | fakeResponse = &stdetcd.Response{ | |
22 | Node: node, | |
23 | } | |
24 | ) | |
25 | ||
26 | func TestSubscriber(t *testing.T) { | |
27 | factory := func(string) (endpoint.Endpoint, io.Closer, error) { | |
28 | return endpoint.Nop, nil, nil | |
29 | } | |
30 | ||
31 | client := &fakeClient{ | |
32 | responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, | |
33 | } | |
34 | ||
35 | s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) | |
36 | if err != nil { | |
37 | t.Fatal(err) | |
38 | } | |
39 | defer s.Stop() | |
40 | ||
41 | if _, err := s.Endpoints(); err != nil { | |
42 | t.Fatal(err) | |
43 | } | |
44 | } | |
45 | ||
46 | func TestBadFactory(t *testing.T) { | |
47 | factory := func(string) (endpoint.Endpoint, io.Closer, error) { | |
48 | return nil, nil, errors.New("kaboom") | |
49 | } | |
50 | ||
51 | client := &fakeClient{ | |
52 | responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, | |
53 | } | |
54 | ||
55 | s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) | |
56 | if err != nil { | |
57 | t.Fatal(err) | |
58 | } | |
59 | defer s.Stop() | |
60 | ||
61 | endpoints, err := s.Endpoints() | |
62 | if err != nil { | |
63 | t.Fatal(err) | |
64 | } | |
65 | ||
66 | if want, have := 0, len(endpoints); want != have { | |
67 | t.Errorf("want %d, have %d", want, have) | |
68 | } | |
69 | } | |
70 | ||
71 | type fakeClient struct { | |
72 | responses map[string]*stdetcd.Response | |
73 | } | |
74 | ||
75 | func (c *fakeClient) GetEntries(prefix string) ([]string, error) { | |
76 | response, ok := c.responses[prefix] | |
77 | if !ok { | |
78 | return nil, errors.New("key not exist") | |
79 | } | |
80 | ||
81 | entries := make([]string, len(response.Node.Nodes)) | |
82 | for i, node := range response.Node.Nodes { | |
83 | entries[i] = node.Value | |
84 | } | |
85 | return entries, nil | |
86 | } | |
87 | ||
88 | func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {} |
0 | package sd | |
1 | ||
2 | import ( | |
3 | "io" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Factory is a function that converts an instance string (e.g. host:port) to a | |
9 | // specific endpoint. Instances that provide multiple endpoints require multiple | |
10 | // factories. A factory also returns an io.Closer that's invoked when the | |
11 | // instance goes away and needs to be cleaned up. | |
12 | // | |
13 | // Users are expected to provide their own factory functions that assume | |
14 | // specific transports, or can deduce transports by parsing the instance string. | |
15 | type Factory func(instance string) (endpoint.Endpoint, io.Closer, error) |
0 | package sd | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // FixedSubscriber yields a fixed set of services. | |
5 | type FixedSubscriber []endpoint.Endpoint | |
6 | ||
7 | // Endpoints implements Subscriber. | |
8 | func (s FixedSubscriber) Endpoints() ([]endpoint.Endpoint, error) { return s, nil } |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ) | |
7 | ||
8 | // Balancer yields endpoints according to some heuristic. | |
9 | type Balancer interface { | |
10 | Endpoint() (endpoint.Endpoint, error) | |
11 | } | |
12 | ||
13 | // ErrNoEndpoints is returned when no qualifying endpoints are available. | |
14 | var ErrNoEndpoints = errors.New("no endpoints available") |
0 | // Package lb deals with client-side load balancing across multiple identical | |
1 | // instances of services and endpoints. When combined with a service discovery | |
2 | // system of record, it enables a more decentralized architecture, removing the | |
3 | // need for separate load balancers like HAProxy. | |
4 | package lb |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/sd" | |
7 | ) | |
8 | ||
9 | // NewRandom returns a load balancer that selects services randomly. | |
10 | func NewRandom(s sd.Subscriber, seed int64) Balancer { | |
11 | return &random{ | |
12 | s: s, | |
13 | r: rand.New(rand.NewSource(seed)), | |
14 | } | |
15 | } | |
16 | ||
17 | type random struct { | |
18 | s sd.Subscriber | |
19 | r *rand.Rand | |
20 | } | |
21 | ||
22 | func (r *random) Endpoint() (endpoint.Endpoint, error) { | |
23 | endpoints, err := r.s.Endpoints() | |
24 | if err != nil { | |
25 | return nil, err | |
26 | } | |
27 | if len(endpoints) <= 0 { | |
28 | return nil, ErrNoEndpoints | |
29 | } | |
30 | return endpoints[r.r.Intn(len(endpoints))], nil | |
31 | } |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "math" | |
4 | "testing" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/sd" | |
8 | "golang.org/x/net/context" | |
9 | ) | |
10 | ||
11 | func TestRandom(t *testing.T) { | |
12 | var ( | |
13 | n = 7 | |
14 | endpoints = make([]endpoint.Endpoint, n) | |
15 | counts = make([]int, n) | |
16 | seed = int64(12345) | |
17 | iterations = 1000000 | |
18 | want = iterations / n | |
19 | tolerance = want / 100 // 1% | |
20 | ) | |
21 | ||
22 | for i := 0; i < n; i++ { | |
23 | i0 := i | |
24 | endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } | |
25 | } | |
26 | ||
27 | subscriber := sd.FixedSubscriber(endpoints) | |
28 | balancer := NewRandom(subscriber, seed) | |
29 | ||
30 | for i := 0; i < iterations; i++ { | |
31 | endpoint, _ := balancer.Endpoint() | |
32 | endpoint(context.Background(), struct{}{}) | |
33 | } | |
34 | ||
35 | for i, have := range counts { | |
36 | delta := int(math.Abs(float64(want - have))) | |
37 | if delta > tolerance { | |
38 | t.Errorf("%d: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance) | |
39 | } | |
40 | } | |
41 | } | |
42 | ||
43 | func TestRandomNoEndpoints(t *testing.T) { | |
44 | subscriber := sd.FixedSubscriber{} | |
45 | balancer := NewRandom(subscriber, 1415926) | |
46 | _, err := balancer.Endpoint() | |
47 | if want, have := ErrNoEndpoints, err; want != have { | |
48 | t.Errorf("want %v, have %v", want, have) | |
49 | } | |
50 | ||
51 | } |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | ) | |
11 | ||
12 | // Retry wraps a service load balancer and returns an endpoint oriented load | |
13 | // balancer for the specified service method. | |
14 | // Requests to the endpoint will be automatically load balanced via the load | |
15 | // balancer. Requests that return errors will be retried until they succeed, | |
16 | // up to max times, or until the timeout is elapsed, whichever comes first. | |
17 | func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint { | |
18 | if b == nil { | |
19 | panic("nil Balancer") | |
20 | } | |
21 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
22 | var ( | |
23 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
24 | responses = make(chan interface{}, 1) | |
25 | errs = make(chan error, 1) | |
26 | a = []string{} | |
27 | ) | |
28 | defer cancel() | |
29 | for i := 1; i <= max; i++ { | |
30 | go func() { | |
31 | e, err := b.Endpoint() | |
32 | if err != nil { | |
33 | errs <- err | |
34 | return | |
35 | } | |
36 | response, err := e(newctx, request) | |
37 | if err != nil { | |
38 | errs <- err | |
39 | return | |
40 | } | |
41 | responses <- response | |
42 | }() | |
43 | ||
44 | select { | |
45 | case <-newctx.Done(): | |
46 | return nil, newctx.Err() | |
47 | case response := <-responses: | |
48 | return response, nil | |
49 | case err := <-errs: | |
50 | a = append(a, err.Error()) | |
51 | continue | |
52 | } | |
53 | } | |
54 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
55 | } | |
56 | } |
0 | package lb_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/sd" | |
11 | loadbalancer "github.com/go-kit/kit/sd/lb" | |
12 | ) | |
13 | ||
14 | func TestRetryMaxTotalFail(t *testing.T) { | |
15 | var ( | |
16 | endpoints = sd.FixedSubscriber{} // no endpoints | |
17 | lb = loadbalancer.NewRoundRobin(endpoints) | |
18 | retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries | |
19 | ctx = context.Background() | |
20 | ) | |
21 | if _, err := retry(ctx, struct{}{}); err == nil { | |
22 | t.Errorf("expected error, got none") // should fail | |
23 | } | |
24 | } | |
25 | ||
26 | func TestRetryMaxPartialFail(t *testing.T) { | |
27 | var ( | |
28 | endpoints = []endpoint.Endpoint{ | |
29 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
30 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
31 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
32 | } | |
33 | subscriber = sd.FixedSubscriber{ | |
34 | 0: endpoints[0], | |
35 | 1: endpoints[1], | |
36 | 2: endpoints[2], | |
37 | } | |
38 | retries = len(endpoints) - 1 // not quite enough retries | |
39 | lb = loadbalancer.NewRoundRobin(subscriber) | |
40 | ctx = context.Background() | |
41 | ) | |
42 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { | |
43 | t.Errorf("expected error, got none") | |
44 | } | |
45 | } | |
46 | ||
47 | func TestRetryMaxSuccess(t *testing.T) { | |
48 | var ( | |
49 | endpoints = []endpoint.Endpoint{ | |
50 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
51 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
52 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
53 | } | |
54 | subscriber = sd.FixedSubscriber{ | |
55 | 0: endpoints[0], | |
56 | 1: endpoints[1], | |
57 | 2: endpoints[2], | |
58 | } | |
59 | retries = len(endpoints) // exactly enough retries | |
60 | lb = loadbalancer.NewRoundRobin(subscriber) | |
61 | ctx = context.Background() | |
62 | ) | |
63 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { | |
64 | t.Error(err) | |
65 | } | |
66 | } | |
67 | ||
68 | func TestRetryTimeout(t *testing.T) { | |
69 | var ( | |
70 | step = make(chan struct{}) | |
71 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } | |
72 | timeout = time.Millisecond | |
73 | retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e})) | |
74 | errs = make(chan error, 1) | |
75 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } | |
76 | ) | |
77 | ||
78 | go func() { step <- struct{}{} }() // queue up a flush of the endpoint | |
79 | invoke() // invoke the endpoint and trigger the flush | |
80 | if err := <-errs; err != nil { // that should succeed | |
81 | t.Error(err) | |
82 | } | |
83 | ||
84 | go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush | |
85 | invoke() // invoke the endpoint | |
86 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
87 | t.Errorf("wanted %v, got none", context.DeadlineExceeded) | |
88 | } | |
89 | } |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "sync/atomic" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/sd" | |
7 | ) | |
8 | ||
9 | // NewRoundRobin returns a load balancer that returns services in sequence. | |
10 | func NewRoundRobin(s sd.Subscriber) Balancer { | |
11 | return &roundRobin{ | |
12 | s: s, | |
13 | c: 0, | |
14 | } | |
15 | } | |
16 | ||
17 | type roundRobin struct { | |
18 | s sd.Subscriber | |
19 | c uint64 | |
20 | } | |
21 | ||
22 | func (rr *roundRobin) Endpoint() (endpoint.Endpoint, error) { | |
23 | endpoints, err := rr.s.Endpoints() | |
24 | if err != nil { | |
25 | return nil, err | |
26 | } | |
27 | if len(endpoints) <= 0 { | |
28 | return nil, ErrNoEndpoints | |
29 | } | |
30 | old := atomic.AddUint64(&rr.c, 1) - 1 | |
31 | idx := old % uint64(len(endpoints)) | |
32 | return endpoints[idx], nil | |
33 | } |
0 | package lb | |
1 | ||
2 | import ( | |
3 | "reflect" | |
4 | "sync" | |
5 | "sync/atomic" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | "golang.org/x/net/context" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | "github.com/go-kit/kit/sd" | |
13 | ) | |
14 | ||
15 | func TestRoundRobin(t *testing.T) { | |
16 | var ( | |
17 | counts = []int{0, 0, 0} | |
18 | endpoints = []endpoint.Endpoint{ | |
19 | func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, | |
20 | func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, | |
21 | func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, | |
22 | } | |
23 | ) | |
24 | ||
25 | subscriber := sd.FixedSubscriber(endpoints) | |
26 | balancer := NewRoundRobin(subscriber) | |
27 | ||
28 | for i, want := range [][]int{ | |
29 | {1, 0, 0}, | |
30 | {1, 1, 0}, | |
31 | {1, 1, 1}, | |
32 | {2, 1, 1}, | |
33 | {2, 2, 1}, | |
34 | {2, 2, 2}, | |
35 | {3, 2, 2}, | |
36 | } { | |
37 | endpoint, err := balancer.Endpoint() | |
38 | if err != nil { | |
39 | t.Fatal(err) | |
40 | } | |
41 | endpoint(context.Background(), struct{}{}) | |
42 | if have := counts; !reflect.DeepEqual(want, have) { | |
43 | t.Fatalf("%d: want %v, have %v", i, want, have) | |
44 | } | |
45 | } | |
46 | } | |
47 | ||
48 | func TestRoundRobinNoEndpoints(t *testing.T) { | |
49 | subscriber := sd.FixedSubscriber{} | |
50 | balancer := NewRoundRobin(subscriber) | |
51 | _, err := balancer.Endpoint() | |
52 | if want, have := ErrNoEndpoints, err; want != have { | |
53 | t.Errorf("want %v, have %v", want, have) | |
54 | } | |
55 | } | |
56 | ||
57 | func TestRoundRobinNoRace(t *testing.T) { | |
58 | balancer := NewRoundRobin(sd.FixedSubscriber([]endpoint.Endpoint{ | |
59 | endpoint.Nop, | |
60 | endpoint.Nop, | |
61 | endpoint.Nop, | |
62 | endpoint.Nop, | |
63 | endpoint.Nop, | |
64 | })) | |
65 | ||
66 | var ( | |
67 | n = 100 | |
68 | done = make(chan struct{}) | |
69 | wg sync.WaitGroup | |
70 | count uint64 | |
71 | ) | |
72 | ||
73 | wg.Add(n) | |
74 | ||
75 | for i := 0; i < n; i++ { | |
76 | go func() { | |
77 | defer wg.Done() | |
78 | for { | |
79 | select { | |
80 | case <-done: | |
81 | return | |
82 | default: | |
83 | _, _ = balancer.Endpoint() | |
84 | atomic.AddUint64(&count, 1) | |
85 | } | |
86 | } | |
87 | }() | |
88 | } | |
89 | ||
90 | time.Sleep(time.Second) | |
91 | close(done) | |
92 | wg.Wait() | |
93 | ||
94 | t.Logf("made %d calls", atomic.LoadUint64(&count)) | |
95 | } |
0 | package sd | |
1 | ||
2 | // Registrar registers instance information to a service discovery system when | |
3 | // an instance becomes alive and healthy, and deregisters that information when | |
4 | // the service becomes unhealthy or goes away. | |
5 | // | |
6 | // Registrar implementations exist for various service discovery systems. Note | |
7 | // that identifying instance information (e.g. host:port) must be given via the | |
8 | // concrete constructor; this interface merely signals lifecycle changes. | |
9 | type Registrar interface { | |
10 | Register() | |
11 | Deregister() | |
12 | } |
0 | package sd | |
1 | ||
2 | import "github.com/go-kit/kit/endpoint" | |
3 | ||
4 | // Subscriber listens to a service discovery system and yields a set of | |
5 | // identical endpoints on demand. An error indicates a problem with connectivity | |
6 | // to the service discovery system, or within the system itself; a subscriber | |
7 | // may yield no endpoints without error. | |
8 | type Subscriber interface { | |
9 | Endpoints() ([]endpoint.Endpoint, error) | |
10 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "net" | |
5 | "strings" | |
6 | "time" | |
7 | ||
8 | "github.com/samuel/go-zookeeper/zk" | |
9 | ||
10 | "github.com/go-kit/kit/log" | |
11 | ) | |
12 | ||
13 | // DefaultACL is the default ACL to use for creating znodes. | |
14 | var ( | |
15 | DefaultACL = zk.WorldACL(zk.PermAll) | |
16 | ErrInvalidCredentials = errors.New("invalid credentials provided") | |
17 | ErrClientClosed = errors.New("client service closed") | |
18 | ) | |
19 | ||
20 | const ( | |
21 | // DefaultConnectTimeout is the default timeout to establish a connection to | |
22 | // a ZooKeeper node. | |
23 | DefaultConnectTimeout = 2 * time.Second | |
24 | // DefaultSessionTimeout is the default timeout to keep the current | |
25 | // ZooKeeper session alive during a temporary disconnect. | |
26 | DefaultSessionTimeout = 5 * time.Second | |
27 | ) | |
28 | ||
29 | // Client is a wrapper around a lower level ZooKeeper client implementation. | |
30 | type Client interface { | |
31 | // GetEntries should query the provided path in ZooKeeper, place a watch on | |
32 | // it and retrieve data from its current child nodes. | |
33 | GetEntries(path string) ([]string, <-chan zk.Event, error) | |
34 | // CreateParentNodes should try to create the path in case it does not exist | |
35 | // yet on ZooKeeper. | |
36 | CreateParentNodes(path string) error | |
37 | // Stop should properly shutdown the client implementation | |
38 | Stop() | |
39 | } | |
40 | ||
41 | type clientConfig struct { | |
42 | logger log.Logger | |
43 | acl []zk.ACL | |
44 | credentials []byte | |
45 | connectTimeout time.Duration | |
46 | sessionTimeout time.Duration | |
47 | rootNodePayload [][]byte | |
48 | eventHandler func(zk.Event) | |
49 | } | |
50 | ||
51 | // Option functions enable friendly APIs. | |
52 | type Option func(*clientConfig) error | |
53 | ||
54 | type client struct { | |
55 | *zk.Conn | |
56 | clientConfig | |
57 | active bool | |
58 | quit chan struct{} | |
59 | } | |
60 | ||
61 | // ACL returns an Option specifying a non-default ACL for creating parent nodes. | |
62 | func ACL(acl []zk.ACL) Option { | |
63 | return func(c *clientConfig) error { | |
64 | c.acl = acl | |
65 | return nil | |
66 | } | |
67 | } | |
68 | ||
69 | // Credentials returns an Option specifying a user/password combination which | |
70 | // the client will use to authenticate itself with. | |
71 | func Credentials(user, pass string) Option { | |
72 | return func(c *clientConfig) error { | |
73 | if user == "" || pass == "" { | |
74 | return ErrInvalidCredentials | |
75 | } | |
76 | c.credentials = []byte(user + ":" + pass) | |
77 | return nil | |
78 | } | |
79 | } | |
80 | ||
81 | // ConnectTimeout returns an Option specifying a non-default connection timeout | |
82 | // when we try to establish a connection to a ZooKeeper server. | |
83 | func ConnectTimeout(t time.Duration) Option { | |
84 | return func(c *clientConfig) error { | |
85 | if t.Seconds() < 1 { | |
86 | return errors.New("invalid connect timeout (minimum value is 1 second)") | |
87 | } | |
88 | c.connectTimeout = t | |
89 | return nil | |
90 | } | |
91 | } | |
92 | ||
93 | // SessionTimeout returns an Option specifying a non-default session timeout. | |
94 | func SessionTimeout(t time.Duration) Option { | |
95 | return func(c *clientConfig) error { | |
96 | if t.Seconds() < 1 { | |
97 | return errors.New("invalid session timeout (minimum value is 1 second)") | |
98 | } | |
99 | c.sessionTimeout = t | |
100 | return nil | |
101 | } | |
102 | } | |
103 | ||
104 | // Payload returns an Option specifying non-default data values for each znode | |
105 | // created by CreateParentNodes. | |
106 | func Payload(payload [][]byte) Option { | |
107 | return func(c *clientConfig) error { | |
108 | c.rootNodePayload = payload | |
109 | return nil | |
110 | } | |
111 | } | |
112 | ||
113 | // EventHandler returns an Option specifying a callback function to handle | |
114 | // incoming zk.Event payloads (ZooKeeper connection events). | |
115 | func EventHandler(handler func(zk.Event)) Option { | |
116 | return func(c *clientConfig) error { | |
117 | c.eventHandler = handler | |
118 | return nil | |
119 | } | |
120 | } | |
121 | ||
122 | // NewClient returns a ZooKeeper client with a connection to the server cluster. | |
123 | // It will return an error if the server cluster cannot be resolved. | |
124 | func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) { | |
125 | defaultEventHandler := func(event zk.Event) { | |
126 | logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err) | |
127 | } | |
128 | config := clientConfig{ | |
129 | acl: DefaultACL, | |
130 | connectTimeout: DefaultConnectTimeout, | |
131 | sessionTimeout: DefaultSessionTimeout, | |
132 | eventHandler: defaultEventHandler, | |
133 | logger: logger, | |
134 | } | |
135 | for _, option := range options { | |
136 | if err := option(&config); err != nil { | |
137 | return nil, err | |
138 | } | |
139 | } | |
140 | // dialer overrides the default ZooKeeper library Dialer so we can configure | |
141 | // the connectTimeout. The current library has a hardcoded value of 1 second | |
142 | // and there are reports of race conditions, due to slow DNS resolvers and | |
143 | // other network latency issues. | |
144 | dialer := func(network, address string, _ time.Duration) (net.Conn, error) { | |
145 | return net.DialTimeout(network, address, config.connectTimeout) | |
146 | } | |
147 | conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer)) | |
148 | ||
149 | if err != nil { | |
150 | return nil, err | |
151 | } | |
152 | ||
153 | if len(config.credentials) > 0 { | |
154 | err = conn.AddAuth("digest", config.credentials) | |
155 | if err != nil { | |
156 | return nil, err | |
157 | } | |
158 | } | |
159 | ||
160 | c := &client{conn, config, true, make(chan struct{})} | |
161 | ||
162 | // Start listening for incoming Event payloads and callback the set | |
163 | // eventHandler. | |
164 | go func() { | |
165 | for { | |
166 | select { | |
167 | case event := <-eventc: | |
168 | config.eventHandler(event) | |
169 | case <-c.quit: | |
170 | return | |
171 | } | |
172 | } | |
173 | }() | |
174 | return c, nil | |
175 | } | |
176 | ||
177 | // CreateParentNodes implements the ZooKeeper Client interface. | |
178 | func (c *client) CreateParentNodes(path string) error { | |
179 | if !c.active { | |
180 | return ErrClientClosed | |
181 | } | |
182 | if path[0] != '/' { | |
183 | return zk.ErrInvalidPath | |
184 | } | |
185 | payload := []byte("") | |
186 | pathString := "" | |
187 | pathNodes := strings.Split(path, "/") | |
188 | for i := 1; i < len(pathNodes); i++ { | |
189 | if i <= len(c.rootNodePayload) { | |
190 | payload = c.rootNodePayload[i-1] | |
191 | } else { | |
192 | payload = []byte("") | |
193 | } | |
194 | pathString += "/" + pathNodes[i] | |
195 | _, err := c.Create(pathString, payload, 0, c.acl) | |
196 | // not being able to create the node because it exists or not having | |
197 | // sufficient rights is not an issue. It is ok for the node to already | |
198 | // exist and/or us to only have read rights | |
199 | if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth { | |
200 | return err | |
201 | } | |
202 | } | |
203 | return nil | |
204 | } | |
205 | ||
206 | // GetEntries implements the ZooKeeper Client interface. | |
207 | func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) { | |
208 | // retrieve list of child nodes for given path and add watch to path | |
209 | znodes, _, eventc, err := c.ChildrenW(path) | |
210 | ||
211 | if err != nil { | |
212 | return nil, eventc, err | |
213 | } | |
214 | ||
215 | var resp []string | |
216 | for _, znode := range znodes { | |
217 | // retrieve payload for child znode and add to response array | |
218 | if data, _, err := c.Get(path + "/" + znode); err == nil { | |
219 | resp = append(resp, string(data)) | |
220 | } | |
221 | } | |
222 | return resp, eventc, nil | |
223 | } | |
224 | ||
225 | // Stop implements the ZooKeeper Client interface. | |
226 | func (c *client) Stop() { | |
227 | c.active = false | |
228 | close(c.quit) | |
229 | c.Close() | |
230 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "bytes" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | stdzk "github.com/samuel/go-zookeeper/zk" | |
8 | ||
9 | "github.com/go-kit/kit/log" | |
10 | ) | |
11 | ||
12 | func TestNewClient(t *testing.T) { | |
13 | var ( | |
14 | acl = stdzk.WorldACL(stdzk.PermRead) | |
15 | connectTimeout = 3 * time.Second | |
16 | sessionTimeout = 20 * time.Second | |
17 | payload = [][]byte{[]byte("Payload"), []byte("Test")} | |
18 | ) | |
19 | ||
20 | c, err := NewClient( | |
21 | []string{"FailThisInvalidHost!!!"}, | |
22 | log.NewNopLogger(), | |
23 | ) | |
24 | if err == nil { | |
25 | t.Errorf("expected error, got nil") | |
26 | } | |
27 | ||
28 | hasFired := false | |
29 | calledEventHandler := make(chan struct{}) | |
30 | eventHandler := func(event stdzk.Event) { | |
31 | if !hasFired { | |
32 | // test is successful if this function has fired at least once | |
33 | hasFired = true | |
34 | close(calledEventHandler) | |
35 | } | |
36 | } | |
37 | ||
38 | c, err = NewClient( | |
39 | []string{"localhost"}, | |
40 | log.NewNopLogger(), | |
41 | ACL(acl), | |
42 | ConnectTimeout(connectTimeout), | |
43 | SessionTimeout(sessionTimeout), | |
44 | Payload(payload), | |
45 | EventHandler(eventHandler), | |
46 | ) | |
47 | if err != nil { | |
48 | t.Fatal(err) | |
49 | } | |
50 | defer c.Stop() | |
51 | ||
52 | clientImpl, ok := c.(*client) | |
53 | if !ok { | |
54 | t.Fatal("retrieved incorrect Client implementation") | |
55 | } | |
56 | if want, have := acl, clientImpl.acl; want[0] != have[0] { | |
57 | t.Errorf("want %+v, have %+v", want, have) | |
58 | } | |
59 | if want, have := connectTimeout, clientImpl.connectTimeout; want != have { | |
60 | t.Errorf("want %d, have %d", want, have) | |
61 | } | |
62 | if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have { | |
63 | t.Errorf("want %d, have %d", want, have) | |
64 | } | |
65 | if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 { | |
66 | t.Errorf("want %s, have %s", want, have) | |
67 | } | |
68 | ||
69 | select { | |
70 | case <-calledEventHandler: | |
71 | case <-time.After(100 * time.Millisecond): | |
72 | t.Errorf("event handler never called") | |
73 | } | |
74 | } | |
75 | ||
76 | func TestOptions(t *testing.T) { | |
77 | _, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials")) | |
78 | if err != nil && err != stdzk.ErrNoServer { | |
79 | t.Errorf("unexpected error: %v", err) | |
80 | } | |
81 | ||
82 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", "")) | |
83 | if want, have := err, ErrInvalidCredentials; want != have { | |
84 | t.Errorf("want %v, have %v", want, have) | |
85 | } | |
86 | ||
87 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0)) | |
88 | if err == nil { | |
89 | t.Errorf("expected connect timeout error") | |
90 | } | |
91 | ||
92 | _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0)) | |
93 | if err == nil { | |
94 | t.Errorf("expected connect timeout error") | |
95 | } | |
96 | } | |
97 | ||
98 | func TestCreateParentNodes(t *testing.T) { | |
99 | payload := [][]byte{[]byte("Payload"), []byte("Test")} | |
100 | ||
101 | c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger()) | |
102 | if err != nil { | |
103 | t.Errorf("unexpected error: %v", err) | |
104 | } | |
105 | if c == nil { | |
106 | t.Fatal("expected new Client, got nil") | |
107 | } | |
108 | ||
109 | s, err := NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
110 | if err != stdzk.ErrNoServer { | |
111 | t.Errorf("unexpected error: %v", err) | |
112 | } | |
113 | if s != nil { | |
114 | t.Error("expected failed new Subscriber") | |
115 | } | |
116 | ||
117 | s, err = NewSubscriber(c, "invalidpath", newFactory(""), log.NewNopLogger()) | |
118 | if err != stdzk.ErrInvalidPath { | |
119 | t.Errorf("unexpected error: %v", err) | |
120 | } | |
121 | _, _, err = c.GetEntries("/validpath") | |
122 | if err != stdzk.ErrNoServer { | |
123 | t.Errorf("unexpected error: %v", err) | |
124 | } | |
125 | ||
126 | c.Stop() | |
127 | ||
128 | err = c.CreateParentNodes("/validpath") | |
129 | if err != ErrClientClosed { | |
130 | t.Errorf("unexpected error: %v", err) | |
131 | } | |
132 | ||
133 | s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
134 | if err != ErrClientClosed { | |
135 | t.Errorf("unexpected error: %v", err) | |
136 | } | |
137 | if s != nil { | |
138 | t.Error("expected failed new Subscriber") | |
139 | } | |
140 | ||
141 | c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload)) | |
142 | if err != nil { | |
143 | t.Errorf("unexpected error: %v", err) | |
144 | } | |
145 | if c == nil { | |
146 | t.Fatal("expected new Client, got nil") | |
147 | } | |
148 | ||
149 | s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) | |
150 | if err != stdzk.ErrNoServer { | |
151 | t.Errorf("unexpected error: %v", err) | |
152 | } | |
153 | if s != nil { | |
154 | t.Error("expected failed new Subscriber") | |
155 | } | |
156 | } |
0 | // +build integration | |
1 | ||
2 | package zk | |
3 | ||
4 | import ( | |
5 | "bytes" | |
6 | "flag" | |
7 | "fmt" | |
8 | "os" | |
9 | "testing" | |
10 | "time" | |
11 | ||
12 | stdzk "github.com/samuel/go-zookeeper/zk" | |
13 | ) | |
14 | ||
15 | var ( | |
16 | host []string | |
17 | ) | |
18 | ||
19 | func TestMain(m *testing.M) { | |
20 | flag.Parse() | |
21 | ||
22 | fmt.Println("Starting ZooKeeper server...") | |
23 | ||
24 | ts, err := stdzk.StartTestCluster(1, nil, nil) | |
25 | if err != nil { | |
26 | fmt.Printf("ZooKeeper server error: %v\n", err) | |
27 | os.Exit(1) | |
28 | } | |
29 | ||
30 | host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)} | |
31 | code := m.Run() | |
32 | ||
33 | ts.Stop() | |
34 | os.Exit(code) | |
35 | } | |
36 | ||
37 | func TestCreateParentNodesOnServer(t *testing.T) { | |
38 | payload := [][]byte{[]byte("Payload"), []byte("Test")} | |
39 | c1, err := NewClient(host, logger, Payload(payload)) | |
40 | if err != nil { | |
41 | t.Fatalf("Connect returned error: %v", err) | |
42 | } | |
43 | if c1 == nil { | |
44 | t.Fatal("Expected pointer to client, got nil") | |
45 | } | |
46 | defer c1.Stop() | |
47 | ||
48 | s, err := NewSubscriber(c1, path, newFactory(""), logger) | |
49 | if err != nil { | |
50 | t.Fatalf("Unable to create Subscriber: %v", err) | |
51 | } | |
52 | defer s.Stop() | |
53 | ||
54 | services, err := s.Services() | |
55 | if err != nil { | |
56 | t.Fatal(err) | |
57 | } | |
58 | if want, have := 0, len(services); want != have { | |
59 | t.Errorf("want %d, have %d", want, have) | |
60 | } | |
61 | ||
62 | c2, err := NewClient(host, logger) | |
63 | if err != nil { | |
64 | t.Fatalf("Connect returned error: %v", err) | |
65 | } | |
66 | defer c2.Stop() | |
67 | data, _, err := c2.(*client).Get(path) | |
68 | if err != nil { | |
69 | t.Fatal(err) | |
70 | } | |
71 | // test Client implementation of CreateParentNodes. It should have created | |
72 | // our payload | |
73 | if bytes.Compare(data, payload[1]) != 0 { | |
74 | t.Errorf("want %s, have %s", payload[1], data) | |
75 | } | |
76 | ||
77 | } | |
78 | ||
79 | func TestCreateBadParentNodesOnServer(t *testing.T) { | |
80 | c, _ := NewClient(host, logger) | |
81 | defer c.Stop() | |
82 | ||
83 | _, err := NewSubscriber(c, "invalid/path", newFactory(""), logger) | |
84 | ||
85 | if want, have := stdzk.ErrInvalidPath, err; want != have { | |
86 | t.Errorf("want %v, have %v", want, have) | |
87 | } | |
88 | } | |
89 | ||
90 | func TestCredentials1(t *testing.T) { | |
91 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") | |
92 | c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret")) | |
93 | defer c.Stop() | |
94 | ||
95 | _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) | |
96 | ||
97 | if err != nil { | |
98 | t.Fatal(err) | |
99 | } | |
100 | } | |
101 | ||
102 | func TestCredentials2(t *testing.T) { | |
103 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") | |
104 | c, _ := NewClient(host, logger, ACL(acl)) | |
105 | defer c.Stop() | |
106 | ||
107 | _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) | |
108 | ||
109 | if err != stdzk.ErrNoAuth { | |
110 | t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err) | |
111 | } | |
112 | } | |
113 | ||
114 | func TestConnection(t *testing.T) { | |
115 | c, _ := NewClient(host, logger) | |
116 | c.Stop() | |
117 | ||
118 | _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) | |
119 | ||
120 | if err != ErrClientClosed { | |
121 | t.Errorf("want %v, have %v", ErrClientClosed, err) | |
122 | } | |
123 | } | |
124 | ||
125 | func TestGetEntriesOnServer(t *testing.T) { | |
126 | var instancePayload = "protocol://hostname:port/routing" | |
127 | ||
128 | c1, err := NewClient(host, logger) | |
129 | if err != nil { | |
130 | t.Fatalf("Connect returned error: %v", err) | |
131 | } | |
132 | ||
133 | defer c1.Stop() | |
134 | ||
135 | c2, err := NewClient(host, logger) | |
136 | s, err := NewSubscriber(c2, path, newFactory(""), logger) | |
137 | if err != nil { | |
138 | t.Fatal(err) | |
139 | } | |
140 | defer c2.Stop() | |
141 | ||
142 | c2impl, _ := c2.(*client) | |
143 | _, err = c2impl.Create( | |
144 | path+"/instance1", | |
145 | []byte(instancePayload), | |
146 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
147 | stdzk.WorldACL(stdzk.PermAll), | |
148 | ) | |
149 | if err != nil { | |
150 | t.Fatalf("Unable to create test ephemeral znode 1: %v", err) | |
151 | } | |
152 | _, err = c2impl.Create( | |
153 | path+"/instance2", | |
154 | []byte(instancePayload+"2"), | |
155 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
156 | stdzk.WorldACL(stdzk.PermAll), | |
157 | ) | |
158 | if err != nil { | |
159 | t.Fatalf("Unable to create test ephemeral znode 2: %v", err) | |
160 | } | |
161 | ||
162 | time.Sleep(50 * time.Millisecond) | |
163 | ||
164 | services, err := s.Services() | |
165 | if err != nil { | |
166 | t.Fatal(err) | |
167 | } | |
168 | if want, have := 2, len(services); want != have { | |
169 | t.Errorf("want %d, have %d", want, have) | |
170 | } | |
171 | } | |
172 | ||
173 | func TestGetEntriesPayloadOnServer(t *testing.T) { | |
174 | c, err := NewClient(host, logger) | |
175 | if err != nil { | |
176 | t.Fatalf("Connect returned error: %v", err) | |
177 | } | |
178 | _, eventc, err := c.GetEntries(path) | |
179 | if err != nil { | |
180 | t.Fatal(err) | |
181 | } | |
182 | _, err = c.(*client).Create( | |
183 | path+"/instance3", | |
184 | []byte("just some payload"), | |
185 | stdzk.FlagEphemeral|stdzk.FlagSequence, | |
186 | stdzk.WorldACL(stdzk.PermAll), | |
187 | ) | |
188 | if err != nil { | |
189 | t.Fatalf("Unable to create test ephemeral znode: %v", err) | |
190 | } | |
191 | select { | |
192 | case event := <-eventc: | |
193 | if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have { | |
194 | t.Errorf("want %s, have %s", want, have) | |
195 | } | |
196 | case <-time.After(20 * time.Millisecond): | |
197 | t.Errorf("expected incoming watch event, timeout occurred") | |
198 | } | |
199 | ||
200 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | ||
5 | "github.com/samuel/go-zookeeper/zk" | |
6 | ||
7 | "github.com/go-kit/kit/log" | |
8 | ) | |
9 | ||
10 | // wrapLogger wraps a Go kit logger so we can use it as the logging service for | |
11 | // the ZooKeeper library, which expects a Printf method to be available. | |
12 | type wrapLogger struct { | |
13 | log.Logger | |
14 | } | |
15 | ||
16 | func (logger wrapLogger) Printf(format string, args ...interface{}) { | |
17 | logger.Log("msg", fmt.Sprintf(format, args...)) | |
18 | } | |
19 | ||
20 | // withLogger replaces the ZooKeeper library's default logging service with our | |
21 | // own Go kit logger. | |
22 | func withLogger(logger log.Logger) func(c *zk.Conn) { | |
23 | return func(c *zk.Conn) { | |
24 | c.SetLogger(wrapLogger{logger}) | |
25 | } | |
26 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "github.com/samuel/go-zookeeper/zk" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/go-kit/kit/log" | |
7 | "github.com/go-kit/kit/sd" | |
8 | "github.com/go-kit/kit/sd/cache" | |
9 | ) | |
10 | ||
11 | // Subscriber yield endpoints stored in a certain ZooKeeper path. Any kind of | |
12 | // change in that path is watched and will update the Subscriber endpoints. | |
13 | type Subscriber struct { | |
14 | client Client | |
15 | path string | |
16 | cache *cache.Cache | |
17 | logger log.Logger | |
18 | quitc chan struct{} | |
19 | } | |
20 | ||
21 | var _ sd.Subscriber = &Subscriber{} | |
22 | ||
23 | // NewSubscriber returns a ZooKeeper subscriber. ZooKeeper will start watching | |
24 | // the given path for changes and update the Subscriber endpoints. | |
25 | func NewSubscriber(c Client, path string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { | |
26 | s := &Subscriber{ | |
27 | client: c, | |
28 | path: path, | |
29 | cache: cache.New(factory, logger), | |
30 | logger: logger, | |
31 | quitc: make(chan struct{}), | |
32 | } | |
33 | ||
34 | err := s.client.CreateParentNodes(s.path) | |
35 | if err != nil { | |
36 | return nil, err | |
37 | } | |
38 | ||
39 | instances, eventc, err := s.client.GetEntries(s.path) | |
40 | if err != nil { | |
41 | logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) | |
42 | return nil, err | |
43 | } | |
44 | logger.Log("path", s.path, "instances", len(instances)) | |
45 | s.cache.Update(instances) | |
46 | ||
47 | go s.loop(eventc) | |
48 | ||
49 | return s, nil | |
50 | } | |
51 | ||
52 | func (s *Subscriber) loop(eventc <-chan zk.Event) { | |
53 | var ( | |
54 | instances []string | |
55 | err error | |
56 | ) | |
57 | for { | |
58 | select { | |
59 | case <-eventc: | |
60 | // We received a path update notification. Call GetEntries to | |
61 | // retrieve child node data, and set a new watch, as ZK watches are | |
62 | // one-time triggers. | |
63 | instances, eventc, err = s.client.GetEntries(s.path) | |
64 | if err != nil { | |
65 | s.logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) | |
66 | continue | |
67 | } | |
68 | s.logger.Log("path", s.path, "instances", len(instances)) | |
69 | s.cache.Update(instances) | |
70 | ||
71 | case <-s.quitc: | |
72 | return | |
73 | } | |
74 | } | |
75 | } | |
76 | ||
77 | // Endpoints implements the Subscriber interface. | |
78 | func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { | |
79 | return s.cache.Endpoints(), nil | |
80 | } | |
81 | ||
82 | // Stop terminates the Subscriber. | |
83 | func (s *Subscriber) Stop() { | |
84 | close(s.quitc) | |
85 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | "time" | |
5 | ) | |
6 | ||
7 | func TestSubscriber(t *testing.T) { | |
8 | client := newFakeClient() | |
9 | ||
10 | s, err := NewSubscriber(client, path, newFactory(""), logger) | |
11 | if err != nil { | |
12 | t.Fatalf("failed to create new Subscriber: %v", err) | |
13 | } | |
14 | defer s.Stop() | |
15 | ||
16 | if _, err := s.Endpoints(); err != nil { | |
17 | t.Fatal(err) | |
18 | } | |
19 | } | |
20 | ||
21 | func TestBadFactory(t *testing.T) { | |
22 | client := newFakeClient() | |
23 | ||
24 | s, err := NewSubscriber(client, path, newFactory("kaboom"), logger) | |
25 | if err != nil { | |
26 | t.Fatalf("failed to create new Subscriber: %v", err) | |
27 | } | |
28 | defer s.Stop() | |
29 | ||
30 | // instance1 came online | |
31 | client.AddService(path+"/instance1", "kaboom") | |
32 | ||
33 | // instance2 came online | |
34 | client.AddService(path+"/instance2", "zookeeper_node_data") | |
35 | ||
36 | if err = asyncTest(100*time.Millisecond, 1, s); err != nil { | |
37 | t.Error(err) | |
38 | } | |
39 | } | |
40 | ||
41 | func TestServiceUpdate(t *testing.T) { | |
42 | client := newFakeClient() | |
43 | ||
44 | s, err := NewSubscriber(client, path, newFactory(""), logger) | |
45 | if err != nil { | |
46 | t.Fatalf("failed to create new Subscriber: %v", err) | |
47 | } | |
48 | defer s.Stop() | |
49 | ||
50 | endpoints, err := s.Endpoints() | |
51 | if err != nil { | |
52 | t.Fatal(err) | |
53 | } | |
54 | if want, have := 0, len(endpoints); want != have { | |
55 | t.Errorf("want %d, have %d", want, have) | |
56 | } | |
57 | ||
58 | // instance1 came online | |
59 | client.AddService(path+"/instance1", "zookeeper_node_data1") | |
60 | ||
61 | // instance2 came online | |
62 | client.AddService(path+"/instance2", "zookeeper_node_data2") | |
63 | ||
64 | // we should have 2 instances | |
65 | if err = asyncTest(100*time.Millisecond, 2, s); err != nil { | |
66 | t.Error(err) | |
67 | } | |
68 | ||
69 | // TODO(pb): this bit is flaky | |
70 | // | |
71 | //// watch triggers an error... | |
72 | //client.SendErrorOnWatch() | |
73 | // | |
74 | //// test if error was consumed | |
75 | //if err = client.ErrorIsConsumedWithin(100 * time.Millisecond); err != nil { | |
76 | // t.Error(err) | |
77 | //} | |
78 | ||
79 | // instance3 came online | |
80 | client.AddService(path+"/instance3", "zookeeper_node_data3") | |
81 | ||
82 | // we should have 3 instances | |
83 | if err = asyncTest(100*time.Millisecond, 3, s); err != nil { | |
84 | t.Error(err) | |
85 | } | |
86 | ||
87 | // instance1 goes offline | |
88 | client.RemoveService(path + "/instance1") | |
89 | ||
90 | // instance2 goes offline | |
91 | client.RemoveService(path + "/instance2") | |
92 | ||
93 | // we should have 1 instance | |
94 | if err = asyncTest(100*time.Millisecond, 1, s); err != nil { | |
95 | t.Error(err) | |
96 | } | |
97 | } | |
98 | ||
99 | func TestBadSubscriberCreate(t *testing.T) { | |
100 | client := newFakeClient() | |
101 | client.SendErrorOnWatch() | |
102 | s, err := NewSubscriber(client, path, newFactory(""), logger) | |
103 | if err == nil { | |
104 | t.Error("expected error on new Subscriber") | |
105 | } | |
106 | if s != nil { | |
107 | t.Error("expected Subscriber not to be created") | |
108 | } | |
109 | s, err = NewSubscriber(client, "BadPath", newFactory(""), logger) | |
110 | if err == nil { | |
111 | t.Error("expected error on new Subscriber") | |
112 | } | |
113 | if s != nil { | |
114 | t.Error("expected Subscriber not to be created") | |
115 | } | |
116 | } |
0 | package zk | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "fmt" | |
5 | "io" | |
6 | "sync" | |
7 | "time" | |
8 | ||
9 | "github.com/samuel/go-zookeeper/zk" | |
10 | "golang.org/x/net/context" | |
11 | ||
12 | "github.com/go-kit/kit/endpoint" | |
13 | "github.com/go-kit/kit/log" | |
14 | "github.com/go-kit/kit/sd" | |
15 | ) | |
16 | ||
17 | var ( | |
18 | path = "/gokit.test/service.name" | |
19 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
20 | logger = log.NewNopLogger() | |
21 | ) | |
22 | ||
23 | type fakeClient struct { | |
24 | mtx sync.Mutex | |
25 | ch chan zk.Event | |
26 | responses map[string]string | |
27 | result bool | |
28 | } | |
29 | ||
30 | func newFakeClient() *fakeClient { | |
31 | return &fakeClient{ | |
32 | ch: make(chan zk.Event, 1), | |
33 | responses: make(map[string]string), | |
34 | result: true, | |
35 | } | |
36 | } | |
37 | ||
38 | func (c *fakeClient) CreateParentNodes(path string) error { | |
39 | if path == "BadPath" { | |
40 | return errors.New("dummy error") | |
41 | } | |
42 | return nil | |
43 | } | |
44 | ||
45 | func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) { | |
46 | c.mtx.Lock() | |
47 | defer c.mtx.Unlock() | |
48 | if c.result == false { | |
49 | c.result = true | |
50 | return []string{}, c.ch, errors.New("dummy error") | |
51 | } | |
52 | responses := []string{} | |
53 | for _, data := range c.responses { | |
54 | responses = append(responses, data) | |
55 | } | |
56 | return responses, c.ch, nil | |
57 | } | |
58 | ||
59 | func (c *fakeClient) AddService(node, data string) { | |
60 | c.mtx.Lock() | |
61 | defer c.mtx.Unlock() | |
62 | c.responses[node] = data | |
63 | c.ch <- zk.Event{} | |
64 | } | |
65 | ||
66 | func (c *fakeClient) RemoveService(node string) { | |
67 | c.mtx.Lock() | |
68 | defer c.mtx.Unlock() | |
69 | delete(c.responses, node) | |
70 | c.ch <- zk.Event{} | |
71 | } | |
72 | ||
73 | func (c *fakeClient) SendErrorOnWatch() { | |
74 | c.mtx.Lock() | |
75 | defer c.mtx.Unlock() | |
76 | c.result = false | |
77 | c.ch <- zk.Event{} | |
78 | } | |
79 | ||
80 | func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error { | |
81 | t := time.After(timeout) | |
82 | for { | |
83 | select { | |
84 | case <-t: | |
85 | return fmt.Errorf("expected error not consumed after timeout %s", timeout) | |
86 | default: | |
87 | c.mtx.Lock() | |
88 | if c.result == false { | |
89 | c.mtx.Unlock() | |
90 | return nil | |
91 | } | |
92 | c.mtx.Unlock() | |
93 | } | |
94 | } | |
95 | } | |
96 | ||
97 | func (c *fakeClient) Stop() {} | |
98 | ||
99 | func newFactory(fakeError string) sd.Factory { | |
100 | return func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
101 | if fakeError == instance { | |
102 | return nil, nil, errors.New(fakeError) | |
103 | } | |
104 | return endpoint.Nop, nil, nil | |
105 | } | |
106 | } | |
107 | ||
108 | func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) { | |
109 | var endpoints []endpoint.Endpoint | |
110 | have := -1 // want can never be <0 | |
111 | t := time.After(timeout) | |
112 | for { | |
113 | select { | |
114 | case <-t: | |
115 | return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String()) | |
116 | default: | |
117 | endpoints, err = s.Endpoints() | |
118 | have = len(endpoints) | |
119 | if err != nil || want == have { | |
120 | return | |
121 | } | |
122 | time.Sleep(timeout / 10) | |
123 | } | |
124 | } | |
125 | } |