Package list golang-github-go-kit-kit / 9a19822
sd: port, without service.Service Peter Bourgon 5 years ago
37 changed file(s) with 2753 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
88 // Endpoint is the fundamental building block of servers and clients.
99 // It represents a single RPC method.
1010 type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error)
11
12 // Nop is an endpoint that does nothing and returns a nil error.
13 // Useful for tests.
14 func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
1115
1216 // Middleware is a chainable behavior modifier for endpoints.
1317 type Middleware func(Endpoint) Endpoint
0 package cache
1
2 import (
3 "io"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/log"
8 )
9
10 func BenchmarkEndpoints(b *testing.B) {
11 var (
12 ca = make(closer)
13 cb = make(closer)
14 cmap = map[string]io.Closer{"a": ca, "b": cb}
15 factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, cmap[instance], nil }
16 c = New(factory, log.NewNopLogger())
17 )
18
19 b.ReportAllocs()
20
21 c.Update([]string{"a", "b"})
22
23 b.RunParallel(func(pb *testing.PB) {
24 for pb.Next() {
25 c.Endpoints()
26 }
27 })
28 }
0 package cache
1
2 import (
3 "io"
4 "sort"
5 "sync"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/log"
9 "github.com/go-kit/kit/sd"
10 )
11
12 // Cache collects the most recent set of endpoints from a service discovery
13 // system via a subscriber, and makes them available to consumers. Cache is
14 // meant to be embedded inside of a concrete subscriber, and can serve Service
15 // invocations directly.
16 type Cache struct {
17 mtx sync.RWMutex
18 factory sd.Factory
19 cache map[string]endpointCloser
20 slice []endpoint.Endpoint
21 logger log.Logger
22 }
23
24 type endpointCloser struct {
25 endpoint.Endpoint
26 io.Closer
27 }
28
29 // New returns a new, empty endpoint cache.
30 func New(factory sd.Factory, logger log.Logger) *Cache {
31 return &Cache{
32 factory: factory,
33 cache: map[string]endpointCloser{},
34 logger: logger,
35 }
36 }
37
38 // Update should be invoked by clients with a complete set of current instance
39 // strings whenever that set changes. The cache manufactures new endpoints via
40 // the factory, closes old endpoints when they disappear, and persists existing
41 // endpoints if they survive through an update.
42 func (c *Cache) Update(instances []string) {
43 c.mtx.Lock()
44 defer c.mtx.Unlock()
45
46 // Deterministic order (for later).
47 sort.Strings(instances)
48
49 // Produce the current set of services.
50 cache := make(map[string]endpointCloser, len(instances))
51 for _, instance := range instances {
52 // If it already exists, just copy it over.
53 if sc, ok := c.cache[instance]; ok {
54 cache[instance] = sc
55 delete(c.cache, instance)
56 continue
57 }
58
59 // If it doesn't exist, create it.
60 service, closer, err := c.factory(instance)
61 if err != nil {
62 c.logger.Log("instance", instance, "err", err)
63 continue
64 }
65 cache[instance] = endpointCloser{service, closer}
66 }
67
68 // Close any leftover endpoints.
69 for _, sc := range c.cache {
70 if sc.Closer != nil {
71 sc.Closer.Close()
72 }
73 }
74
75 // Populate the slice of endpoints.
76 slice := make([]endpoint.Endpoint, 0, len(cache))
77 for _, instance := range instances {
78 // A bad factory may mean an instance is not present.
79 if _, ok := cache[instance]; !ok {
80 continue
81 }
82 slice = append(slice, cache[instance].Endpoint)
83 }
84
85 // Swap and trigger GC for old copies.
86 c.slice = slice
87 c.cache = cache
88 }
89
90 // Endpoints yields the current set of (presumably identical) endpoints, ordered
91 // lexicographically by the corresponding instance string.
92 func (c *Cache) Endpoints() []endpoint.Endpoint {
93 c.mtx.RLock()
94 defer c.mtx.RUnlock()
95 return c.slice
96 }
0 package cache
1
2 import (
3 "errors"
4 "io"
5 "testing"
6 "time"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/log"
10 )
11
12 func TestCache(t *testing.T) {
13 var (
14 ca = make(closer)
15 cb = make(closer)
16 c = map[string]io.Closer{"a": ca, "b": cb}
17 f = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, c[instance], nil }
18 cache = New(f, log.NewNopLogger())
19 )
20
21 // Populate
22 cache.Update([]string{"a", "b"})
23 select {
24 case <-ca:
25 t.Errorf("endpoint a closed, not good")
26 case <-cb:
27 t.Errorf("endpoint b closed, not good")
28 case <-time.After(time.Millisecond):
29 t.Logf("no closures yet, good")
30 }
31 if want, have := 2, len(cache.Endpoints()); want != have {
32 t.Errorf("want %d, have %d", want, have)
33 }
34
35 // Duplicate, should be no-op
36 cache.Update([]string{"a", "b"})
37 select {
38 case <-ca:
39 t.Errorf("endpoint a closed, not good")
40 case <-cb:
41 t.Errorf("endpoint b closed, not good")
42 case <-time.After(time.Millisecond):
43 t.Logf("no closures yet, good")
44 }
45 if want, have := 2, len(cache.Endpoints()); want != have {
46 t.Errorf("want %d, have %d", want, have)
47 }
48
49 // Delete b
50 go cache.Update([]string{"a"})
51 select {
52 case <-ca:
53 t.Errorf("endpoint a closed, not good")
54 case <-cb:
55 t.Logf("endpoint b closed, good")
56 case <-time.After(time.Second):
57 t.Errorf("didn't close the deleted instance in time")
58 }
59 if want, have := 1, len(cache.Endpoints()); want != have {
60 t.Errorf("want %d, have %d", want, have)
61 }
62
63 // Delete a
64 go cache.Update([]string{})
65 select {
66 // case <-cb: will succeed, as it's closed
67 case <-ca:
68 t.Logf("endpoint a closed, good")
69 case <-time.After(time.Second):
70 t.Errorf("didn't close the deleted instance in time")
71 }
72 if want, have := 0, len(cache.Endpoints()); want != have {
73 t.Errorf("want %d, have %d", want, have)
74 }
75 }
76
77 func TestBadFactory(t *testing.T) {
78 cache := New(func(string) (endpoint.Endpoint, io.Closer, error) {
79 return nil, nil, errors.New("bad factory")
80 }, log.NewNopLogger())
81
82 cache.Update([]string{"foo:1234", "bar:5678"})
83 if want, have := 0, len(cache.Endpoints()); want != have {
84 t.Errorf("want %d, have %d", want, have)
85 }
86 }
87
88 type closer chan struct{}
89
90 func (c closer) Close() error { close(c); return nil }
0 package consul
1
2 import consul "github.com/hashicorp/consul/api"
3
4 // Client is a wrapper around the Consul API.
5 type Client interface {
6 // Register a service with the local agent.
7 Register(r *consul.AgentServiceRegistration) error
8
9 // Deregister a service with the local agent.
10 Deregister(r *consul.AgentServiceRegistration) error
11
12 // Service
13 Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error)
14 }
15
16 type client struct {
17 consul *consul.Client
18 }
19
20 // NewClient returns an implementation of the Client interface, wrapping a
21 // concrete Consul client.
22 func NewClient(c *consul.Client) Client {
23 return &client{consul: c}
24 }
25
26 func (c *client) Register(r *consul.AgentServiceRegistration) error {
27 return c.consul.Agent().ServiceRegister(r)
28 }
29
30 func (c *client) Deregister(r *consul.AgentServiceRegistration) error {
31 return c.consul.Agent().ServiceDeregister(r.ID)
32 }
33
34 func (c *client) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) {
35 return c.consul.Health().Service(service, tag, passingOnly, queryOpts)
36 }
0 package consul
1
2 import (
3 "errors"
4 "io"
5 "reflect"
6 "testing"
7
8 stdconsul "github.com/hashicorp/consul/api"
9 "golang.org/x/net/context"
10
11 "github.com/go-kit/kit/endpoint"
12 )
13
14 func TestClientRegistration(t *testing.T) {
15 c := newTestClient(nil)
16
17 services, _, err := c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
18 if err != nil {
19 t.Error(err)
20 }
21 if want, have := 0, len(services); want != have {
22 t.Errorf("want %d, have %d", want, have)
23 }
24
25 if err := c.Register(testRegistration); err != nil {
26 t.Error(err)
27 }
28
29 if err := c.Register(testRegistration); err == nil {
30 t.Errorf("want error, have %v", err)
31 }
32
33 services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
34 if err != nil {
35 t.Error(err)
36 }
37 if want, have := 1, len(services); want != have {
38 t.Errorf("want %d, have %d", want, have)
39 }
40
41 if err := c.Deregister(testRegistration); err != nil {
42 t.Error(err)
43 }
44
45 if err := c.Deregister(testRegistration); err == nil {
46 t.Errorf("want error, have %v", err)
47 }
48
49 services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{})
50 if err != nil {
51 t.Error(err)
52 }
53 if want, have := 0, len(services); want != have {
54 t.Errorf("want %d, have %d", want, have)
55 }
56 }
57
58 type testClient struct {
59 entries []*stdconsul.ServiceEntry
60 }
61
62 func newTestClient(entries []*stdconsul.ServiceEntry) *testClient {
63 return &testClient{
64 entries: entries,
65 }
66 }
67
68 var _ Client = &testClient{}
69
70 func (c *testClient) Service(service, tag string, _ bool, opts *stdconsul.QueryOptions) ([]*stdconsul.ServiceEntry, *stdconsul.QueryMeta, error) {
71 var results []*stdconsul.ServiceEntry
72
73 for _, entry := range c.entries {
74 if entry.Service.Service != service {
75 continue
76 }
77 if tag != "" {
78 tagMap := map[string]struct{}{}
79
80 for _, t := range entry.Service.Tags {
81 tagMap[t] = struct{}{}
82 }
83
84 if _, ok := tagMap[tag]; !ok {
85 continue
86 }
87 }
88
89 results = append(results, entry)
90 }
91
92 return results, &stdconsul.QueryMeta{}, nil
93 }
94
95 func (c *testClient) Register(r *stdconsul.AgentServiceRegistration) error {
96 toAdd := registration2entry(r)
97
98 for _, entry := range c.entries {
99 if reflect.DeepEqual(*entry, *toAdd) {
100 return errors.New("duplicate")
101 }
102 }
103
104 c.entries = append(c.entries, toAdd)
105 return nil
106 }
107
108 func (c *testClient) Deregister(r *stdconsul.AgentServiceRegistration) error {
109 toDelete := registration2entry(r)
110
111 var newEntries []*stdconsul.ServiceEntry
112 for _, entry := range c.entries {
113 if reflect.DeepEqual(*entry, *toDelete) {
114 continue
115 }
116 newEntries = append(newEntries, entry)
117 }
118 if len(newEntries) == len(c.entries) {
119 return errors.New("not found")
120 }
121
122 c.entries = newEntries
123 return nil
124 }
125
126 func registration2entry(r *stdconsul.AgentServiceRegistration) *stdconsul.ServiceEntry {
127 return &stdconsul.ServiceEntry{
128 Node: &stdconsul.Node{
129 Node: "some-node",
130 Address: r.Address,
131 },
132 Service: &stdconsul.AgentService{
133 ID: r.ID,
134 Service: r.Name,
135 Tags: r.Tags,
136 Port: r.Port,
137 Address: r.Address,
138 },
139 // Checks ignored
140 }
141 }
142
143 func testFactory(instance string) (endpoint.Endpoint, io.Closer, error) {
144 return func(context.Context, interface{}) (interface{}, error) {
145 return instance, nil
146 }, nil, nil
147 }
148
149 var testRegistration = &stdconsul.AgentServiceRegistration{
150 ID: "my-id",
151 Name: "my-name",
152 Tags: []string{"my-tag-1", "my-tag-2"},
153 Port: 12345,
154 Address: "my-address",
155 }
0 // +build integration
1
2 package consul
3
4 import (
5 "io"
6 "os"
7 "testing"
8 "time"
9
10 "github.com/go-kit/kit/log"
11 "github.com/go-kit/kit/service"
12 stdconsul "github.com/hashicorp/consul/api"
13 )
14
15 func TestIntegration(t *testing.T) {
16 // Connect to Consul.
17 // docker run -p 8500:8500 progrium/consul -server -bootstrap
18 consulAddr := os.Getenv("CONSUL_ADDRESS")
19 if consulAddr == "" {
20 t.Fatal("CONSUL_ADDRESS is not set")
21 }
22 stdClient, err := stdconsul.NewClient(&stdconsul.Config{
23 Address: consulAddr,
24 })
25 if err != nil {
26 t.Fatal(err)
27 }
28 client := NewClient(stdClient)
29 logger := log.NewLogfmtLogger(os.Stderr)
30
31 // Produce a fake service registration.
32 r := &stdconsul.AgentServiceRegistration{
33 ID: "my-service-ID",
34 Name: "my-service-name",
35 Tags: []string{"alpha", "beta"},
36 Port: 12345,
37 Address: "my-address",
38 EnableTagOverride: false,
39 // skipping check(s)
40 }
41
42 // Build a subscriber on r.Name + r.Tags.
43 factory := func(instance string) (service.Service, io.Closer, error) {
44 t.Logf("factory invoked for %q", instance)
45 return service.Fixed{}, nil, nil
46 }
47 subscriber, err := NewSubscriber(
48 client,
49 factory,
50 log.NewContext(logger).With("component", "subscriber"),
51 r.Name,
52 r.Tags,
53 true,
54 )
55 if err != nil {
56 t.Fatal(err)
57 }
58
59 time.Sleep(time.Second)
60
61 // Before we publish, we should have no services.
62 services, err := subscriber.Services()
63 if err != nil {
64 t.Error(err)
65 }
66 if want, have := 0, len(services); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69
70 // Build a registrar for r.
71 registrar := NewRegistrar(client, r, log.NewContext(logger).With("component", "registrar"))
72 registrar.Register()
73 defer registrar.Deregister()
74
75 time.Sleep(time.Second)
76
77 // Now we should have one active service.
78 services, err = subscriber.Services()
79 if err != nil {
80 t.Error(err)
81 }
82 if want, have := 1, len(services); want != have {
83 t.Errorf("want %d, have %d", want, have)
84 }
85 }
0 package consul
1
2 import (
3 "fmt"
4
5 stdconsul "github.com/hashicorp/consul/api"
6
7 "github.com/go-kit/kit/log"
8 )
9
10 // Registrar registers service instance liveness information to Consul.
11 type Registrar struct {
12 client Client
13 registration *stdconsul.AgentServiceRegistration
14 logger log.Logger
15 }
16
17 // NewRegistrar returns a Consul Registrar acting on the provided catalog
18 // registration.
19 func NewRegistrar(client Client, r *stdconsul.AgentServiceRegistration, logger log.Logger) *Registrar {
20 return &Registrar{
21 client: client,
22 registration: r,
23 logger: log.NewContext(logger).With("service", r.Name, "tags", fmt.Sprint(r.Tags), "address", r.Address),
24 }
25 }
26
27 // Register implements sd.Registrar interface.
28 func (p *Registrar) Register() {
29 if err := p.client.Register(p.registration); err != nil {
30 p.logger.Log("err", err)
31 } else {
32 p.logger.Log("action", "register")
33 }
34 }
35
36 // Deregister implements sd.Registrar interface.
37 func (p *Registrar) Deregister() {
38 if err := p.client.Deregister(p.registration); err != nil {
39 p.logger.Log("err", err)
40 } else {
41 p.logger.Log("action", "deregister")
42 }
43 }
0 package consul
1
2 import (
3 "testing"
4
5 stdconsul "github.com/hashicorp/consul/api"
6
7 "github.com/go-kit/kit/log"
8 )
9
10 func TestRegistrar(t *testing.T) {
11 client := newTestClient([]*stdconsul.ServiceEntry{})
12 p := NewRegistrar(client, testRegistration, log.NewNopLogger())
13 if want, have := 0, len(client.entries); want != have {
14 t.Errorf("want %d, have %d", want, have)
15 }
16
17 p.Register()
18 if want, have := 1, len(client.entries); want != have {
19 t.Errorf("want %d, have %d", want, have)
20 }
21
22 p.Deregister()
23 if want, have := 0, len(client.entries); want != have {
24 t.Errorf("want %d, have %d", want, have)
25 }
26 }
0 package consul
1
2 import (
3 "fmt"
4 "io"
5
6 consul "github.com/hashicorp/consul/api"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/log"
10 "github.com/go-kit/kit/sd"
11 "github.com/go-kit/kit/sd/cache"
12 )
13
14 const defaultIndex = 0
15
16 // Subscriber yields endpoints for a service in Consul. Updates to the service
17 // are watched and will update the Subscriber endpoints.
18 type Subscriber struct {
19 cache *cache.Cache
20 client Client
21 logger log.Logger
22 service string
23 tags []string
24 passingOnly bool
25 endpointsc chan []endpoint.Endpoint
26 quitc chan struct{}
27 }
28
29 var _ sd.Subscriber = &Subscriber{}
30
31 // NewSubscriber returns a Consul subscriber which returns endpoints for the
32 // requested service. It only returns instances for which all of the passed tags
33 // are present.
34 func NewSubscriber(client Client, factory sd.Factory, logger log.Logger, service string, tags []string, passingOnly bool) (*Subscriber, error) {
35 s := &Subscriber{
36 cache: cache.New(factory, logger),
37 client: client,
38 logger: log.NewContext(logger).With("service", service, "tags", fmt.Sprint(tags)),
39 service: service,
40 tags: tags,
41 passingOnly: passingOnly,
42 quitc: make(chan struct{}),
43 }
44
45 instances, index, err := s.getInstances(defaultIndex, nil)
46 if err == nil {
47 s.logger.Log("instances", len(instances))
48 } else {
49 s.logger.Log("err", err)
50 }
51
52 s.cache.Update(instances)
53 go s.loop(index)
54 return s, nil
55 }
56
57 // Endpoints implements the Subscriber interface.
58 func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
59 return s.cache.Endpoints(), nil
60 }
61
62 // Stop terminates the subscriber.
63 func (s *Subscriber) Stop() {
64 close(s.quitc)
65 }
66
67 func (s *Subscriber) loop(lastIndex uint64) {
68 var (
69 instances []string
70 err error
71 )
72 for {
73 instances, lastIndex, err = s.getInstances(lastIndex, s.quitc)
74 switch {
75 case err == io.EOF:
76 return // stopped via quitc
77 case err != nil:
78 s.logger.Log("err", err)
79 default:
80 s.cache.Update(instances)
81 }
82 }
83 }
84
85 func (s *Subscriber) getInstances(lastIndex uint64, interruptc chan struct{}) ([]string, uint64, error) {
86 tag := ""
87 if len(s.tags) > 0 {
88 tag = s.tags[0]
89 }
90
91 // Consul doesn't support more than one tag in its service query method.
92 // https://github.com/hashicorp/consul/issues/294
93 // Hashi suggest prepared queries, but they don't support blocking.
94 // https://www.consul.io/docs/agent/http/query.html#execute
95 // If we want blocking for efficiency, we must filter tags manually.
96
97 type response struct {
98 instances []string
99 index uint64
100 }
101
102 var (
103 errc = make(chan error, 1)
104 resc = make(chan response, 1)
105 )
106
107 go func() {
108 entries, meta, err := s.client.Service(s.service, tag, s.passingOnly, &consul.QueryOptions{
109 WaitIndex: lastIndex,
110 })
111 if err != nil {
112 errc <- err
113 return
114 }
115 if len(s.tags) > 1 {
116 entries = filterEntries(entries, s.tags[1:]...)
117 }
118 resc <- response{
119 instances: makeInstances(entries),
120 index: meta.LastIndex,
121 }
122 }()
123
124 select {
125 case err := <-errc:
126 return nil, 0, err
127 case res := <-resc:
128 return res.instances, res.index, nil
129 case <-interruptc:
130 return nil, 0, io.EOF
131 }
132 }
133
134 func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry {
135 var es []*consul.ServiceEntry
136
137 ENTRIES:
138 for _, entry := range entries {
139 ts := make(map[string]struct{}, len(entry.Service.Tags))
140 for _, tag := range entry.Service.Tags {
141 ts[tag] = struct{}{}
142 }
143
144 for _, tag := range tags {
145 if _, ok := ts[tag]; !ok {
146 continue ENTRIES
147 }
148 }
149 es = append(es, entry)
150 }
151
152 return es
153 }
154
155 func makeInstances(entries []*consul.ServiceEntry) []string {
156 instances := make([]string, len(entries))
157 for i, entry := range entries {
158 addr := entry.Node.Address
159 if entry.Service.Address != "" {
160 addr = entry.Service.Address
161 }
162 instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port)
163 }
164 return instances
165 }
0 package consul
1
2 import (
3 "testing"
4
5 consul "github.com/hashicorp/consul/api"
6 "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/log"
9 )
10
11 var consulState = []*consul.ServiceEntry{
12 {
13 Node: &consul.Node{
14 Address: "10.0.0.0",
15 Node: "app00.local",
16 },
17 Service: &consul.AgentService{
18 ID: "search-api-0",
19 Port: 8000,
20 Service: "search",
21 Tags: []string{
22 "api",
23 "v1",
24 },
25 },
26 },
27 {
28 Node: &consul.Node{
29 Address: "10.0.0.1",
30 Node: "app01.local",
31 },
32 Service: &consul.AgentService{
33 ID: "search-api-1",
34 Port: 8001,
35 Service: "search",
36 Tags: []string{
37 "api",
38 "v2",
39 },
40 },
41 },
42 {
43 Node: &consul.Node{
44 Address: "10.0.0.1",
45 Node: "app01.local",
46 },
47 Service: &consul.AgentService{
48 Address: "10.0.0.10",
49 ID: "search-db-0",
50 Port: 9000,
51 Service: "search",
52 Tags: []string{
53 "db",
54 },
55 },
56 },
57 }
58
59 func TestSubscriber(t *testing.T) {
60 var (
61 logger = log.NewNopLogger()
62 client = newTestClient(consulState)
63 )
64
65 s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api"}, true)
66 if err != nil {
67 t.Fatal(err)
68 }
69 defer s.Stop()
70
71 endpoints, err := s.Endpoints()
72 if err != nil {
73 t.Fatal(err)
74 }
75
76 if want, have := 2, len(endpoints); want != have {
77 t.Errorf("want %d, have %d", want, have)
78 }
79 }
80
81 func TestSubscriberNoService(t *testing.T) {
82 var (
83 logger = log.NewNopLogger()
84 client = newTestClient(consulState)
85 )
86
87 s, err := NewSubscriber(client, testFactory, logger, "feed", []string{}, true)
88 if err != nil {
89 t.Fatal(err)
90 }
91 defer s.Stop()
92
93 endpoints, err := s.Endpoints()
94 if err != nil {
95 t.Fatal(err)
96 }
97
98 if want, have := 0, len(endpoints); want != have {
99 t.Fatalf("want %d, have %d", want, have)
100 }
101 }
102
103 func TestSubscriberWithTags(t *testing.T) {
104 var (
105 logger = log.NewNopLogger()
106 client = newTestClient(consulState)
107 )
108
109 s, err := NewSubscriber(client, testFactory, logger, "search", []string{"api", "v2"}, true)
110 if err != nil {
111 t.Fatal(err)
112 }
113 defer s.Stop()
114
115 endpoints, err := s.Endpoints()
116 if err != nil {
117 t.Fatal(err)
118 }
119
120 if want, have := 1, len(endpoints); want != have {
121 t.Fatalf("want %d, have %d", want, have)
122 }
123 }
124
125 func TestSubscriberAddressOverride(t *testing.T) {
126 s, err := NewSubscriber(newTestClient(consulState), testFactory, log.NewNopLogger(), "search", []string{"db"}, true)
127 if err != nil {
128 t.Fatal(err)
129 }
130 defer s.Stop()
131
132 endpoints, err := s.Endpoints()
133 if err != nil {
134 t.Fatal(err)
135 }
136
137 if want, have := 1, len(endpoints); want != have {
138 t.Fatalf("want %d, have %d", want, have)
139 }
140
141 response, err := endpoints[0](context.Background(), struct{}{})
142 if err != nil {
143 t.Fatal(err)
144 }
145
146 if want, have := "10.0.0.10:9000", response.(string); want != have {
147 t.Errorf("want %q, have %q", want, have)
148 }
149 }
0 package dnssrv
1
2 import "net"
3
4 // Lookup is a function that resolves a DNS SRV record to multiple addresses.
5 // It has the same signature as net.LookupSRV.
6 type Lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error)
0 package dnssrv
1
2 import (
3 "fmt"
4 "net"
5 "time"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/log"
9 "github.com/go-kit/kit/sd"
10 "github.com/go-kit/kit/sd/cache"
11 )
12
13 // Subscriber yields endpoints taken from the named DNS SRV record. The name is
14 // resolved on a fixed schedule. Priorities and weights are ignored.
15 type Subscriber struct {
16 name string
17 cache *cache.Cache
18 logger log.Logger
19 quit chan struct{}
20 }
21
22 // NewSubscriber returns a DNS SRV subscriber.
23 func NewSubscriber(
24 name string,
25 ttl time.Duration,
26 factory sd.Factory,
27 logger log.Logger,
28 ) *Subscriber {
29 return NewSubscriberDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger)
30 }
31
32 // NewSubscriberDetailed is the same as NewSubscriber, but allows users to
33 // provide an explicit lookup refresh ticker instead of a TTL, and specify the
34 // lookup function instead of using net.LookupSRV.
35 func NewSubscriberDetailed(
36 name string,
37 refresh *time.Ticker,
38 lookup Lookup,
39 factory sd.Factory,
40 logger log.Logger,
41 ) *Subscriber {
42 p := &Subscriber{
43 name: name,
44 cache: cache.New(factory, logger),
45 logger: logger,
46 quit: make(chan struct{}),
47 }
48
49 instances, err := p.resolve(lookup)
50 if err == nil {
51 logger.Log("name", name, "instances", len(instances))
52 } else {
53 logger.Log("name", name, "err", err)
54 }
55 p.cache.Update(instances)
56
57 go p.loop(refresh, lookup)
58 return p
59 }
60
61 // Stop terminates the Subscriber.
62 func (p *Subscriber) Stop() {
63 close(p.quit)
64 }
65
66 func (p *Subscriber) loop(t *time.Ticker, lookup Lookup) {
67 defer t.Stop()
68 for {
69 select {
70 case <-t.C:
71 instances, err := p.resolve(lookup)
72 if err != nil {
73 p.logger.Log("name", p.name, "err", err)
74 continue // don't replace potentially-good with bad
75 }
76 p.cache.Update(instances)
77
78 case <-p.quit:
79 return
80 }
81 }
82 }
83
84 // Endpoints implements the Subscriber interface.
85 func (p *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
86 return p.cache.Endpoints(), nil
87 }
88
89 func (p *Subscriber) resolve(lookup Lookup) ([]string, error) {
90 _, addrs, err := lookup("", "", p.name)
91 if err != nil {
92 return []string{}, err
93 }
94 instances := make([]string, len(addrs))
95 for i, addr := range addrs {
96 instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port))
97 }
98 return instances, nil
99 }
0 package dnssrv
1
2 import (
3 "io"
4 "net"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/log"
11 )
12
13 func TestRefresh(t *testing.T) {
14 name := "some.service.internal"
15
16 ticker := time.NewTicker(time.Second)
17 ticker.Stop()
18 tickc := make(chan time.Time)
19 ticker.C = tickc
20
21 var lookups uint64
22 records := []*net.SRV{}
23 lookup := func(service, proto, name string) (string, []*net.SRV, error) {
24 t.Logf("lookup(%q, %q, %q)", service, proto, name)
25 atomic.AddUint64(&lookups, 1)
26 return "cname", records, nil
27 }
28
29 var generates uint64
30 factory := func(instance string) (endpoint.Endpoint, io.Closer, error) {
31 t.Logf("factory(%q)", instance)
32 atomic.AddUint64(&generates, 1)
33 return endpoint.Nop, nopCloser{}, nil
34 }
35
36 subscriber := NewSubscriberDetailed(name, ticker, lookup, factory, log.NewNopLogger())
37 defer subscriber.Stop()
38
39 // First lookup, empty
40 endpoints, err := subscriber.Endpoints()
41 if err != nil {
42 t.Error(err)
43 }
44 if want, have := 0, len(endpoints); want != have {
45 t.Errorf("want %d, have %d", want, have)
46 }
47 if want, have := uint64(1), atomic.LoadUint64(&lookups); want != have {
48 t.Errorf("want %d, have %d", want, have)
49 }
50 if want, have := uint64(0), atomic.LoadUint64(&generates); want != have {
51 t.Errorf("want %d, have %d", want, have)
52 }
53
54 // Load some records and lookup again
55 records = []*net.SRV{
56 &net.SRV{Target: "1.0.0.1", Port: 1001},
57 &net.SRV{Target: "1.0.0.2", Port: 1002},
58 &net.SRV{Target: "1.0.0.3", Port: 1003},
59 }
60 tickc <- time.Now()
61
62 // There is a race condition where the subscriber.Endpoints call below
63 // invokes the cache before it is updated by the tick above.
64 // TODO(pb): solve by running the read through the loop goroutine.
65 time.Sleep(100 * time.Millisecond)
66
67 endpoints, err = subscriber.Endpoints()
68 if err != nil {
69 t.Error(err)
70 }
71 if want, have := 3, len(endpoints); want != have {
72 t.Errorf("want %d, have %d", want, have)
73 }
74 if want, have := uint64(2), atomic.LoadUint64(&lookups); want != have {
75 t.Errorf("want %d, have %d", want, have)
76 }
77 if want, have := uint64(len(records)), atomic.LoadUint64(&generates); want != have {
78 t.Errorf("want %d, have %d", want, have)
79 }
80 }
81
82 type nopCloser struct{}
83
84 func (nopCloser) Close() error { return nil }
0 // Package sd provides utilities related to service discovery. That includes
1 // subscribing to service discovery systems in order to reach remote instances,
2 // and publishing to service discovery systems to make an instance available.
3 // Implementations are provided for most common systems.
4 package sd
0 package etcd
1
2 import (
3 "crypto/tls"
4 "crypto/x509"
5 "io/ioutil"
6 "net"
7 "net/http"
8 "time"
9
10 etcd "github.com/coreos/etcd/client"
11 "golang.org/x/net/context"
12 )
13
14 // Client is a wrapper around the etcd client.
15 type Client interface {
16 // GetEntries will query the given prefix in etcd and returns a set of entries.
17 GetEntries(prefix string) ([]string, error)
18
19 // WatchPrefix starts watching every change for given prefix in etcd. When an
20 // change is detected it will populate the responseChan when an *etcd.Response.
21 WatchPrefix(prefix string, responseChan chan *etcd.Response)
22 }
23
24 type client struct {
25 keysAPI etcd.KeysAPI
26 ctx context.Context
27 }
28
29 // ClientOptions defines options for the etcd client.
30 type ClientOptions struct {
31 Cert string
32 Key string
33 CaCert string
34 DialTimeout time.Duration
35 DialKeepAline time.Duration
36 HeaderTimeoutPerRequest time.Duration
37 }
38
39 // NewClient returns an *etcd.Client with a connection to the named machines.
40 // It will return an error if a connection to the cluster cannot be made.
41 // The parameter machines needs to be a full URL with schemas.
42 // e.g. "http://localhost:2379" will work, but "localhost:2379" will not.
43 func NewClient(ctx context.Context, machines []string, options ClientOptions) (Client, error) {
44 var (
45 c etcd.KeysAPI
46 err error
47 caCertCt []byte
48 tlsCert tls.Certificate
49 )
50
51 if options.Cert != "" && options.Key != "" {
52 tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key)
53 if err != nil {
54 return nil, err
55 }
56
57 caCertCt, err = ioutil.ReadFile(options.CaCert)
58 if err != nil {
59 return nil, err
60 }
61 caCertPool := x509.NewCertPool()
62 caCertPool.AppendCertsFromPEM(caCertCt)
63
64 tlsConfig := &tls.Config{
65 Certificates: []tls.Certificate{tlsCert},
66 RootCAs: caCertPool,
67 }
68
69 transport := &http.Transport{
70 TLSClientConfig: tlsConfig,
71 Dial: func(network, addr string) (net.Conn, error) {
72 dial := &net.Dialer{
73 Timeout: options.DialTimeout,
74 KeepAlive: options.DialKeepAline,
75 }
76 return dial.Dial(network, addr)
77 },
78 }
79
80 cfg := etcd.Config{
81 Endpoints: machines,
82 Transport: transport,
83 HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
84 }
85 ce, err := etcd.New(cfg)
86 if err != nil {
87 return nil, err
88 }
89 c = etcd.NewKeysAPI(ce)
90 } else {
91 cfg := etcd.Config{
92 Endpoints: machines,
93 Transport: etcd.DefaultTransport,
94 HeaderTimeoutPerRequest: options.HeaderTimeoutPerRequest,
95 }
96 ce, err := etcd.New(cfg)
97 if err != nil {
98 return nil, err
99 }
100 c = etcd.NewKeysAPI(ce)
101 }
102
103 return &client{c, ctx}, nil
104 }
105
106 // GetEntries implements the etcd Client interface.
107 func (c *client) GetEntries(key string) ([]string, error) {
108 resp, err := c.keysAPI.Get(c.ctx, key, &etcd.GetOptions{Recursive: true})
109 if err != nil {
110 return nil, err
111 }
112
113 entries := make([]string, len(resp.Node.Nodes))
114 for i, node := range resp.Node.Nodes {
115 entries[i] = node.Value
116 }
117 return entries, nil
118 }
119
120 // WatchPrefix implements the etcd Client interface.
121 func (c *client) WatchPrefix(prefix string, responseChan chan *etcd.Response) {
122 watch := c.keysAPI.Watcher(prefix, &etcd.WatcherOptions{AfterIndex: 0, Recursive: true})
123 for {
124 res, err := watch.Next(c.ctx)
125 if err != nil {
126 return
127 }
128 responseChan <- res
129 }
130 }
0 package etcd
1
2 import (
3 etcd "github.com/coreos/etcd/client"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/log"
7 "github.com/go-kit/kit/sd"
8 "github.com/go-kit/kit/sd/cache"
9 )
10
11 // Subscriber yield endpoints stored in a certain etcd keyspace. Any kind of
12 // change in that keyspace is watched and will update the Subscriber endpoints.
13 type Subscriber struct {
14 client Client
15 prefix string
16 cache *cache.Cache
17 logger log.Logger
18 quitc chan struct{}
19 }
20
21 var _ sd.Subscriber = &Subscriber{}
22
23 // NewSubscriber returns an etcd subscriber. It will start watching the given
24 // prefix for changes, and update the endpoints.
25 func NewSubscriber(c Client, prefix string, factory sd.Factory, logger log.Logger) (*Subscriber, error) {
26 s := &Subscriber{
27 client: c,
28 prefix: prefix,
29 cache: cache.New(factory, logger),
30 logger: logger,
31 quitc: make(chan struct{}),
32 }
33
34 instances, err := s.client.GetEntries(s.prefix)
35 if err == nil {
36 logger.Log("prefix", s.prefix, "instances", len(instances))
37 } else {
38 logger.Log("prefix", s.prefix, "err", err)
39 }
40 s.cache.Update(instances)
41
42 go s.loop()
43 return s, nil
44 }
45
46 func (s *Subscriber) loop() {
47 responseChan := make(chan *etcd.Response)
48 go s.client.WatchPrefix(s.prefix, responseChan)
49 for {
50 select {
51 case <-responseChan:
52 instances, err := s.client.GetEntries(s.prefix)
53 if err != nil {
54 s.logger.Log("msg", "failed to retrieve entries", "err", err)
55 continue
56 }
57 s.cache.Update(instances)
58
59 case <-s.quitc:
60 return
61 }
62 }
63 }
64
65 // Endpoints implements the Subscriber interface.
66 func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
67 return s.cache.Endpoints(), nil
68 }
69
70 // Stop terminates the Subscriber.
71 func (s *Subscriber) Stop() {
72 close(s.quitc)
73 }
0 package etcd
1
2 import (
3 "errors"
4 "io"
5 "testing"
6
7 stdetcd "github.com/coreos/etcd/client"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/log"
11 )
12
13 var (
14 node = &stdetcd.Node{
15 Key: "/foo",
16 Nodes: []*stdetcd.Node{
17 {Key: "/foo/1", Value: "1:1"},
18 {Key: "/foo/2", Value: "1:2"},
19 },
20 }
21 fakeResponse = &stdetcd.Response{
22 Node: node,
23 }
24 )
25
26 func TestSubscriber(t *testing.T) {
27 factory := func(string) (endpoint.Endpoint, io.Closer, error) {
28 return endpoint.Nop, nil, nil
29 }
30
31 client := &fakeClient{
32 responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
33 }
34
35 s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger())
36 if err != nil {
37 t.Fatal(err)
38 }
39 defer s.Stop()
40
41 if _, err := s.Endpoints(); err != nil {
42 t.Fatal(err)
43 }
44 }
45
46 func TestBadFactory(t *testing.T) {
47 factory := func(string) (endpoint.Endpoint, io.Closer, error) {
48 return nil, nil, errors.New("kaboom")
49 }
50
51 client := &fakeClient{
52 responses: map[string]*stdetcd.Response{"/foo": fakeResponse},
53 }
54
55 s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger())
56 if err != nil {
57 t.Fatal(err)
58 }
59 defer s.Stop()
60
61 endpoints, err := s.Endpoints()
62 if err != nil {
63 t.Fatal(err)
64 }
65
66 if want, have := 0, len(endpoints); want != have {
67 t.Errorf("want %d, have %d", want, have)
68 }
69 }
70
71 type fakeClient struct {
72 responses map[string]*stdetcd.Response
73 }
74
75 func (c *fakeClient) GetEntries(prefix string) ([]string, error) {
76 response, ok := c.responses[prefix]
77 if !ok {
78 return nil, errors.New("key not exist")
79 }
80
81 entries := make([]string, len(response.Node.Nodes))
82 for i, node := range response.Node.Nodes {
83 entries[i] = node.Value
84 }
85 return entries, nil
86 }
87
88 func (c *fakeClient) WatchPrefix(prefix string, responseChan chan *stdetcd.Response) {}
0 package sd
1
2 import (
3 "io"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Factory is a function that converts an instance string (e.g. host:port) to a
9 // specific endpoint. Instances that provide multiple endpoints require multiple
10 // factories. A factory also returns an io.Closer that's invoked when the
11 // instance goes away and needs to be cleaned up.
12 //
13 // Users are expected to provide their own factory functions that assume
14 // specific transports, or can deduce transports by parsing the instance string.
15 type Factory func(instance string) (endpoint.Endpoint, io.Closer, error)
0 package sd
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // FixedSubscriber yields a fixed set of services.
5 type FixedSubscriber []endpoint.Endpoint
6
7 // Endpoints implements Subscriber.
8 func (s FixedSubscriber) Endpoints() ([]endpoint.Endpoint, error) { return s, nil }
0 package lb
1
2 import (
3 "errors"
4
5 "github.com/go-kit/kit/endpoint"
6 )
7
8 // Balancer yields endpoints according to some heuristic.
9 type Balancer interface {
10 Endpoint() (endpoint.Endpoint, error)
11 }
12
13 // ErrNoEndpoints is returned when no qualifying endpoints are available.
14 var ErrNoEndpoints = errors.New("no endpoints available")
0 // Package lb deals with client-side load balancing across multiple identical
1 // instances of services and endpoints. When combined with a service discovery
2 // system of record, it enables a more decentralized architecture, removing the
3 // need for separate load balancers like HAProxy.
4 package lb
0 package lb
1
2 import (
3 "math/rand"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/sd"
7 )
8
9 // NewRandom returns a load balancer that selects services randomly.
10 func NewRandom(s sd.Subscriber, seed int64) Balancer {
11 return &random{
12 s: s,
13 r: rand.New(rand.NewSource(seed)),
14 }
15 }
16
17 type random struct {
18 s sd.Subscriber
19 r *rand.Rand
20 }
21
22 func (r *random) Endpoint() (endpoint.Endpoint, error) {
23 endpoints, err := r.s.Endpoints()
24 if err != nil {
25 return nil, err
26 }
27 if len(endpoints) <= 0 {
28 return nil, ErrNoEndpoints
29 }
30 return endpoints[r.r.Intn(len(endpoints))], nil
31 }
0 package lb
1
2 import (
3 "math"
4 "testing"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/sd"
8 "golang.org/x/net/context"
9 )
10
11 func TestRandom(t *testing.T) {
12 var (
13 n = 7
14 endpoints = make([]endpoint.Endpoint, n)
15 counts = make([]int, n)
16 seed = int64(12345)
17 iterations = 1000000
18 want = iterations / n
19 tolerance = want / 100 // 1%
20 )
21
22 for i := 0; i < n; i++ {
23 i0 := i
24 endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil }
25 }
26
27 subscriber := sd.FixedSubscriber(endpoints)
28 balancer := NewRandom(subscriber, seed)
29
30 for i := 0; i < iterations; i++ {
31 endpoint, _ := balancer.Endpoint()
32 endpoint(context.Background(), struct{}{})
33 }
34
35 for i, have := range counts {
36 delta := int(math.Abs(float64(want - have)))
37 if delta > tolerance {
38 t.Errorf("%d: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance)
39 }
40 }
41 }
42
43 func TestRandomNoEndpoints(t *testing.T) {
44 subscriber := sd.FixedSubscriber{}
45 balancer := NewRandom(subscriber, 1415926)
46 _, err := balancer.Endpoint()
47 if want, have := ErrNoEndpoints, err; want != have {
48 t.Errorf("want %v, have %v", want, have)
49 }
50
51 }
0 package lb
1
2 import (
3 "fmt"
4 "strings"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 )
11
12 // Retry wraps a service load balancer and returns an endpoint oriented load
13 // balancer for the specified service method.
14 // Requests to the endpoint will be automatically load balanced via the load
15 // balancer. Requests that return errors will be retried until they succeed,
16 // up to max times, or until the timeout is elapsed, whichever comes first.
17 func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint {
18 if b == nil {
19 panic("nil Balancer")
20 }
21 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
22 var (
23 newctx, cancel = context.WithTimeout(ctx, timeout)
24 responses = make(chan interface{}, 1)
25 errs = make(chan error, 1)
26 a = []string{}
27 )
28 defer cancel()
29 for i := 1; i <= max; i++ {
30 go func() {
31 e, err := b.Endpoint()
32 if err != nil {
33 errs <- err
34 return
35 }
36 response, err := e(newctx, request)
37 if err != nil {
38 errs <- err
39 return
40 }
41 responses <- response
42 }()
43
44 select {
45 case <-newctx.Done():
46 return nil, newctx.Err()
47 case response := <-responses:
48 return response, nil
49 case err := <-errs:
50 a = append(a, err.Error())
51 continue
52 }
53 }
54 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
55 }
56 }
0 package lb_test
1
2 import (
3 "errors"
4 "testing"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/sd"
11 loadbalancer "github.com/go-kit/kit/sd/lb"
12 )
13
14 func TestRetryMaxTotalFail(t *testing.T) {
15 var (
16 endpoints = sd.FixedSubscriber{} // no endpoints
17 lb = loadbalancer.NewRoundRobin(endpoints)
18 retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries
19 ctx = context.Background()
20 )
21 if _, err := retry(ctx, struct{}{}); err == nil {
22 t.Errorf("expected error, got none") // should fail
23 }
24 }
25
26 func TestRetryMaxPartialFail(t *testing.T) {
27 var (
28 endpoints = []endpoint.Endpoint{
29 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
30 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
31 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
32 }
33 subscriber = sd.FixedSubscriber{
34 0: endpoints[0],
35 1: endpoints[1],
36 2: endpoints[2],
37 }
38 retries = len(endpoints) - 1 // not quite enough retries
39 lb = loadbalancer.NewRoundRobin(subscriber)
40 ctx = context.Background()
41 )
42 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil {
43 t.Errorf("expected error, got none")
44 }
45 }
46
47 func TestRetryMaxSuccess(t *testing.T) {
48 var (
49 endpoints = []endpoint.Endpoint{
50 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
51 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
52 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
53 }
54 subscriber = sd.FixedSubscriber{
55 0: endpoints[0],
56 1: endpoints[1],
57 2: endpoints[2],
58 }
59 retries = len(endpoints) // exactly enough retries
60 lb = loadbalancer.NewRoundRobin(subscriber)
61 ctx = context.Background()
62 )
63 if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil {
64 t.Error(err)
65 }
66 }
67
68 func TestRetryTimeout(t *testing.T) {
69 var (
70 step = make(chan struct{})
71 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
72 timeout = time.Millisecond
73 retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e}))
74 errs = make(chan error, 1)
75 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
76 )
77
78 go func() { step <- struct{}{} }() // queue up a flush of the endpoint
79 invoke() // invoke the endpoint and trigger the flush
80 if err := <-errs; err != nil { // that should succeed
81 t.Error(err)
82 }
83
84 go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
85 invoke() // invoke the endpoint
86 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
87 t.Errorf("wanted %v, got none", context.DeadlineExceeded)
88 }
89 }
0 package lb
1
2 import (
3 "sync/atomic"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/sd"
7 )
8
9 // NewRoundRobin returns a load balancer that returns services in sequence.
10 func NewRoundRobin(s sd.Subscriber) Balancer {
11 return &roundRobin{
12 s: s,
13 c: 0,
14 }
15 }
16
17 type roundRobin struct {
18 s sd.Subscriber
19 c uint64
20 }
21
22 func (rr *roundRobin) Endpoint() (endpoint.Endpoint, error) {
23 endpoints, err := rr.s.Endpoints()
24 if err != nil {
25 return nil, err
26 }
27 if len(endpoints) <= 0 {
28 return nil, ErrNoEndpoints
29 }
30 old := atomic.AddUint64(&rr.c, 1) - 1
31 idx := old % uint64(len(endpoints))
32 return endpoints[idx], nil
33 }
0 package lb
1
2 import (
3 "reflect"
4 "sync"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "golang.org/x/net/context"
10
11 "github.com/go-kit/kit/endpoint"
12 "github.com/go-kit/kit/sd"
13 )
14
15 func TestRoundRobin(t *testing.T) {
16 var (
17 counts = []int{0, 0, 0}
18 endpoints = []endpoint.Endpoint{
19 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
20 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
21 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
22 }
23 )
24
25 subscriber := sd.FixedSubscriber(endpoints)
26 balancer := NewRoundRobin(subscriber)
27
28 for i, want := range [][]int{
29 {1, 0, 0},
30 {1, 1, 0},
31 {1, 1, 1},
32 {2, 1, 1},
33 {2, 2, 1},
34 {2, 2, 2},
35 {3, 2, 2},
36 } {
37 endpoint, err := balancer.Endpoint()
38 if err != nil {
39 t.Fatal(err)
40 }
41 endpoint(context.Background(), struct{}{})
42 if have := counts; !reflect.DeepEqual(want, have) {
43 t.Fatalf("%d: want %v, have %v", i, want, have)
44 }
45 }
46 }
47
48 func TestRoundRobinNoEndpoints(t *testing.T) {
49 subscriber := sd.FixedSubscriber{}
50 balancer := NewRoundRobin(subscriber)
51 _, err := balancer.Endpoint()
52 if want, have := ErrNoEndpoints, err; want != have {
53 t.Errorf("want %v, have %v", want, have)
54 }
55 }
56
57 func TestRoundRobinNoRace(t *testing.T) {
58 balancer := NewRoundRobin(sd.FixedSubscriber([]endpoint.Endpoint{
59 endpoint.Nop,
60 endpoint.Nop,
61 endpoint.Nop,
62 endpoint.Nop,
63 endpoint.Nop,
64 }))
65
66 var (
67 n = 100
68 done = make(chan struct{})
69 wg sync.WaitGroup
70 count uint64
71 )
72
73 wg.Add(n)
74
75 for i := 0; i < n; i++ {
76 go func() {
77 defer wg.Done()
78 for {
79 select {
80 case <-done:
81 return
82 default:
83 _, _ = balancer.Endpoint()
84 atomic.AddUint64(&count, 1)
85 }
86 }
87 }()
88 }
89
90 time.Sleep(time.Second)
91 close(done)
92 wg.Wait()
93
94 t.Logf("made %d calls", atomic.LoadUint64(&count))
95 }
0 package sd
1
2 // Registrar registers instance information to a service discovery system when
3 // an instance becomes alive and healthy, and deregisters that information when
4 // the service becomes unhealthy or goes away.
5 //
6 // Registrar implementations exist for various service discovery systems. Note
7 // that identifying instance information (e.g. host:port) must be given via the
8 // concrete constructor; this interface merely signals lifecycle changes.
9 type Registrar interface {
10 Register()
11 Deregister()
12 }
0 package sd
1
2 import "github.com/go-kit/kit/endpoint"
3
4 // Subscriber listens to a service discovery system and yields a set of
5 // identical endpoints on demand. An error indicates a problem with connectivity
6 // to the service discovery system, or within the system itself; a subscriber
7 // may yield no endpoints without error.
8 type Subscriber interface {
9 Endpoints() ([]endpoint.Endpoint, error)
10 }
0 package zk
1
2 import (
3 "errors"
4 "net"
5 "strings"
6 "time"
7
8 "github.com/samuel/go-zookeeper/zk"
9
10 "github.com/go-kit/kit/log"
11 )
12
13 // DefaultACL is the default ACL to use for creating znodes.
14 var (
15 DefaultACL = zk.WorldACL(zk.PermAll)
16 ErrInvalidCredentials = errors.New("invalid credentials provided")
17 ErrClientClosed = errors.New("client service closed")
18 )
19
20 const (
21 // DefaultConnectTimeout is the default timeout to establish a connection to
22 // a ZooKeeper node.
23 DefaultConnectTimeout = 2 * time.Second
24 // DefaultSessionTimeout is the default timeout to keep the current
25 // ZooKeeper session alive during a temporary disconnect.
26 DefaultSessionTimeout = 5 * time.Second
27 )
28
29 // Client is a wrapper around a lower level ZooKeeper client implementation.
30 type Client interface {
31 // GetEntries should query the provided path in ZooKeeper, place a watch on
32 // it and retrieve data from its current child nodes.
33 GetEntries(path string) ([]string, <-chan zk.Event, error)
34 // CreateParentNodes should try to create the path in case it does not exist
35 // yet on ZooKeeper.
36 CreateParentNodes(path string) error
37 // Stop should properly shutdown the client implementation
38 Stop()
39 }
40
41 type clientConfig struct {
42 logger log.Logger
43 acl []zk.ACL
44 credentials []byte
45 connectTimeout time.Duration
46 sessionTimeout time.Duration
47 rootNodePayload [][]byte
48 eventHandler func(zk.Event)
49 }
50
51 // Option functions enable friendly APIs.
52 type Option func(*clientConfig) error
53
54 type client struct {
55 *zk.Conn
56 clientConfig
57 active bool
58 quit chan struct{}
59 }
60
61 // ACL returns an Option specifying a non-default ACL for creating parent nodes.
62 func ACL(acl []zk.ACL) Option {
63 return func(c *clientConfig) error {
64 c.acl = acl
65 return nil
66 }
67 }
68
69 // Credentials returns an Option specifying a user/password combination which
70 // the client will use to authenticate itself with.
71 func Credentials(user, pass string) Option {
72 return func(c *clientConfig) error {
73 if user == "" || pass == "" {
74 return ErrInvalidCredentials
75 }
76 c.credentials = []byte(user + ":" + pass)
77 return nil
78 }
79 }
80
81 // ConnectTimeout returns an Option specifying a non-default connection timeout
82 // when we try to establish a connection to a ZooKeeper server.
83 func ConnectTimeout(t time.Duration) Option {
84 return func(c *clientConfig) error {
85 if t.Seconds() < 1 {
86 return errors.New("invalid connect timeout (minimum value is 1 second)")
87 }
88 c.connectTimeout = t
89 return nil
90 }
91 }
92
93 // SessionTimeout returns an Option specifying a non-default session timeout.
94 func SessionTimeout(t time.Duration) Option {
95 return func(c *clientConfig) error {
96 if t.Seconds() < 1 {
97 return errors.New("invalid session timeout (minimum value is 1 second)")
98 }
99 c.sessionTimeout = t
100 return nil
101 }
102 }
103
104 // Payload returns an Option specifying non-default data values for each znode
105 // created by CreateParentNodes.
106 func Payload(payload [][]byte) Option {
107 return func(c *clientConfig) error {
108 c.rootNodePayload = payload
109 return nil
110 }
111 }
112
113 // EventHandler returns an Option specifying a callback function to handle
114 // incoming zk.Event payloads (ZooKeeper connection events).
115 func EventHandler(handler func(zk.Event)) Option {
116 return func(c *clientConfig) error {
117 c.eventHandler = handler
118 return nil
119 }
120 }
121
122 // NewClient returns a ZooKeeper client with a connection to the server cluster.
123 // It will return an error if the server cluster cannot be resolved.
124 func NewClient(servers []string, logger log.Logger, options ...Option) (Client, error) {
125 defaultEventHandler := func(event zk.Event) {
126 logger.Log("eventtype", event.Type.String(), "server", event.Server, "state", event.State.String(), "err", event.Err)
127 }
128 config := clientConfig{
129 acl: DefaultACL,
130 connectTimeout: DefaultConnectTimeout,
131 sessionTimeout: DefaultSessionTimeout,
132 eventHandler: defaultEventHandler,
133 logger: logger,
134 }
135 for _, option := range options {
136 if err := option(&config); err != nil {
137 return nil, err
138 }
139 }
140 // dialer overrides the default ZooKeeper library Dialer so we can configure
141 // the connectTimeout. The current library has a hardcoded value of 1 second
142 // and there are reports of race conditions, due to slow DNS resolvers and
143 // other network latency issues.
144 dialer := func(network, address string, _ time.Duration) (net.Conn, error) {
145 return net.DialTimeout(network, address, config.connectTimeout)
146 }
147 conn, eventc, err := zk.Connect(servers, config.sessionTimeout, withLogger(logger), zk.WithDialer(dialer))
148
149 if err != nil {
150 return nil, err
151 }
152
153 if len(config.credentials) > 0 {
154 err = conn.AddAuth("digest", config.credentials)
155 if err != nil {
156 return nil, err
157 }
158 }
159
160 c := &client{conn, config, true, make(chan struct{})}
161
162 // Start listening for incoming Event payloads and callback the set
163 // eventHandler.
164 go func() {
165 for {
166 select {
167 case event := <-eventc:
168 config.eventHandler(event)
169 case <-c.quit:
170 return
171 }
172 }
173 }()
174 return c, nil
175 }
176
177 // CreateParentNodes implements the ZooKeeper Client interface.
178 func (c *client) CreateParentNodes(path string) error {
179 if !c.active {
180 return ErrClientClosed
181 }
182 if path[0] != '/' {
183 return zk.ErrInvalidPath
184 }
185 payload := []byte("")
186 pathString := ""
187 pathNodes := strings.Split(path, "/")
188 for i := 1; i < len(pathNodes); i++ {
189 if i <= len(c.rootNodePayload) {
190 payload = c.rootNodePayload[i-1]
191 } else {
192 payload = []byte("")
193 }
194 pathString += "/" + pathNodes[i]
195 _, err := c.Create(pathString, payload, 0, c.acl)
196 // not being able to create the node because it exists or not having
197 // sufficient rights is not an issue. It is ok for the node to already
198 // exist and/or us to only have read rights
199 if err != nil && err != zk.ErrNodeExists && err != zk.ErrNoAuth {
200 return err
201 }
202 }
203 return nil
204 }
205
206 // GetEntries implements the ZooKeeper Client interface.
207 func (c *client) GetEntries(path string) ([]string, <-chan zk.Event, error) {
208 // retrieve list of child nodes for given path and add watch to path
209 znodes, _, eventc, err := c.ChildrenW(path)
210
211 if err != nil {
212 return nil, eventc, err
213 }
214
215 var resp []string
216 for _, znode := range znodes {
217 // retrieve payload for child znode and add to response array
218 if data, _, err := c.Get(path + "/" + znode); err == nil {
219 resp = append(resp, string(data))
220 }
221 }
222 return resp, eventc, nil
223 }
224
225 // Stop implements the ZooKeeper Client interface.
226 func (c *client) Stop() {
227 c.active = false
228 close(c.quit)
229 c.Close()
230 }
0 package zk
1
2 import (
3 "bytes"
4 "testing"
5 "time"
6
7 stdzk "github.com/samuel/go-zookeeper/zk"
8
9 "github.com/go-kit/kit/log"
10 )
11
12 func TestNewClient(t *testing.T) {
13 var (
14 acl = stdzk.WorldACL(stdzk.PermRead)
15 connectTimeout = 3 * time.Second
16 sessionTimeout = 20 * time.Second
17 payload = [][]byte{[]byte("Payload"), []byte("Test")}
18 )
19
20 c, err := NewClient(
21 []string{"FailThisInvalidHost!!!"},
22 log.NewNopLogger(),
23 )
24 if err == nil {
25 t.Errorf("expected error, got nil")
26 }
27
28 hasFired := false
29 calledEventHandler := make(chan struct{})
30 eventHandler := func(event stdzk.Event) {
31 if !hasFired {
32 // test is successful if this function has fired at least once
33 hasFired = true
34 close(calledEventHandler)
35 }
36 }
37
38 c, err = NewClient(
39 []string{"localhost"},
40 log.NewNopLogger(),
41 ACL(acl),
42 ConnectTimeout(connectTimeout),
43 SessionTimeout(sessionTimeout),
44 Payload(payload),
45 EventHandler(eventHandler),
46 )
47 if err != nil {
48 t.Fatal(err)
49 }
50 defer c.Stop()
51
52 clientImpl, ok := c.(*client)
53 if !ok {
54 t.Fatal("retrieved incorrect Client implementation")
55 }
56 if want, have := acl, clientImpl.acl; want[0] != have[0] {
57 t.Errorf("want %+v, have %+v", want, have)
58 }
59 if want, have := connectTimeout, clientImpl.connectTimeout; want != have {
60 t.Errorf("want %d, have %d", want, have)
61 }
62 if want, have := sessionTimeout, clientImpl.sessionTimeout; want != have {
63 t.Errorf("want %d, have %d", want, have)
64 }
65 if want, have := payload, clientImpl.rootNodePayload; bytes.Compare(want[0], have[0]) != 0 || bytes.Compare(want[1], have[1]) != 0 {
66 t.Errorf("want %s, have %s", want, have)
67 }
68
69 select {
70 case <-calledEventHandler:
71 case <-time.After(100 * time.Millisecond):
72 t.Errorf("event handler never called")
73 }
74 }
75
76 func TestOptions(t *testing.T) {
77 _, err := NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("valid", "credentials"))
78 if err != nil && err != stdzk.ErrNoServer {
79 t.Errorf("unexpected error: %v", err)
80 }
81
82 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), Credentials("nopass", ""))
83 if want, have := err, ErrInvalidCredentials; want != have {
84 t.Errorf("want %v, have %v", want, have)
85 }
86
87 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), ConnectTimeout(0))
88 if err == nil {
89 t.Errorf("expected connect timeout error")
90 }
91
92 _, err = NewClient([]string{"localhost"}, log.NewNopLogger(), SessionTimeout(0))
93 if err == nil {
94 t.Errorf("expected connect timeout error")
95 }
96 }
97
98 func TestCreateParentNodes(t *testing.T) {
99 payload := [][]byte{[]byte("Payload"), []byte("Test")}
100
101 c, err := NewClient([]string{"localhost:65500"}, log.NewNopLogger())
102 if err != nil {
103 t.Errorf("unexpected error: %v", err)
104 }
105 if c == nil {
106 t.Fatal("expected new Client, got nil")
107 }
108
109 s, err := NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
110 if err != stdzk.ErrNoServer {
111 t.Errorf("unexpected error: %v", err)
112 }
113 if s != nil {
114 t.Error("expected failed new Subscriber")
115 }
116
117 s, err = NewSubscriber(c, "invalidpath", newFactory(""), log.NewNopLogger())
118 if err != stdzk.ErrInvalidPath {
119 t.Errorf("unexpected error: %v", err)
120 }
121 _, _, err = c.GetEntries("/validpath")
122 if err != stdzk.ErrNoServer {
123 t.Errorf("unexpected error: %v", err)
124 }
125
126 c.Stop()
127
128 err = c.CreateParentNodes("/validpath")
129 if err != ErrClientClosed {
130 t.Errorf("unexpected error: %v", err)
131 }
132
133 s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
134 if err != ErrClientClosed {
135 t.Errorf("unexpected error: %v", err)
136 }
137 if s != nil {
138 t.Error("expected failed new Subscriber")
139 }
140
141 c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload))
142 if err != nil {
143 t.Errorf("unexpected error: %v", err)
144 }
145 if c == nil {
146 t.Fatal("expected new Client, got nil")
147 }
148
149 s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger())
150 if err != stdzk.ErrNoServer {
151 t.Errorf("unexpected error: %v", err)
152 }
153 if s != nil {
154 t.Error("expected failed new Subscriber")
155 }
156 }
0 // +build integration
1
2 package zk
3
4 import (
5 "bytes"
6 "flag"
7 "fmt"
8 "os"
9 "testing"
10 "time"
11
12 stdzk "github.com/samuel/go-zookeeper/zk"
13 )
14
15 var (
16 host []string
17 )
18
19 func TestMain(m *testing.M) {
20 flag.Parse()
21
22 fmt.Println("Starting ZooKeeper server...")
23
24 ts, err := stdzk.StartTestCluster(1, nil, nil)
25 if err != nil {
26 fmt.Printf("ZooKeeper server error: %v\n", err)
27 os.Exit(1)
28 }
29
30 host = []string{fmt.Sprintf("localhost:%d", ts.Servers[0].Port)}
31 code := m.Run()
32
33 ts.Stop()
34 os.Exit(code)
35 }
36
37 func TestCreateParentNodesOnServer(t *testing.T) {
38 payload := [][]byte{[]byte("Payload"), []byte("Test")}
39 c1, err := NewClient(host, logger, Payload(payload))
40 if err != nil {
41 t.Fatalf("Connect returned error: %v", err)
42 }
43 if c1 == nil {
44 t.Fatal("Expected pointer to client, got nil")
45 }
46 defer c1.Stop()
47
48 s, err := NewSubscriber(c1, path, newFactory(""), logger)
49 if err != nil {
50 t.Fatalf("Unable to create Subscriber: %v", err)
51 }
52 defer s.Stop()
53
54 services, err := s.Services()
55 if err != nil {
56 t.Fatal(err)
57 }
58 if want, have := 0, len(services); want != have {
59 t.Errorf("want %d, have %d", want, have)
60 }
61
62 c2, err := NewClient(host, logger)
63 if err != nil {
64 t.Fatalf("Connect returned error: %v", err)
65 }
66 defer c2.Stop()
67 data, _, err := c2.(*client).Get(path)
68 if err != nil {
69 t.Fatal(err)
70 }
71 // test Client implementation of CreateParentNodes. It should have created
72 // our payload
73 if bytes.Compare(data, payload[1]) != 0 {
74 t.Errorf("want %s, have %s", payload[1], data)
75 }
76
77 }
78
79 func TestCreateBadParentNodesOnServer(t *testing.T) {
80 c, _ := NewClient(host, logger)
81 defer c.Stop()
82
83 _, err := NewSubscriber(c, "invalid/path", newFactory(""), logger)
84
85 if want, have := stdzk.ErrInvalidPath, err; want != have {
86 t.Errorf("want %v, have %v", want, have)
87 }
88 }
89
90 func TestCredentials1(t *testing.T) {
91 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
92 c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret"))
93 defer c.Stop()
94
95 _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
96
97 if err != nil {
98 t.Fatal(err)
99 }
100 }
101
102 func TestCredentials2(t *testing.T) {
103 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
104 c, _ := NewClient(host, logger, ACL(acl))
105 defer c.Stop()
106
107 _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
108
109 if err != stdzk.ErrNoAuth {
110 t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err)
111 }
112 }
113
114 func TestConnection(t *testing.T) {
115 c, _ := NewClient(host, logger)
116 c.Stop()
117
118 _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger)
119
120 if err != ErrClientClosed {
121 t.Errorf("want %v, have %v", ErrClientClosed, err)
122 }
123 }
124
125 func TestGetEntriesOnServer(t *testing.T) {
126 var instancePayload = "protocol://hostname:port/routing"
127
128 c1, err := NewClient(host, logger)
129 if err != nil {
130 t.Fatalf("Connect returned error: %v", err)
131 }
132
133 defer c1.Stop()
134
135 c2, err := NewClient(host, logger)
136 s, err := NewSubscriber(c2, path, newFactory(""), logger)
137 if err != nil {
138 t.Fatal(err)
139 }
140 defer c2.Stop()
141
142 c2impl, _ := c2.(*client)
143 _, err = c2impl.Create(
144 path+"/instance1",
145 []byte(instancePayload),
146 stdzk.FlagEphemeral|stdzk.FlagSequence,
147 stdzk.WorldACL(stdzk.PermAll),
148 )
149 if err != nil {
150 t.Fatalf("Unable to create test ephemeral znode 1: %v", err)
151 }
152 _, err = c2impl.Create(
153 path+"/instance2",
154 []byte(instancePayload+"2"),
155 stdzk.FlagEphemeral|stdzk.FlagSequence,
156 stdzk.WorldACL(stdzk.PermAll),
157 )
158 if err != nil {
159 t.Fatalf("Unable to create test ephemeral znode 2: %v", err)
160 }
161
162 time.Sleep(50 * time.Millisecond)
163
164 services, err := s.Services()
165 if err != nil {
166 t.Fatal(err)
167 }
168 if want, have := 2, len(services); want != have {
169 t.Errorf("want %d, have %d", want, have)
170 }
171 }
172
173 func TestGetEntriesPayloadOnServer(t *testing.T) {
174 c, err := NewClient(host, logger)
175 if err != nil {
176 t.Fatalf("Connect returned error: %v", err)
177 }
178 _, eventc, err := c.GetEntries(path)
179 if err != nil {
180 t.Fatal(err)
181 }
182 _, err = c.(*client).Create(
183 path+"/instance3",
184 []byte("just some payload"),
185 stdzk.FlagEphemeral|stdzk.FlagSequence,
186 stdzk.WorldACL(stdzk.PermAll),
187 )
188 if err != nil {
189 t.Fatalf("Unable to create test ephemeral znode: %v", err)
190 }
191 select {
192 case event := <-eventc:
193 if want, have := stdzk.EventNodeChildrenChanged.String(), event.Type.String(); want != have {
194 t.Errorf("want %s, have %s", want, have)
195 }
196 case <-time.After(20 * time.Millisecond):
197 t.Errorf("expected incoming watch event, timeout occurred")
198 }
199
200 }
0 package zk
1
2 import (
3 "fmt"
4
5 "github.com/samuel/go-zookeeper/zk"
6
7 "github.com/go-kit/kit/log"
8 )
9
10 // wrapLogger wraps a Go kit logger so we can use it as the logging service for
11 // the ZooKeeper library, which expects a Printf method to be available.
12 type wrapLogger struct {
13 log.Logger
14 }
15
16 func (logger wrapLogger) Printf(format string, args ...interface{}) {
17 logger.Log("msg", fmt.Sprintf(format, args...))
18 }
19
20 // withLogger replaces the ZooKeeper library's default logging service with our
21 // own Go kit logger.
22 func withLogger(logger log.Logger) func(c *zk.Conn) {
23 return func(c *zk.Conn) {
24 c.SetLogger(wrapLogger{logger})
25 }
26 }
0 package zk
1
2 import (
3 "github.com/samuel/go-zookeeper/zk"
4
5 "github.com/go-kit/kit/endpoint"
6 "github.com/go-kit/kit/log"
7 "github.com/go-kit/kit/sd"
8 "github.com/go-kit/kit/sd/cache"
9 )
10
11 // Subscriber yield endpoints stored in a certain ZooKeeper path. Any kind of
12 // change in that path is watched and will update the Subscriber endpoints.
13 type Subscriber struct {
14 client Client
15 path string
16 cache *cache.Cache
17 logger log.Logger
18 quitc chan struct{}
19 }
20
21 var _ sd.Subscriber = &Subscriber{}
22
23 // NewSubscriber returns a ZooKeeper subscriber. ZooKeeper will start watching
24 // the given path for changes and update the Subscriber endpoints.
25 func NewSubscriber(c Client, path string, factory sd.Factory, logger log.Logger) (*Subscriber, error) {
26 s := &Subscriber{
27 client: c,
28 path: path,
29 cache: cache.New(factory, logger),
30 logger: logger,
31 quitc: make(chan struct{}),
32 }
33
34 err := s.client.CreateParentNodes(s.path)
35 if err != nil {
36 return nil, err
37 }
38
39 instances, eventc, err := s.client.GetEntries(s.path)
40 if err != nil {
41 logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err)
42 return nil, err
43 }
44 logger.Log("path", s.path, "instances", len(instances))
45 s.cache.Update(instances)
46
47 go s.loop(eventc)
48
49 return s, nil
50 }
51
52 func (s *Subscriber) loop(eventc <-chan zk.Event) {
53 var (
54 instances []string
55 err error
56 )
57 for {
58 select {
59 case <-eventc:
60 // We received a path update notification. Call GetEntries to
61 // retrieve child node data, and set a new watch, as ZK watches are
62 // one-time triggers.
63 instances, eventc, err = s.client.GetEntries(s.path)
64 if err != nil {
65 s.logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err)
66 continue
67 }
68 s.logger.Log("path", s.path, "instances", len(instances))
69 s.cache.Update(instances)
70
71 case <-s.quitc:
72 return
73 }
74 }
75 }
76
77 // Endpoints implements the Subscriber interface.
78 func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) {
79 return s.cache.Endpoints(), nil
80 }
81
82 // Stop terminates the Subscriber.
83 func (s *Subscriber) Stop() {
84 close(s.quitc)
85 }
0 package zk
1
2 import (
3 "testing"
4 "time"
5 )
6
7 func TestSubscriber(t *testing.T) {
8 client := newFakeClient()
9
10 s, err := NewSubscriber(client, path, newFactory(""), logger)
11 if err != nil {
12 t.Fatalf("failed to create new Subscriber: %v", err)
13 }
14 defer s.Stop()
15
16 if _, err := s.Endpoints(); err != nil {
17 t.Fatal(err)
18 }
19 }
20
21 func TestBadFactory(t *testing.T) {
22 client := newFakeClient()
23
24 s, err := NewSubscriber(client, path, newFactory("kaboom"), logger)
25 if err != nil {
26 t.Fatalf("failed to create new Subscriber: %v", err)
27 }
28 defer s.Stop()
29
30 // instance1 came online
31 client.AddService(path+"/instance1", "kaboom")
32
33 // instance2 came online
34 client.AddService(path+"/instance2", "zookeeper_node_data")
35
36 if err = asyncTest(100*time.Millisecond, 1, s); err != nil {
37 t.Error(err)
38 }
39 }
40
41 func TestServiceUpdate(t *testing.T) {
42 client := newFakeClient()
43
44 s, err := NewSubscriber(client, path, newFactory(""), logger)
45 if err != nil {
46 t.Fatalf("failed to create new Subscriber: %v", err)
47 }
48 defer s.Stop()
49
50 endpoints, err := s.Endpoints()
51 if err != nil {
52 t.Fatal(err)
53 }
54 if want, have := 0, len(endpoints); want != have {
55 t.Errorf("want %d, have %d", want, have)
56 }
57
58 // instance1 came online
59 client.AddService(path+"/instance1", "zookeeper_node_data1")
60
61 // instance2 came online
62 client.AddService(path+"/instance2", "zookeeper_node_data2")
63
64 // we should have 2 instances
65 if err = asyncTest(100*time.Millisecond, 2, s); err != nil {
66 t.Error(err)
67 }
68
69 // TODO(pb): this bit is flaky
70 //
71 //// watch triggers an error...
72 //client.SendErrorOnWatch()
73 //
74 //// test if error was consumed
75 //if err = client.ErrorIsConsumedWithin(100 * time.Millisecond); err != nil {
76 // t.Error(err)
77 //}
78
79 // instance3 came online
80 client.AddService(path+"/instance3", "zookeeper_node_data3")
81
82 // we should have 3 instances
83 if err = asyncTest(100*time.Millisecond, 3, s); err != nil {
84 t.Error(err)
85 }
86
87 // instance1 goes offline
88 client.RemoveService(path + "/instance1")
89
90 // instance2 goes offline
91 client.RemoveService(path + "/instance2")
92
93 // we should have 1 instance
94 if err = asyncTest(100*time.Millisecond, 1, s); err != nil {
95 t.Error(err)
96 }
97 }
98
99 func TestBadSubscriberCreate(t *testing.T) {
100 client := newFakeClient()
101 client.SendErrorOnWatch()
102 s, err := NewSubscriber(client, path, newFactory(""), logger)
103 if err == nil {
104 t.Error("expected error on new Subscriber")
105 }
106 if s != nil {
107 t.Error("expected Subscriber not to be created")
108 }
109 s, err = NewSubscriber(client, "BadPath", newFactory(""), logger)
110 if err == nil {
111 t.Error("expected error on new Subscriber")
112 }
113 if s != nil {
114 t.Error("expected Subscriber not to be created")
115 }
116 }
0 package zk
1
2 import (
3 "errors"
4 "fmt"
5 "io"
6 "sync"
7 "time"
8
9 "github.com/samuel/go-zookeeper/zk"
10 "golang.org/x/net/context"
11
12 "github.com/go-kit/kit/endpoint"
13 "github.com/go-kit/kit/log"
14 "github.com/go-kit/kit/sd"
15 )
16
17 var (
18 path = "/gokit.test/service.name"
19 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
20 logger = log.NewNopLogger()
21 )
22
23 type fakeClient struct {
24 mtx sync.Mutex
25 ch chan zk.Event
26 responses map[string]string
27 result bool
28 }
29
30 func newFakeClient() *fakeClient {
31 return &fakeClient{
32 ch: make(chan zk.Event, 1),
33 responses: make(map[string]string),
34 result: true,
35 }
36 }
37
38 func (c *fakeClient) CreateParentNodes(path string) error {
39 if path == "BadPath" {
40 return errors.New("dummy error")
41 }
42 return nil
43 }
44
45 func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) {
46 c.mtx.Lock()
47 defer c.mtx.Unlock()
48 if c.result == false {
49 c.result = true
50 return []string{}, c.ch, errors.New("dummy error")
51 }
52 responses := []string{}
53 for _, data := range c.responses {
54 responses = append(responses, data)
55 }
56 return responses, c.ch, nil
57 }
58
59 func (c *fakeClient) AddService(node, data string) {
60 c.mtx.Lock()
61 defer c.mtx.Unlock()
62 c.responses[node] = data
63 c.ch <- zk.Event{}
64 }
65
66 func (c *fakeClient) RemoveService(node string) {
67 c.mtx.Lock()
68 defer c.mtx.Unlock()
69 delete(c.responses, node)
70 c.ch <- zk.Event{}
71 }
72
73 func (c *fakeClient) SendErrorOnWatch() {
74 c.mtx.Lock()
75 defer c.mtx.Unlock()
76 c.result = false
77 c.ch <- zk.Event{}
78 }
79
80 func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error {
81 t := time.After(timeout)
82 for {
83 select {
84 case <-t:
85 return fmt.Errorf("expected error not consumed after timeout %s", timeout)
86 default:
87 c.mtx.Lock()
88 if c.result == false {
89 c.mtx.Unlock()
90 return nil
91 }
92 c.mtx.Unlock()
93 }
94 }
95 }
96
97 func (c *fakeClient) Stop() {}
98
99 func newFactory(fakeError string) sd.Factory {
100 return func(instance string) (endpoint.Endpoint, io.Closer, error) {
101 if fakeError == instance {
102 return nil, nil, errors.New(fakeError)
103 }
104 return endpoint.Nop, nil, nil
105 }
106 }
107
108 func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) {
109 var endpoints []endpoint.Endpoint
110 have := -1 // want can never be <0
111 t := time.After(timeout)
112 for {
113 select {
114 case <-t:
115 return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String())
116 default:
117 endpoints, err = s.Endpoints()
118 have = len(endpoints)
119 if err != nil || want == have {
120 return
121 }
122 time.Sleep(timeout / 10)
123 }
124 }
125 }