Codebase list golang-github-go-kit-kit / 9a19822 sd / zk / util_test.go
9a19822

Tree @9a19822 (Download .tar.gz)

util_test.go @9a19822raw · history · blame

package zk

import (
	"errors"
	"fmt"
	"io"
	"sync"
	"time"

	"github.com/samuel/go-zookeeper/zk"
	"golang.org/x/net/context"

	"github.com/go-kit/kit/endpoint"
	"github.com/go-kit/kit/log"
	"github.com/go-kit/kit/sd"
)

var (
	path   = "/gokit.test/service.name"
	e      = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
	logger = log.NewNopLogger()
)

type fakeClient struct {
	mtx       sync.Mutex
	ch        chan zk.Event
	responses map[string]string
	result    bool
}

func newFakeClient() *fakeClient {
	return &fakeClient{
		ch:        make(chan zk.Event, 1),
		responses: make(map[string]string),
		result:    true,
	}
}

func (c *fakeClient) CreateParentNodes(path string) error {
	if path == "BadPath" {
		return errors.New("dummy error")
	}
	return nil
}

func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.result == false {
		c.result = true
		return []string{}, c.ch, errors.New("dummy error")
	}
	responses := []string{}
	for _, data := range c.responses {
		responses = append(responses, data)
	}
	return responses, c.ch, nil
}

func (c *fakeClient) AddService(node, data string) {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	c.responses[node] = data
	c.ch <- zk.Event{}
}

func (c *fakeClient) RemoveService(node string) {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	delete(c.responses, node)
	c.ch <- zk.Event{}
}

func (c *fakeClient) SendErrorOnWatch() {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	c.result = false
	c.ch <- zk.Event{}
}

func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error {
	t := time.After(timeout)
	for {
		select {
		case <-t:
			return fmt.Errorf("expected error not consumed after timeout %s", timeout)
		default:
			c.mtx.Lock()
			if c.result == false {
				c.mtx.Unlock()
				return nil
			}
			c.mtx.Unlock()
		}
	}
}

func (c *fakeClient) Stop() {}

func newFactory(fakeError string) sd.Factory {
	return func(instance string) (endpoint.Endpoint, io.Closer, error) {
		if fakeError == instance {
			return nil, nil, errors.New(fakeError)
		}
		return endpoint.Nop, nil, nil
	}
}

func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) {
	var endpoints []endpoint.Endpoint
	have := -1 // want can never be <0
	t := time.After(timeout)
	for {
		select {
		case <-t:
			return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String())
		default:
			endpoints, err = s.Endpoints()
			have = len(endpoints)
			if err != nil || want == have {
				return
			}
			time.Sleep(timeout / 10)
		}
	}
}