loadbalancer/dnssrv: use EndpointCache
Peter Bourgon
8 years ago
0 | 0 | package dnssrv |
1 | 1 | |
2 | 2 | import ( |
3 | "crypto/md5" | |
4 | 3 | "fmt" |
5 | 4 | "net" |
6 | "sort" | |
7 | 5 | "time" |
8 | 6 | |
9 | 7 | "github.com/go-kit/kit/endpoint" |
16 | 14 | type Publisher struct { |
17 | 15 | name string |
18 | 16 | ttl time.Duration |
19 | factory loadbalancer.Factory | |
17 | cache *loadbalancer.EndpointCache | |
20 | 18 | logger log.Logger |
21 | 19 | endpoints chan []endpoint.Endpoint |
22 | 20 | quit chan struct{} |
27 | 25 | // constructor will return an error. The factory is used to convert a |
28 | 26 | // host:port to a usable endpoint. The logger is used to report DNS and |
29 | 27 | // factory errors. |
30 | func NewPublisher(name string, ttl time.Duration, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { | |
31 | logger = log.NewContext(logger).With("component", "DNS SRV Publisher") | |
32 | addrs, md5, err := resolve(name) | |
33 | if err != nil { | |
34 | return nil, err | |
35 | } | |
28 | func NewPublisher(name string, ttl time.Duration, factory loadbalancer.Factory, logger log.Logger) *Publisher { | |
36 | 29 | p := &Publisher{ |
37 | 30 | name: name, |
38 | 31 | ttl: ttl, |
39 | factory: f, | |
32 | cache: loadbalancer.NewEndpointCache(factory, logger), | |
40 | 33 | logger: logger, |
41 | 34 | endpoints: make(chan []endpoint.Endpoint), |
42 | 35 | quit: make(chan struct{}), |
43 | 36 | } |
44 | go p.loop(makeEndpoints(addrs, f, logger), md5) | |
45 | return p, nil | |
37 | ||
38 | instances, err := p.resolve() | |
39 | if err != nil { | |
40 | logger.Log(name, len(instances)) | |
41 | } else { | |
42 | logger.Log(name, err) | |
43 | } | |
44 | p.cache.Replace(instances) | |
45 | ||
46 | go p.loop() | |
47 | return p | |
46 | 48 | } |
47 | 49 | |
48 | 50 | // Stop terminates the publisher. |
50 | 52 | close(p.quit) |
51 | 53 | } |
52 | 54 | |
53 | func (p *Publisher) loop(m map[string]endpointCloser, md5 string) { | |
55 | func (p *Publisher) loop() { | |
54 | 56 | t := newTicker(p.ttl) |
55 | 57 | defer t.Stop() |
56 | 58 | for { |
57 | 59 | select { |
58 | case p.endpoints <- flatten(m): | |
60 | case p.endpoints <- p.cache.Endpoints(): | |
59 | 61 | |
60 | 62 | case <-t.C: |
61 | // TODO should we do this out-of-band? | |
62 | addrs, newmd5, err := resolve(p.name) | |
63 | instances, err := p.resolve() | |
63 | 64 | if err != nil { |
64 | p.logger.Log("name", p.name, "err", err) | |
65 | continue // don't replace probably-good endpoints with bad ones | |
65 | p.logger.Log(p.name, err) | |
66 | continue // don't replace potentially-good with bad | |
66 | 67 | } |
67 | if newmd5 == md5 { | |
68 | continue // optimization: no change | |
69 | } | |
70 | m = migrate(m, makeEndpoints(addrs, p.factory, p.logger)) | |
71 | md5 = newmd5 | |
68 | p.cache.Replace(instances) | |
72 | 69 | |
73 | 70 | case <-p.quit: |
74 | 71 | return |
91 | 88 | newTicker = time.NewTicker |
92 | 89 | ) |
93 | 90 | |
94 | func resolve(name string) (addrs []*net.SRV, md5sum string, err error) { | |
95 | _, addrs, err = lookupSRV("", "", name) | |
91 | func (p *Publisher) resolve() ([]string, error) { | |
92 | _, addrs, err := lookupSRV("", "", p.name) | |
96 | 93 | if err != nil { |
97 | return addrs, "", err | |
94 | return []string{}, err | |
98 | 95 | } |
99 | 96 | instances := make([]string, len(addrs)) |
100 | 97 | for i, addr := range addrs { |
101 | instances[i] = addr2instance(addr) | |
98 | instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
102 | 99 | } |
103 | sort.Sort(sort.StringSlice(instances)) | |
104 | h := md5.New() | |
105 | for _, instance := range instances { | |
106 | fmt.Fprintf(h, instance) | |
107 | } | |
108 | return addrs, fmt.Sprintf("%x", h.Sum(nil)), nil | |
100 | return instances, nil | |
109 | 101 | } |
110 | ||
111 | func makeEndpoints(addrs []*net.SRV, f loadbalancer.Factory, logger log.Logger) map[string]endpointCloser { | |
112 | m := make(map[string]endpointCloser, len(addrs)) | |
113 | for _, addr := range addrs { | |
114 | instance := addr2instance(addr) | |
115 | endpoint, closer, err := f(instance) | |
116 | if err != nil { | |
117 | logger.Log("instance", addr2instance(addr), "err", err) | |
118 | continue | |
119 | } | |
120 | m[instance] = endpointCloser{endpoint, closer} | |
121 | } | |
122 | return m | |
123 | } | |
124 | ||
125 | func migrate(prev, curr map[string]endpointCloser) map[string]endpointCloser { | |
126 | for instance, ec := range prev { | |
127 | if _, ok := curr[instance]; !ok { | |
128 | close(ec.Closer) | |
129 | } | |
130 | } | |
131 | return curr | |
132 | } | |
133 | ||
134 | func addr2instance(addr *net.SRV) string { | |
135 | return net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) | |
136 | } | |
137 | ||
138 | func flatten(m map[string]endpointCloser) []endpoint.Endpoint { | |
139 | a := make([]endpoint.Endpoint, 0, len(m)) | |
140 | for _, ec := range m { | |
141 | a = append(a, ec.Endpoint) | |
142 | } | |
143 | return a | |
144 | } | |
145 | ||
146 | type endpointCloser struct { | |
147 | endpoint.Endpoint | |
148 | loadbalancer.Closer | |
149 | } |
15 | 15 | |
16 | 16 | func TestPublisher(t *testing.T) { |
17 | 17 | var ( |
18 | target = "my-target" | |
19 | port = uint16(1234) | |
20 | addr = &net.SRV{Target: target, Port: port} | |
21 | addrs = []*net.SRV{addr} | |
22 | name = "my-name" | |
23 | ttl = time.Second | |
24 | logger = log.NewNopLogger() | |
25 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
18 | name = "foo" | |
19 | ttl = time.Second | |
20 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
21 | c = make(chan struct{}) | |
22 | factory = func(string) (endpoint.Endpoint, loadbalancer.Closer, error) { return e, c, nil } | |
23 | logger = log.NewNopLogger() | |
26 | 24 | ) |
27 | 25 | |
28 | oldLookup := lookupSRV | |
29 | defer func() { lookupSRV = oldLookup }() | |
30 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
31 | ||
32 | factory := func(instance string) (endpoint.Endpoint, loadbalancer.Closer, error) { | |
33 | if want, have := addr2instance(addr), instance; want != have { | |
34 | t.Errorf("want %q, have %q", want, have) | |
35 | } | |
36 | return e, make(loadbalancer.Closer), nil | |
37 | } | |
38 | ||
39 | p, err := NewPublisher(name, ttl, factory, logger) | |
40 | if err != nil { | |
41 | t.Fatal(err) | |
42 | } | |
26 | p := NewPublisher(name, ttl, factory, logger) | |
43 | 27 | defer p.Stop() |
44 | 28 | |
45 | 29 | if _, err := p.Endpoints(); err != nil { |
55 | 39 | var ( |
56 | 40 | name = "some-name" |
57 | 41 | ttl = time.Second |
58 | factory = func(string) (endpoint.Endpoint, loadbalancer.Closer, error) { return nil, nil, errors.New("false") } | |
42 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
43 | c = make(chan struct{}) | |
44 | factory = func(string) (endpoint.Endpoint, loadbalancer.Closer, error) { return e, c, nil } | |
59 | 45 | logger = log.NewNopLogger() |
60 | 46 | ) |
61 | 47 | |
62 | if _, err := NewPublisher(name, ttl, factory, logger); err == nil { | |
63 | t.Fatal("wanted error, got none") | |
48 | p := NewPublisher(name, ttl, factory, logger) | |
49 | defer p.Stop() | |
50 | ||
51 | endpoints, err := p.Endpoints() | |
52 | if err != nil { | |
53 | t.Error(err) | |
54 | } | |
55 | if want, have := 0, len(endpoints); want != have { | |
56 | t.Errorf("want %d, have %d", want, have) | |
64 | 57 | } |
65 | 58 | } |
66 | 59 | |
78 | 71 | defer func() { lookupSRV = oldLookup }() |
79 | 72 | lookupSRV = mockLookupSRV(addrs, nil, nil) |
80 | 73 | |
81 | p, err := NewPublisher(name, ttl, factory, logger) | |
82 | if err != nil { | |
83 | t.Fatal(err) | |
84 | } | |
74 | p := NewPublisher(name, ttl, factory, logger) | |
85 | 75 | defer p.Stop() |
86 | 76 | |
87 | 77 | endpoints, err := p.Endpoints() |
88 | 78 | if err != nil { |
89 | t.Fatal(err) | |
79 | t.Error(err) | |
90 | 80 | } |
91 | 81 | if want, have := 0, len(endpoints); want != have { |
92 | 82 | t.Errorf("want %q, have %q", want, have) |
119 | 109 | defer func() { lookupSRV = oldLookup }() |
120 | 110 | lookupSRV = mockLookupSRV(addrs, nil, &resolves) |
121 | 111 | |
122 | p, err := NewPublisher(name, ttl, factory, logger) | |
123 | if err != nil { | |
124 | t.Fatal(err) | |
125 | } | |
112 | p := NewPublisher(name, ttl, factory, logger) | |
126 | 113 | defer p.Stop() |
127 | 114 | |
128 | 115 | tick <- time.Now() |
147 | 134 | defer func() { lookupSRV = oldLookup }() |
148 | 135 | lookupSRV = mockLookupSRV([]*net.SRV{}, nil, nil) |
149 | 136 | |
150 | p, err := NewPublisher(name, ttl, factory, logger) | |
151 | if err != nil { | |
152 | t.Fatal(err) | |
153 | } | |
137 | p := NewPublisher(name, ttl, factory, logger) | |
154 | 138 | |
155 | 139 | p.Stop() |
156 | 140 | _, have := p.Endpoints() |
6 | 6 | "github.com/go-kit/kit/log" |
7 | 7 | ) |
8 | 8 | |
9 | // EndpointCache caches resource-managed endpoints. Clients update the cache | |
10 | // by providing a current set of instance strings. The cache converts each | |
11 | // instance string to an endpoint and a closer via the factory function. | |
9 | // EndpointCache caches endpoints that need to be deallocated when they're no | |
10 | // longer useful. Clients update the cache by providing a current set of | |
11 | // instance strings. The cache converts each instance string to an endpoint | |
12 | // and a closer via the factory function. | |
12 | 13 | // |
13 | // Instance strings are assumed to be unique and used as keys. Endpoints that | |
14 | // were in the previous set of instances and removed from the current set are | |
15 | // considered invalid and closed. | |
14 | // Instance strings are assumed to be unique and are used as keys. Endpoints | |
15 | // that were in the previous set of instances and are not in the current set | |
16 | // are considered invalid and closed. | |
16 | 17 | // |
17 | 18 | // EndpointCache is designed to be used in your publisher implementation. |
18 | 19 | type EndpointCache struct { |
29 | 30 | return &EndpointCache{ |
30 | 31 | f: f, |
31 | 32 | m: map[string]endpointCloser{}, |
32 | logger: logger, | |
33 | logger: log.NewContext(logger).With("component", "Endpoint Cache"), | |
33 | 34 | } |
34 | 35 | } |
35 | 36 |
11 | 11 | |
12 | 12 | // NewPublisher returns a static endpoint Publisher. |
13 | 13 | func NewPublisher(instances []string, factory loadbalancer.Factory, logger log.Logger) Publisher { |
14 | logger = log.NewContext(logger).With("component", "Fixed Publisher") | |
14 | logger = log.NewContext(logger).With("component", "Static Publisher") | |
15 | 15 | endpoints := []endpoint.Endpoint{} |
16 | 16 | for _, instance := range instances { |
17 | 17 | e, _, err := factory(instance) // never close |