Codebase list golang-github-go-kit-kit / 1eba920 loadbalancer / dnssrv / publisher_internal_test.go
1eba920

Tree @1eba920 (Download .tar.gz)

publisher_internal_test.go @1eba920raw · history · blame

package dnssrv

import (
	"errors"
	"net"
	"sync/atomic"
	"testing"
	"time"

	"golang.org/x/net/context"

	"github.com/go-kit/kit/endpoint"
	"github.com/go-kit/kit/loadbalancer"
	"github.com/go-kit/kit/log"
)

func TestPublisher(t *testing.T) {
	var (
		target = "my-target"
		port   = uint16(1234)
		addr   = &net.SRV{Target: target, Port: port}
		addrs  = []*net.SRV{addr}
		name   = "my-name"
		ttl    = time.Second
		logger = log.NewNopLogger()
		e      = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
	)

	oldLookup := lookupSRV
	defer func() { lookupSRV = oldLookup }()
	lookupSRV = mockLookupSRV(addrs, nil, nil)

	factory := func(instance string) (endpoint.Endpoint, error) {
		if want, have := addr2instance(addr), instance; want != have {
			t.Errorf("want %q, have %q", want, have)
		}
		return e, nil
	}

	p, err := NewPublisher(name, ttl, factory, logger)
	if err != nil {
		t.Fatal(err)
	}
	defer p.Stop()

	if _, err := p.Endpoints(); err != nil {
		t.Fatal(err)
	}
}

func TestBadLookup(t *testing.T) {
	oldLookup := lookupSRV
	defer func() { lookupSRV = oldLookup }()
	lookupSRV = mockLookupSRV([]*net.SRV{}, errors.New("kaboom"), nil)

	var (
		name    = "some-name"
		ttl     = time.Second
		factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("unreachable") }
		logger  = log.NewNopLogger()
	)

	if _, err := NewPublisher(name, ttl, factory, logger); err == nil {
		t.Fatal("wanted error, got none")
	}
}

func TestBadFactory(t *testing.T) {
	var (
		addr    = &net.SRV{Target: "foo", Port: 1234}
		addrs   = []*net.SRV{addr}
		name    = "some-name"
		ttl     = time.Second
		factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
		logger  = log.NewNopLogger()
	)

	oldLookup := lookupSRV
	defer func() { lookupSRV = oldLookup }()
	lookupSRV = mockLookupSRV(addrs, nil, nil)

	p, err := NewPublisher(name, ttl, factory, logger)
	if err != nil {
		t.Fatal(err)
	}
	defer p.Stop()

	endpoints, err := p.Endpoints()
	if err != nil {
		t.Fatal(err)
	}
	if want, have := 0, len(endpoints); want != have {
		t.Errorf("want %q, have %q", want, have)
	}
}

func TestRefreshWithChange(t *testing.T) {
	t.Skip("TODO")
}

func TestRefreshNoChange(t *testing.T) {
	var (
		tick    = make(chan time.Time)
		target  = "my-target"
		port    = uint16(5678)
		addr    = &net.SRV{Target: target, Port: port}
		addrs   = []*net.SRV{addr}
		name    = "my-name"
		ttl     = time.Second
		factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
		logger  = log.NewNopLogger()
	)

	oldTicker := newTicker
	defer func() { newTicker = oldTicker }()
	newTicker = func(time.Duration) *time.Ticker { return &time.Ticker{C: tick} }

	var resolves uint64
	oldLookup := lookupSRV
	defer func() { lookupSRV = oldLookup }()
	lookupSRV = mockLookupSRV(addrs, nil, &resolves)

	p, err := NewPublisher(name, ttl, factory, logger)
	if err != nil {
		t.Fatal(err)
	}
	defer p.Stop()

	tick <- time.Now()
	if want, have := uint64(2), resolves; want != have {
		t.Errorf("want %d, have %d", want, have)
	}
}

func TestRefreshResolveError(t *testing.T) {
	t.Skip("TODO")
}

func TestErrPublisherStopped(t *testing.T) {
	var (
		name    = "my-name"
		ttl     = time.Second
		factory = func(string) (endpoint.Endpoint, error) { return nil, errors.New("kaboom") }
		logger  = log.NewNopLogger()
	)

	oldLookup := lookupSRV
	defer func() { lookupSRV = oldLookup }()
	lookupSRV = mockLookupSRV([]*net.SRV{}, nil, nil)

	p, err := NewPublisher(name, ttl, factory, logger)
	if err != nil {
		t.Fatal(err)
	}

	p.Stop()
	_, have := p.Endpoints()
	if want := loadbalancer.ErrPublisherStopped; want != have {
		t.Fatalf("want %v, have %v", want, have)
	}
}

func mockLookupSRV(addrs []*net.SRV, err error, count *uint64) func(service, proto, name string) (string, []*net.SRV, error) {
	return func(service, proto, name string) (string, []*net.SRV, error) {
		if count != nil {
			atomic.AddUint64(count, 1)
		}
		return "", addrs, err
	}
}