loadbalancer/dnssrv: fix racy tests
Peter Bourgon
8 years ago
13 | 13 | // resolved on a fixed schedule. Priorities and weights are ignored. |
14 | 14 | type Publisher struct { |
15 | 15 | name string |
16 | ttl time.Duration | |
17 | 16 | cache *loadbalancer.EndpointCache |
18 | 17 | logger log.Logger |
19 | 18 | quit chan struct{} |
24 | 23 | // constructor will return an error. The factory is used to convert a |
25 | 24 | // host:port to a usable endpoint. The logger is used to report DNS and |
26 | 25 | // factory errors. |
27 | func NewPublisher(name string, ttl time.Duration, factory loadbalancer.Factory, logger log.Logger) *Publisher { | |
26 | func NewPublisher( | |
27 | name string, | |
28 | ttl time.Duration, | |
29 | factory loadbalancer.Factory, | |
30 | logger log.Logger, | |
31 | ) *Publisher { | |
32 | return NewPublisherDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) | |
33 | } | |
34 | ||
35 | // NewPublisherDetailed is the same as NewPublisher, but allows users to provide | |
36 | // an explicit lookup refresh ticker instead of a TTL, and specify the function | |
37 | // used to perform lookups instead of using net.LookupSRV. | |
38 | func NewPublisherDetailed( | |
39 | name string, | |
40 | refreshTicker *time.Ticker, | |
41 | lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), | |
42 | factory loadbalancer.Factory, | |
43 | logger log.Logger, | |
44 | ) *Publisher { | |
28 | 45 | p := &Publisher{ |
29 | 46 | name: name, |
30 | ttl: ttl, | |
31 | 47 | cache: loadbalancer.NewEndpointCache(factory, logger), |
32 | 48 | logger: logger, |
33 | 49 | quit: make(chan struct{}), |
34 | 50 | } |
35 | 51 | |
36 | instances, err := p.resolve() | |
52 | instances, err := p.resolve(lookupSRV) | |
37 | 53 | if err == nil { |
38 | 54 | logger.Log("name", name, "instances", len(instances)) |
39 | 55 | } else { |
41 | 57 | } |
42 | 58 | p.cache.Replace(instances) |
43 | 59 | |
44 | go p.loop() | |
60 | go p.loop(refreshTicker, lookupSRV) | |
45 | 61 | return p |
46 | 62 | } |
47 | 63 | |
50 | 66 | close(p.quit) |
51 | 67 | } |
52 | 68 | |
53 | func (p *Publisher) loop() { | |
54 | t := newTicker(p.ttl) | |
55 | defer t.Stop() | |
69 | func (p *Publisher) loop( | |
70 | refreshTicker *time.Ticker, | |
71 | lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), | |
72 | ) { | |
73 | defer refreshTicker.Stop() | |
56 | 74 | for { |
57 | 75 | select { |
58 | case <-t.C: | |
59 | instances, err := p.resolve() | |
76 | case <-refreshTicker.C: | |
77 | instances, err := p.resolve(lookupSRV) | |
60 | 78 | if err != nil { |
61 | 79 | p.logger.Log(p.name, err) |
62 | 80 | continue // don't replace potentially-good with bad |
74 | 92 | return p.cache.Endpoints() |
75 | 93 | } |
76 | 94 | |
77 | var ( | |
78 | lookupSRV = net.LookupSRV | |
79 | newTicker = time.NewTicker | |
80 | ) | |
81 | ||
82 | func (p *Publisher) resolve() ([]string, error) { | |
95 | func (p *Publisher) resolve(lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error)) ([]string, error) { | |
83 | 96 | _, addrs, err := lookupSRV("", "", p.name) |
84 | 97 | if err != nil { |
85 | 98 | return []string{}, err |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "net" | |
6 | "sync/atomic" | |
7 | "testing" | |
8 | "time" | |
9 | ||
10 | "golang.org/x/net/context" | |
11 | ||
12 | "github.com/go-kit/kit/endpoint" | |
13 | "github.com/go-kit/kit/log" | |
14 | ) | |
15 | ||
16 | func TestPublisher(t *testing.T) { | |
17 | var ( | |
18 | name = "foo" | |
19 | ttl = time.Second | |
20 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
21 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
22 | logger = log.NewNopLogger() | |
23 | ) | |
24 | ||
25 | p := NewPublisher(name, ttl, factory, logger) | |
26 | defer p.Stop() | |
27 | ||
28 | if _, err := p.Endpoints(); err != nil { | |
29 | t.Fatal(err) | |
30 | } | |
31 | } | |
32 | ||
33 | func TestBadLookup(t *testing.T) { | |
34 | oldLookup := lookupSRV | |
35 | defer func() { lookupSRV = oldLookup }() | |
36 | lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil) | |
37 | ||
38 | var ( | |
39 | name = "some-name" | |
40 | ttl = time.Second | |
41 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
42 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
43 | logger = log.NewNopLogger() | |
44 | ) | |
45 | ||
46 | p := NewPublisher(name, ttl, factory, logger) | |
47 | defer p.Stop() | |
48 | ||
49 | endpoints, err := p.Endpoints() | |
50 | if err != nil { | |
51 | t.Error(err) | |
52 | } | |
53 | if want, have := 0, len(endpoints); want != have { | |
54 | t.Errorf("want %d, have %d", want, have) | |
55 | } | |
56 | } | |
57 | ||
58 | func TestBadFactory(t *testing.T) { | |
59 | var ( | |
60 | addr = &net.SRV{Target: "foo", Port: 1234} | |
61 | addrs = []*net.SRV{addr} | |
62 | name = "some-name" | |
63 | ttl = time.Second | |
64 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return nil, nil, errors.New("kaboom") } | |
65 | logger = log.NewNopLogger() | |
66 | ) | |
67 | ||
68 | oldLookup := lookupSRV | |
69 | defer func() { lookupSRV = oldLookup }() | |
70 | lookupSRV = mockLookupSRV(addrs, nil, nil) | |
71 | ||
72 | p := NewPublisher(name, ttl, factory, logger) | |
73 | defer p.Stop() | |
74 | ||
75 | endpoints, err := p.Endpoints() | |
76 | if err != nil { | |
77 | t.Error(err) | |
78 | } | |
79 | if want, have := 0, len(endpoints); want != have { | |
80 | t.Errorf("want %q, have %q", want, have) | |
81 | } | |
82 | } | |
83 | ||
84 | func TestRefreshWithChange(t *testing.T) { | |
85 | t.Skip("TODO") | |
86 | } | |
87 | ||
88 | func TestRefreshNoChange(t *testing.T) { | |
89 | var ( | |
90 | tick = make(chan time.Time) | |
91 | target = "my-target" | |
92 | port = uint16(5678) | |
93 | addr = &net.SRV{Target: target, Port: port} | |
94 | addrs = []*net.SRV{addr} | |
95 | name = "my-name" | |
96 | ttl = time.Second | |
97 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return nil, nil, errors.New("kaboom") } | |
98 | logger = log.NewNopLogger() | |
99 | ) | |
100 | ||
101 | oldTicker := newTicker | |
102 | defer func() { newTicker = oldTicker }() | |
103 | newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} } | |
104 | ||
105 | var resolves uint64 | |
106 | oldLookup := lookupSRV | |
107 | defer func() { lookupSRV = oldLookup }() | |
108 | lookupSRV = mockLookupSRV(addrs, nil, &resolves) | |
109 | ||
110 | p := NewPublisher(name, ttl, factory, logger) | |
111 | defer p.Stop() | |
112 | ||
113 | tick <- time.Now() | |
114 | if want, have := uint64(2), resolves; want != have { | |
115 | t.Errorf("want %d, have %d", want, have) | |
116 | } | |
117 | } | |
118 | ||
119 | func TestRefreshResolveError(t *testing.T) { | |
120 | t.Skip("TODO") | |
121 | } | |
122 | ||
123 | func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) { | |
124 | return func(service, proto, name string) (string, []*net.SRV, error) { | |
125 | if count != nil { | |
126 | atomic.AddUint64(count, 1) | |
127 | } | |
128 | return "", addrs, err | |
129 | } | |
130 | } |
0 | package dnssrv | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "io" | |
5 | "net" | |
6 | "sync/atomic" | |
7 | "testing" | |
8 | "time" | |
9 | ||
10 | "golang.org/x/net/context" | |
11 | ||
12 | "github.com/go-kit/kit/endpoint" | |
13 | "github.com/go-kit/kit/log" | |
14 | ) | |
15 | ||
16 | func TestPublisher(t *testing.T) { | |
17 | var ( | |
18 | name = "foo" | |
19 | ttl = time.Second | |
20 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
21 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
22 | logger = log.NewNopLogger() | |
23 | ) | |
24 | ||
25 | p := NewPublisher(name, ttl, factory, logger) | |
26 | defer p.Stop() | |
27 | ||
28 | if _, err := p.Endpoints(); err != nil { | |
29 | t.Fatal(err) | |
30 | } | |
31 | } | |
32 | ||
33 | func TestBadLookup(t *testing.T) { | |
34 | var ( | |
35 | name = "some-name" | |
36 | ticker = time.NewTicker(time.Second) | |
37 | lookups = uint32(0) | |
38 | lookupSRV = func(string, string, string) (string, []*net.SRV, error) { | |
39 | atomic.AddUint32(&lookups, 1) | |
40 | return "", nil, errors.New("kaboom") | |
41 | } | |
42 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
43 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
44 | logger = log.NewNopLogger() | |
45 | ) | |
46 | ||
47 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
48 | defer p.Stop() | |
49 | ||
50 | endpoints, err := p.Endpoints() | |
51 | if err != nil { | |
52 | t.Error(err) | |
53 | } | |
54 | if want, have := 0, len(endpoints); want != have { | |
55 | t.Errorf("want %d, have %d", want, have) | |
56 | } | |
57 | if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { | |
58 | t.Errorf("want %d, have %d", want, have) | |
59 | } | |
60 | } | |
61 | ||
62 | func TestBadFactory(t *testing.T) { | |
63 | var ( | |
64 | name = "some-name" | |
65 | ticker = time.NewTicker(time.Second) | |
66 | addr = &net.SRV{Target: "foo", Port: 1234} | |
67 | addrs = []*net.SRV{addr} | |
68 | lookupSRV = func(a, b, c string) (string, []*net.SRV, error) { return "", addrs, nil } | |
69 | creates = uint32(0) | |
70 | factory = func(s string) (endpoint.Endpoint, io.Closer, error) { | |
71 | atomic.AddUint32(&creates, 1) | |
72 | return nil, nil, errors.New("kaboom") | |
73 | } | |
74 | logger = log.NewNopLogger() | |
75 | ) | |
76 | ||
77 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
78 | defer p.Stop() | |
79 | ||
80 | endpoints, err := p.Endpoints() | |
81 | if err != nil { | |
82 | t.Error(err) | |
83 | } | |
84 | if want, have := 0, len(endpoints); want != have { | |
85 | t.Errorf("want %q, have %q", want, have) | |
86 | } | |
87 | if want, have := uint32(1), atomic.LoadUint32(&creates); want != have { | |
88 | t.Errorf("want %d, have %d", want, have) | |
89 | } | |
90 | } | |
91 | ||
92 | func TestRefreshWithChange(t *testing.T) { | |
93 | t.Skip("TODO") | |
94 | } | |
95 | ||
96 | func TestRefreshNoChange(t *testing.T) { | |
97 | var ( | |
98 | addr = &net.SRV{Target: "my-target", Port: 5678} | |
99 | addrs = []*net.SRV{addr} | |
100 | name = "my-name" | |
101 | ticker = time.NewTicker(time.Second) | |
102 | lookups = uint32(0) | |
103 | lookupSRV = func(string, string, string) (string, []*net.SRV, error) { | |
104 | atomic.AddUint32(&lookups, 1) | |
105 | return "", addrs, nil | |
106 | } | |
107 | e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } | |
108 | factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } | |
109 | logger = log.NewNopLogger() | |
110 | ) | |
111 | ||
112 | ticker.Stop() | |
113 | tickc := make(chan time.Time) | |
114 | ticker.C = tickc | |
115 | ||
116 | p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) | |
117 | defer p.Stop() | |
118 | ||
119 | if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { | |
120 | t.Errorf("want %d, have %d", want, have) | |
121 | } | |
122 | ||
123 | tickc <- time.Now() | |
124 | ||
125 | if want, have := uint32(2), atomic.LoadUint32(&lookups); want != have { | |
126 | t.Errorf("want %d, have %d", want, have) | |
127 | } | |
128 | } | |
129 | ||
130 | func TestRefreshResolveError(t *testing.T) { | |
131 | t.Skip("TODO") | |
132 | } |