Codebase list golang-github-go-kit-kit / fc19884
loadbalancer/dnssrv: fix racy tests Peter Bourgon 8 years ago
3 changed file(s) with 162 addition(s) and 147 deletion(s). Raw diff Collapse all Expand all
1313 // resolved on a fixed schedule. Priorities and weights are ignored.
1414 type Publisher struct {
1515 name string
16 ttl time.Duration
1716 cache *loadbalancer.EndpointCache
1817 logger log.Logger
1918 quit chan struct{}
2423 // constructor will return an error. The factory is used to convert a
2524 // host:port to a usable endpoint. The logger is used to report DNS and
2625 // factory errors.
27 func NewPublisher(name string, ttl time.Duration, factory loadbalancer.Factory, logger log.Logger) *Publisher {
26 func NewPublisher(
27 name string,
28 ttl time.Duration,
29 factory loadbalancer.Factory,
30 logger log.Logger,
31 ) *Publisher {
32 return NewPublisherDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger)
33 }
34
35 // NewPublisherDetailed is the same as NewPublisher, but allows users to provide
36 // an explicit lookup refresh ticker instead of a TTL, and specify the function
37 // used to perform lookups instead of using net.LookupSRV.
38 func NewPublisherDetailed(
39 name string,
40 refreshTicker *time.Ticker,
41 lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error),
42 factory loadbalancer.Factory,
43 logger log.Logger,
44 ) *Publisher {
2845 p := &Publisher{
2946 name: name,
30 ttl: ttl,
3147 cache: loadbalancer.NewEndpointCache(factory, logger),
3248 logger: logger,
3349 quit: make(chan struct{}),
3450 }
3551
36 instances, err := p.resolve()
52 instances, err := p.resolve(lookupSRV)
3753 if err == nil {
3854 logger.Log("name", name, "instances", len(instances))
3955 } else {
4157 }
4258 p.cache.Replace(instances)
4359
44 go p.loop()
60 go p.loop(refreshTicker, lookupSRV)
4561 return p
4662 }
4763
5066 close(p.quit)
5167 }
5268
53 func (p *Publisher) loop() {
54 t := newTicker(p.ttl)
55 defer t.Stop()
69 func (p *Publisher) loop(
70 refreshTicker *time.Ticker,
71 lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error),
72 ) {
73 defer refreshTicker.Stop()
5674 for {
5775 select {
58 case <-t.C:
59 instances, err := p.resolve()
76 case <-refreshTicker.C:
77 instances, err := p.resolve(lookupSRV)
6078 if err != nil {
6179 p.logger.Log(p.name, err)
6280 continue // don't replace potentially-good with bad
7492 return p.cache.Endpoints()
7593 }
7694
77 var (
78 lookupSRV = net.LookupSRV
79 newTicker = time.NewTicker
80 )
81
82 func (p *Publisher) resolve() ([]string, error) {
95 func (p *Publisher) resolve(lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error)) ([]string, error) {
8396 _, addrs, err := lookupSRV("", "", p.name)
8497 if err != nil {
8598 return []string{}, err
+0
-131
loadbalancer/dnssrv/publisher_internal_test.go less more
0 package dnssrv
1
2 import (
3 "errors"
4 "io"
5 "net"
6 "sync/atomic"
7 "testing"
8 "time"
9
10 "golang.org/x/net/context"
11
12 "github.com/go-kit/kit/endpoint"
13 "github.com/go-kit/kit/log"
14 )
15
16 func TestPublisher(t *testing.T) {
17 var (
18 name = "foo"
19 ttl = time.Second
20 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
21 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
22 logger = log.NewNopLogger()
23 )
24
25 p := NewPublisher(name, ttl, factory, logger)
26 defer p.Stop()
27
28 if _, err := p.Endpoints(); err != nil {
29 t.Fatal(err)
30 }
31 }
32
33 func TestBadLookup(t *testing.T) {
34 oldLookup := lookupSRV
35 defer func() { lookupSRV = oldLookup }()
36 lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil)
37
38 var (
39 name = "some-name"
40 ttl = time.Second
41 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
42 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
43 logger = log.NewNopLogger()
44 )
45
46 p := NewPublisher(name, ttl, factory, logger)
47 defer p.Stop()
48
49 endpoints, err := p.Endpoints()
50 if err != nil {
51 t.Error(err)
52 }
53 if want, have := 0, len(endpoints); want != have {
54 t.Errorf("want %d, have %d", want, have)
55 }
56 }
57
58 func TestBadFactory(t *testing.T) {
59 var (
60 addr = &net.SRV{Target: "foo", Port: 1234}
61 addrs = []*net.SRV{addr}
62 name = "some-name"
63 ttl = time.Second
64 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return nil, nil, errors.New("kaboom") }
65 logger = log.NewNopLogger()
66 )
67
68 oldLookup := lookupSRV
69 defer func() { lookupSRV = oldLookup }()
70 lookupSRV = mockLookupSRV(addrs, nil, nil)
71
72 p := NewPublisher(name, ttl, factory, logger)
73 defer p.Stop()
74
75 endpoints, err := p.Endpoints()
76 if err != nil {
77 t.Error(err)
78 }
79 if want, have := 0, len(endpoints); want != have {
80 t.Errorf("want %q, have %q", want, have)
81 }
82 }
83
84 func TestRefreshWithChange(t *testing.T) {
85 t.Skip("TODO")
86 }
87
88 func TestRefreshNoChange(t *testing.T) {
89 var (
90 tick = make(chan time.Time)
91 target = "my-target"
92 port = uint16(5678)
93 addr = &net.SRV{Target: target, Port: port}
94 addrs = []*net.SRV{addr}
95 name = "my-name"
96 ttl = time.Second
97 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return nil, nil, errors.New("kaboom") }
98 logger = log.NewNopLogger()
99 )
100
101 oldTicker := newTicker
102 defer func() { newTicker = oldTicker }()
103 newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} }
104
105 var resolves uint64
106 oldLookup := lookupSRV
107 defer func() { lookupSRV = oldLookup }()
108 lookupSRV = mockLookupSRV(addrs, nil, &resolves)
109
110 p := NewPublisher(name, ttl, factory, logger)
111 defer p.Stop()
112
113 tick <- time.Now()
114 if want, have := uint64(2), resolves; want != have {
115 t.Errorf("want %d, have %d", want, have)
116 }
117 }
118
119 func TestRefreshResolveError(t *testing.T) {
120 t.Skip("TODO")
121 }
122
123 func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) {
124 return func(service, proto, name string) (string, []*net.SRV, error) {
125 if count != nil {
126 atomic.AddUint64(count, 1)
127 }
128 return "", addrs, err
129 }
130 }
0 package dnssrv
1
2 import (
3 "errors"
4 "io"
5 "net"
6 "sync/atomic"
7 "testing"
8 "time"
9
10 "golang.org/x/net/context"
11
12 "github.com/go-kit/kit/endpoint"
13 "github.com/go-kit/kit/log"
14 )
15
16 func TestPublisher(t *testing.T) {
17 var (
18 name = "foo"
19 ttl = time.Second
20 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
21 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
22 logger = log.NewNopLogger()
23 )
24
25 p := NewPublisher(name, ttl, factory, logger)
26 defer p.Stop()
27
28 if _, err := p.Endpoints(); err != nil {
29 t.Fatal(err)
30 }
31 }
32
33 func TestBadLookup(t *testing.T) {
34 var (
35 name = "some-name"
36 ticker = time.NewTicker(time.Second)
37 lookups = uint32(0)
38 lookupSRV = func(string, string, string) (string, []*net.SRV, error) {
39 atomic.AddUint32(&lookups, 1)
40 return "", nil, errors.New("kaboom")
41 }
42 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
43 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
44 logger = log.NewNopLogger()
45 )
46
47 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
48 defer p.Stop()
49
50 endpoints, err := p.Endpoints()
51 if err != nil {
52 t.Error(err)
53 }
54 if want, have := 0, len(endpoints); want != have {
55 t.Errorf("want %d, have %d", want, have)
56 }
57 if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have {
58 t.Errorf("want %d, have %d", want, have)
59 }
60 }
61
62 func TestBadFactory(t *testing.T) {
63 var (
64 name = "some-name"
65 ticker = time.NewTicker(time.Second)
66 addr = &net.SRV{Target: "foo", Port: 1234}
67 addrs = []*net.SRV{addr}
68 lookupSRV = func(a, b, c string) (string, []*net.SRV, error) { return "", addrs, nil }
69 creates = uint32(0)
70 factory = func(s string) (endpoint.Endpoint, io.Closer, error) {
71 atomic.AddUint32(&creates, 1)
72 return nil, nil, errors.New("kaboom")
73 }
74 logger = log.NewNopLogger()
75 )
76
77 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
78 defer p.Stop()
79
80 endpoints, err := p.Endpoints()
81 if err != nil {
82 t.Error(err)
83 }
84 if want, have := 0, len(endpoints); want != have {
85 t.Errorf("want %q, have %q", want, have)
86 }
87 if want, have := uint32(1), atomic.LoadUint32(&creates); want != have {
88 t.Errorf("want %d, have %d", want, have)
89 }
90 }
91
92 func TestRefreshWithChange(t *testing.T) {
93 t.Skip("TODO")
94 }
95
96 func TestRefreshNoChange(t *testing.T) {
97 var (
98 addr = &net.SRV{Target: "my-target", Port: 5678}
99 addrs = []*net.SRV{addr}
100 name = "my-name"
101 ticker = time.NewTicker(time.Second)
102 lookups = uint32(0)
103 lookupSRV = func(string, string, string) (string, []*net.SRV, error) {
104 atomic.AddUint32(&lookups, 1)
105 return "", addrs, nil
106 }
107 e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
108 factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil }
109 logger = log.NewNopLogger()
110 )
111
112 ticker.Stop()
113 tickc := make(chan time.Time)
114 ticker.C = tickc
115
116 p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger)
117 defer p.Stop()
118
119 if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have {
120 t.Errorf("want %d, have %d", want, have)
121 }
122
123 tickc <- time.Now()
124
125 if want, have := uint32(2), atomic.LoadUint32(&lookups); want != have {
126 t.Errorf("want %d, have %d", want, have)
127 }
128 }
129
130 func TestRefreshResolveError(t *testing.T) {
131 t.Skip("TODO")
132 }