diff --git a/loadbalancer/dnssrv/publisher.go b/loadbalancer/dnssrv/publisher.go index 77329fc..dc4b8c8 100644 --- a/loadbalancer/dnssrv/publisher.go +++ b/loadbalancer/dnssrv/publisher.go @@ -79,7 +79,12 @@ // Endpoints implements the Publisher interface. func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return <-p.endpoints, nil + select { + case endpoints := <-p.endpoints: + return endpoints, nil + case <-p.quit: + return nil, loadbalancer.ErrPublisherStopped + } } var ( diff --git a/loadbalancer/dnssrv/publisher_internal_test.go b/loadbalancer/dnssrv/publisher_internal_test.go index 0613e44..49680f7 100644 --- a/loadbalancer/dnssrv/publisher_internal_test.go +++ b/loadbalancer/dnssrv/publisher_internal_test.go @@ -10,6 +10,7 @@ "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/loadbalancer" "github.com/go-kit/kit/log" ) @@ -135,6 +136,30 @@ t.Skip("TODO") } +func TestErrPublisherStopped(t *testing.T) { + var ( + name = "my-name" + ttl = time.Second + factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") } + logger = log.NewNopLogger() + ) + + oldLookup := lookupSRV + defer func() { lookupSRV = oldLookup }() + lookupSRV = mockLookupSRV([]*net.SRV{}, nil, nil) + + p, err := NewPublisher(name, ttl, factory, logger) + if err != nil { + t.Fatal(err) + } + + p.Stop() + _, have := p.Endpoints() + if want := loadbalancer.ErrPublisherStopped; want != have { + t.Fatalf("want %v, have %v", want, have) + } +} + func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) { return func(service, proto, name string) (string, []*net.SRV, error) { if count != nil { diff --git a/loadbalancer/publisher.go b/loadbalancer/publisher.go index ec17d7e..d175da8 100644 --- a/loadbalancer/publisher.go +++ b/loadbalancer/publisher.go @@ -1,6 +1,10 @@ package loadbalancer -import "github.com/go-kit/kit/endpoint" +import ( + "errors" + + "github.com/go-kit/kit/endpoint" +) // Publisher describes something that provides a set of identical endpoints. // Different publisher implementations exist for different kinds of service @@ -8,3 +12,7 @@ type Publisher interface { Endpoints() ([]endpoint.Endpoint, error) } + +// ErrPublisherStopped is returned by publishers when the underlying +// implementation has been terminated and can no longer serve requests. +var ErrPublisherStopped = errors.New("publisher stopped")