diff --git a/transport/amqp/doc.go b/transport/amqp/doc.go new file mode 100644 index 0000000..0dd4d3d --- /dev/null +++ b/transport/amqp/doc.go @@ -0,0 +1,2 @@ +// Package amqp implements an AMQP transport. +package amqp diff --git a/transport/amqp/encode-decode.go b/transport/amqp/encode-decode.go new file mode 100644 index 0000000..3047318 --- /dev/null +++ b/transport/amqp/encode-decode.go @@ -0,0 +1,22 @@ +package amqp + +import ( + "context" + "github.com/streadway/amqp" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an AMQP Delivery object. It is designed to be used in AMQP Subscribers. +type DecodeRequestFunc func(context.Context, *amqp.Delivery) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed request object into +// an AMQP Publishing object. It is designed to be used in AMQP Publishers. +type EncodeRequestFunc func(context.Context, *amqp.Publishing, interface{}) error + +// EncodeResponseFunc encodes the passed reponse object to +// an AMQP Publishing object. It is designed to be used in AMQP Subscribers. +type EncodeResponseFunc func(context.Context, *amqp.Publishing, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from +// an AMQP Delivery object. It is designed to be used in AMQP Publishers. +type DecodeResponseFunc func(context.Context, *amqp.Delivery) (response interface{}, err error) diff --git a/transport/amqp/publisher.go b/transport/amqp/publisher.go new file mode 100644 index 0000000..a28ee94 --- /dev/null +++ b/transport/amqp/publisher.go @@ -0,0 +1,152 @@ +package amqp + +import ( + "context" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/streadway/amqp" +) + +// The golang AMQP implementation requires the []byte representation of +// correlation id strings to have a maximum length of 255 bytes. +const maxCorrelationIdLength = 255 + +// Publisher wraps an AMQP channel and queue, and provides a method that +// implements endpoint.Endpoint. +type Publisher struct { + ch Channel + q *amqp.Queue + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc + after []PublisherResponseFunc + timeout time.Duration +} + +// NewPublisher constructs a usable Publisher for a single remote method. +func NewPublisher( + ch Channel, + q *amqp.Queue, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + ch: ch, + q: q, + enc: enc, + dec: dec, + timeout: 10 * time.Second, + } + for _, option := range options { + option(p) + } + return p +} + +// PublisherOption sets an optional parameter for clients. +type PublisherOption func(*Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing AMQP +// request before it's invoked. +func PublisherBefore(before ...RequestFunc) PublisherOption { + return func(p *Publisher) { p.before = append(p.before, before...) } +} + +// PublisherAfter sets the ClientResponseFuncs applied to the incoming AMQP +// request prior to it being decoded. This is useful for obtaining anything off +// of the response and adding onto the context prior to decoding. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { p.after = append(p.after, after...) } +} + +// PublisherTimeout sets the available timeout for an AMQP request. +func PublisherTimeout(timeout time.Duration) PublisherOption { + return func(p *Publisher) { p.timeout = timeout } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + + pub := amqp.Publishing{ + ReplyTo: p.q.Name, + CorrelationId: randomString(randInt(5, maxCorrelationIdLength)), + } + + if err := p.enc(ctx, &pub, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, &pub) + } + + deliv, err := p.publishAndConsumeFirstMatchingResponse(ctx, &pub) + if err != nil { + return nil, err + } + + for _, f := range p.after { + ctx = f(ctx, deliv) + } + response, err := p.dec(ctx, deliv) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// publishAndConsumeFirstMatchingResponse publishes the specified Publishing +// and returns the first Delivery object with the matching correlationId. +// If the context times out while waiting for a reply, an error will be returned. +func (p Publisher) publishAndConsumeFirstMatchingResponse( + ctx context.Context, + pub *amqp.Publishing, +) (*amqp.Delivery, error) { + err := p.ch.Publish( + getPublishExchange(ctx), + getPublishKey(ctx), + false, //mandatory + false, //immediate + *pub, + ) + if err != nil { + return nil, err + } + autoAck := getConsumeAutoAck(ctx) + + msg, err := p.ch.Consume( + p.q.Name, + "", //consumer + autoAck, + false, //exclusive + false, //noLocal + false, //noWait + getConsumeArgs(ctx), + ) + if err != nil { + return nil, err + } + + for { + select { + case d := <-msg: + if d.CorrelationId == pub.CorrelationId { + if !autoAck { + d.Ack(false) //multiple + } + return &d, nil + } + + case <-ctx.Done(): + return nil, ctx.Err() + } + } + +} diff --git a/transport/amqp/publisher_test.go b/transport/amqp/publisher_test.go new file mode 100644 index 0000000..5b6785c --- /dev/null +++ b/transport/amqp/publisher_test.go @@ -0,0 +1,226 @@ +package amqp_test + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + amqptransport "github.com/go-kit/kit/transport/amqp" + "github.com/streadway/amqp" +) + +var ( + defaultContentType = "" + defaultContentEncoding = "" +) + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + ch := &mockChannel{f: nullFunc} + q := &amqp.Queue{Name: "some queue"} + pub := amqptransport.NewPublisher( + ch, + q, + func(context.Context, *amqp.Publishing, interface{}) error { return errors.New("err!") }, + func(context.Context, *amqp.Delivery) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + cid := "correlation" + ch := &mockChannel{ + f: nullFunc, + c: make(chan amqp.Publishing, 1), + deliveries: []amqp.Delivery{ + amqp.Delivery{ + CorrelationId: cid, + }, + }, + } + q := &amqp.Queue{Name: "some queue"} + + pub := amqptransport.NewPublisher( + ch, + q, + func(context.Context, *amqp.Publishing, interface{}) error { return nil }, + func(context.Context, *amqp.Delivery) (response interface{}, err error) { + return struct{}{}, errors.New("err!") + }, + amqptransport.PublisherBefore( + amqptransport.SetCorrelationID(cid), + ), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestPublisherTimeout ensures that the publisher timeout mechanism works. +func TestPublisherTimeout(t *testing.T) { + ch := &mockChannel{ + f: nullFunc, + c: make(chan amqp.Publishing, 1), + deliveries: []amqp.Delivery{}, // no reply from mock subscriber + } + q := &amqp.Queue{Name: "some queue"} + + pub := amqptransport.NewPublisher( + ch, + q, + func(context.Context, *amqp.Publishing, interface{}) error { return nil }, + func(context.Context, *amqp.Delivery) (response interface{}, err error) { + return struct{}{}, nil + }, + amqptransport.PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSuccessfulPublisher(t *testing.T) { + cid := "correlation" + mockReq := testReq{437} + mockRes := testRes{ + Squadron: mockReq.Squadron, + Name: names[mockReq.Squadron], + } + b, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + reqChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{ + f: nullFunc, + c: reqChan, + deliveries: []amqp.Delivery{ + amqp.Delivery{ + CorrelationId: cid, + Body: b, + }, + }, + } + q := &amqp.Queue{Name: "some queue"} + + pub := amqptransport.NewPublisher( + ch, + q, + testReqEncoder, + testResDeliveryDecoder, + amqptransport.PublisherBefore( + amqptransport.SetCorrelationID(cid), + ), + ) + var publishing amqp.Publishing + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + res, err := pub.Endpoint()(context.Background(), mockReq) + if err != nil { + errChan <- err + } else { + resChan <- res + } + }() + + select { + case publishing = <-reqChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for request") + } + if want, have := defaultContentType, publishing.ContentType; want != have { + t.Errorf("want %s, have %s", want, have) + } + if want, have := defaultContentEncoding, publishing.ContentEncoding; want != have { + t.Errorf("want %s, have %s", want, have) + } + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.Name, res.Name; want != have { + t.Errorf("want %s, have %s", want, have) + } +} diff --git a/transport/amqp/request_response_func.go b/transport/amqp/request_response_func.go new file mode 100644 index 0000000..a6f730f --- /dev/null +++ b/transport/amqp/request_response_func.go @@ -0,0 +1,182 @@ +package amqp + +import ( + "context" + "time" + + "github.com/streadway/amqp" +) + +// RequestFunc may take information from a publisher request and put it into a +// request context. In Subscribers, RequestFuncs are executed prior to invoking +// the endpoint. +type RequestFunc func(context.Context, *amqp.Publishing) context.Context + +// SubscriberResponseFunc may take information from a request context and use it to +// manipulate a Publisher. SubscriberResponseFuncs are only executed in +// subscribers, after invoking the endpoint but prior to publishing a reply. +type SubscriberResponseFunc func(context.Context, + *amqp.Delivery, + Channel, + *amqp.Publishing, +) context.Context + +// PublisherResponseFunc may take information from an AMQP request and make the +// response available for consumption. PublisherResponseFunc are only executed +// in publishers, after a request has been made, but prior to it being decoded. +type PublisherResponseFunc func(context.Context, *amqp.Delivery) context.Context + +// SetPublishExchange returns a RequestFunc that sets the Exchange field +// of an AMQP Publish call. +func SetPublishExchange(publishExchange string) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + return context.WithValue(ctx, ContextKeyExchange, publishExchange) + } +} + +// SetPublishKey returns a RequestFunc that sets the Key field +// of an AMQP Publish call. +func SetPublishKey(publishKey string) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + return context.WithValue(ctx, ContextKeyPublishKey, publishKey) + } +} + +// SetPublishDeliveryMode sets the delivery mode of a Publishing. +// Please refer to AMQP delivery mode constants in the AMQP package. +func SetPublishDeliveryMode(dmode uint8) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + pub.DeliveryMode = dmode + return ctx + } +} + +// SetNackSleepDuration returns a RequestFunc that sets the amount of time +// to sleep in the event of a Nack. +// This has to be used in conjunction with an error encoder that Nack and sleeps. +// One example is the SingleNackRequeueErrorEncoder. +// It is designed to be used by Subscribers. +func SetNackSleepDuration(duration time.Duration) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + return context.WithValue(ctx, ContextKeyNackSleepDuration, duration) + } +} + +// SetConsumeAutoAck returns a RequestFunc that sets whether or not to autoAck +// messages when consuming. +// When set to false, the publisher will Ack the first message it receives with +// a matching correlationId. +// It is designed to be used by Publishers. +func SetConsumeAutoAck(autoAck bool) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + return context.WithValue(ctx, ContextKeyAutoAck, autoAck) + } +} + +// SetConsumeArgs returns a RequestFunc that set the arguments for amqp Consume +// function. +// It is designed to be used by Publishers. +func SetConsumeArgs(args amqp.Table) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + return context.WithValue(ctx, ContextKeyConsumeArgs, args) + } +} + +// SetContentType returns a RequestFunc that sets the ContentType field of +// an AMQP Publishing. +func SetContentType(contentType string) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + pub.ContentType = contentType + return ctx + } +} + +// SetContentEncoding returns a RequestFunc that sets the ContentEncoding field +// of an AMQP Publishing. +func SetContentEncoding(contentEncoding string) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + pub.ContentEncoding = contentEncoding + return ctx + } +} + +// SetCorrelationID returns a RequestFunc that sets the CorrelationId field +// of an AMQP Publishing. +func SetCorrelationID(cid string) RequestFunc { + return func(ctx context.Context, pub *amqp.Publishing) context.Context { + pub.CorrelationId = cid + return ctx + } +} + +// SetAckAfterEndpoint returns a SubscriberResponseFunc that prompts the service +// to Ack the Delivery object after successfully evaluating the endpoint, +// and before it encodes the response. +// It is designed to be used by Subscribers. +func SetAckAfterEndpoint(multiple bool) SubscriberResponseFunc { + return func(ctx context.Context, + deliv *amqp.Delivery, + ch Channel, + pub *amqp.Publishing, + ) context.Context { + deliv.Ack(multiple) + return ctx + } +} + +func getPublishExchange(ctx context.Context) string { + if exchange := ctx.Value(ContextKeyExchange); exchange != nil { + return exchange.(string) + } + return "" +} + +func getPublishKey(ctx context.Context) string { + if publishKey := ctx.Value(ContextKeyPublishKey); publishKey != nil { + return publishKey.(string) + } + return "" +} + +func getNackSleepDuration(ctx context.Context) time.Duration { + if duration := ctx.Value(ContextKeyNackSleepDuration); duration != nil { + return duration.(time.Duration) + } + return 0 +} + +func getConsumeAutoAck(ctx context.Context) bool { + if autoAck := ctx.Value(ContextKeyAutoAck); autoAck != nil { + return autoAck.(bool) + } + return false +} + +func getConsumeArgs(ctx context.Context) amqp.Table { + if args := ctx.Value(ContextKeyConsumeArgs); args != nil { + return args.(amqp.Table) + } + return nil +} + +type contextKey int + +const ( + // ContextKeyExchange is the value of the reply Exchange in + // amqp.Publish. + ContextKeyExchange contextKey = iota + // ContextKeyPublishKey is the value of the ReplyTo field in + // amqp.Publish. + ContextKeyPublishKey + // ContextKeyNackSleepDuration is the duration to sleep for if the + // service Nack and requeues a message. + // This is to prevent sporadic send-resending of message + // when a message is constantly Nack'd and requeued. + ContextKeyNackSleepDuration + // ContextKeyAutoAck is the value of autoAck field when calling + // amqp.Channel.Consume. + ContextKeyAutoAck + // ContextKeyConsumeArgs is the value of consumeArgs field when calling + // amqp.Channel.Consume. + ContextKeyConsumeArgs +) diff --git a/transport/amqp/subscriber.go b/transport/amqp/subscriber.go new file mode 100644 index 0000000..17e1b0f --- /dev/null +++ b/transport/amqp/subscriber.go @@ -0,0 +1,253 @@ +package amqp + +import ( + "context" + "encoding/json" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/streadway/amqp" +) + +// Subscriber wraps an endpoint and provides a handler for AMQP Delivery messages. +type Subscriber struct { + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + before []RequestFunc + after []SubscriberResponseFunc + errorEncoder ErrorEncoder + logger log.Logger +} + +// NewSubscriber constructs a new subscriber, which provides a handler +// for AMQP Delivery messages. +func NewSubscriber( + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + options ...SubscriberOption, +) *Subscriber { + s := &Subscriber{ + e: e, + dec: dec, + enc: enc, + errorEncoder: DefaultErrorEncoder, + logger: log.NewNopLogger(), + } + for _, option := range options { + option(s) + } + return s +} + +// SubscriberOption sets an optional parameter for subscribers. +type SubscriberOption func(*Subscriber) + +// SubscriberBefore functions are executed on the publisher delivery object +// before the request is decoded. +func SubscriberBefore(before ...RequestFunc) SubscriberOption { + return func(s *Subscriber) { s.before = append(s.before, before...) } +} + +// SubscriberAfter functions are executed on the subscriber reply after the +// endpoint is invoked, but before anything is published to the reply. +func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { + return func(s *Subscriber) { s.after = append(s.after, after...) } +} + +// SubscriberErrorEncoder is used to encode errors to the subscriber reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { + return func(s *Subscriber) { s.errorEncoder = ee } +} + +// SubscriberErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom SubscriberErrorEncoder which has access to the context. +func SubscriberErrorLogger(logger log.Logger) SubscriberOption { + return func(s *Subscriber) { s.logger = logger } +} + +// ServeDelivery handles AMQP Delivery messages +// It is strongly recommended to use *amqp.Channel as the +// Channel interface implementation. +func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) { + return func(deliv *amqp.Delivery) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pub := amqp.Publishing{} + + for _, f := range s.before { + ctx = f(ctx, &pub) + } + + request, err := s.dec(ctx, deliv) + if err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, deliv, ch, &pub) + return + } + + response, err := s.e(ctx, request) + if err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, deliv, ch, &pub) + return + } + + for _, f := range s.after { + ctx = f(ctx, deliv, ch, &pub) + } + + if err := s.enc(ctx, &pub, response); err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, deliv, ch, &pub) + return + } + + if err := s.publishResponse(ctx, deliv, ch, &pub); err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, deliv, ch, &pub) + return + } + } + +} + +func (s Subscriber) publishResponse( + ctx context.Context, + deliv *amqp.Delivery, + ch Channel, + pub *amqp.Publishing, +) error { + if pub.CorrelationId == "" { + pub.CorrelationId = deliv.CorrelationId + } + + replyExchange := getPublishExchange(ctx) + replyTo := getPublishKey(ctx) + if replyTo == "" { + replyTo = deliv.ReplyTo + } + + return ch.Publish( + replyExchange, + replyTo, + false, // mandatory + false, // immediate + *pub, + ) +} + +// EncodeJSONResponse marshals the response as JSON as part of the +// payload of the AMQP Publishing object. +func EncodeJSONResponse( + ctx context.Context, + pub *amqp.Publishing, + response interface{}, +) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + pub.Body = b + return nil +} + +// EncodeNopResponse is a response function that does nothing. +func EncodeNopResponse( + ctx context.Context, + pub *amqp.Publishing, + response interface{}, +) error { + return nil +} + +// ErrorEncoder is responsible for encoding an error to the subscriber reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, + err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) + +// DefaultErrorEncoder simply ignores the message. It does not reply +// nor Ack/Nack the message. +func DefaultErrorEncoder(ctx context.Context, + err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { +} + +// SingleNackRequeueErrorEncoder issues a Nack to the delivery with multiple flag set as false +// and requeue flag set as true. It does not reply the message. +func SingleNackRequeueErrorEncoder(ctx context.Context, + err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { + deliv.Nack( + false, //multiple + true, //requeue + ) + duration := getNackSleepDuration(ctx) + time.Sleep(duration) +} + +// ReplyErrorEncoder serializes the error message as a DefaultErrorResponse +// JSON and sends the message to the ReplyTo address. +func ReplyErrorEncoder( + ctx context.Context, + err error, + deliv *amqp.Delivery, + ch Channel, + pub *amqp.Publishing, +) { + + if pub.CorrelationId == "" { + pub.CorrelationId = deliv.CorrelationId + } + + replyExchange := getPublishExchange(ctx) + replyTo := getPublishKey(ctx) + if replyTo == "" { + replyTo = deliv.ReplyTo + } + + response := DefaultErrorResponse{err.Error()} + + b, err := json.Marshal(response) + if err != nil { + return + } + pub.Body = b + + ch.Publish( + replyExchange, + replyTo, + false, // mandatory + false, // immediate + *pub, + ) +} + +// ReplyAndAckErrorEncoder serializes the error message as a DefaultErrorResponse +// JSON and sends the message to the ReplyTo address then Acks the original +// message. +func ReplyAndAckErrorEncoder(ctx context.Context, err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { + ReplyErrorEncoder(ctx, err, deliv, ch, pub) + deliv.Ack(false) +} + +// DefaultErrorResponse is the default structure of responses in the event +// of an error. +type DefaultErrorResponse struct { + Error string `json:"err"` +} + +// Channel is a channel interface to make testing possible. +// It is highly recommended to use *amqp.Channel as the interface implementation. +type Channel interface { + Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error + Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) +} diff --git a/transport/amqp/subscriber_test.go b/transport/amqp/subscriber_test.go new file mode 100644 index 0000000..5aece6b --- /dev/null +++ b/transport/amqp/subscriber_test.go @@ -0,0 +1,395 @@ +package amqp_test + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + amqptransport "github.com/go-kit/kit/transport/amqp" + "github.com/streadway/amqp" +) + +var ( + typeAssertionError = errors.New("type assertion error") +) + +// mockChannel is a mock of *amqp.Channel. +type mockChannel struct { + f func(exchange, key string, mandatory, immediate bool) + c chan<- amqp.Publishing + deliveries []amqp.Delivery +} + +// Publish runs a test function f and sends resultant message to a channel. +func (ch *mockChannel) Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error { + ch.f(exchange, key, mandatory, immediate) + ch.c <- msg + return nil +} + +var nullFunc = func(exchange, key string, mandatory, immediate bool) { +} + +func (ch *mockChannel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) { + c := make(chan amqp.Delivery, len(ch.deliveries)) + for _, d := range ch.deliveries { + c <- d + } + return c, nil +} + +// TestSubscriberBadDecode checks if decoder errors are handled properly. +func TestSubscriberBadDecode(t *testing.T) { + sub := amqptransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Delivery) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *amqp.Publishing, interface{}) error { + return nil + }, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: nullFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{}) + + var msg amqp.Publishing + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadEndpoint checks if endpoint errors are handled properly. +func TestSubscriberBadEndpoint(t *testing.T) { + sub := amqptransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Publishing, interface{}) error { + return nil + }, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: nullFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{}) + + var msg amqp.Publishing + + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + + res, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberBadEncoder checks if encoder errors are handled properly. +func TestSubscriberBadEncoder(t *testing.T) { + sub := amqptransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Publishing, interface{}) error { + return errors.New("err!") + }, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: nullFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{}) + + var msg amqp.Publishing + + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + + res, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Error; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberSuccess checks if CorrelationId and ReplyTo are set properly +// and if the payload is encoded properly. +func TestSubscriberSuccess(t *testing.T) { + cid := "correlation" + replyTo := "sender" + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + + sub := amqptransport.NewSubscriber( + testEndpoint, + testReqDecoder, + amqptransport.EncodeJSONResponse, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + + checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) { + if want, have := replyTo, key; want != have { + t.Errorf("want %s, have %s", want, have) + } + } + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: checkReplyToFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{ + CorrelationId: cid, + ReplyTo: replyTo, + Body: b, + }) + + var msg amqp.Publishing + + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + + if want, have := cid, msg.CorrelationId; want != have { + t.Errorf("want %s, have %s", want, have) + } + + // check if error is not thrown + errRes, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if errRes.Error != "" { + t.Error("Received error from subscriber", errRes.Error) + return + } + + // check obj vals + response, err := testResDecoder(msg.Body) + if err != nil { + t.Fatal(err) + } + res, ok := response.(testRes) + if !ok { + t.Error(typeAssertionError) + } + + if want, have := obj.Squadron, res.Squadron; want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := names[obj.Squadron], res.Name; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode +// are working. +func TestSubscriberMultipleBefore(t *testing.T) { + exchange := "some exchange" + key := "some key" + deliveryMode := uint8(127) + contentType := "some content type" + contentEncoding := "some content encoding" + sub := amqptransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, + amqptransport.EncodeJSONResponse, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + amqptransport.SubscriberBefore( + amqptransport.SetPublishExchange(exchange), + amqptransport.SetPublishKey(key), + amqptransport.SetPublishDeliveryMode(deliveryMode), + amqptransport.SetContentType(contentType), + amqptransport.SetContentEncoding(contentEncoding), + ), + ) + checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { + if want, have := exchange, exch; want != have { + t.Errorf("want %s, have %s", want, have) + } + if want, have := key, k; want != have { + t.Errorf("want %s, have %s", want, have) + } + } + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: checkReplyToFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{}) + + var msg amqp.Publishing + + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + + // check if error is not thrown + errRes, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if errRes.Error != "" { + t.Error("Received error from subscriber", errRes.Error) + return + } + + if want, have := contentType, msg.ContentType; want != have { + t.Errorf("want %s, have %s", want, have) + } + + if want, have := contentEncoding, msg.ContentEncoding; want != have { + t.Errorf("want %s, have %s", want, have) + } + + if want, have := deliveryMode, msg.DeliveryMode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +// TestDefaultContentMetaData checks that default ContentType and Content-Encoding +// is not set as mentioned by AMQP specification. +func TestDefaultContentMetaData(t *testing.T) { + defaultContentType := "" + defaultContentEncoding := "" + sub := amqptransport.NewSubscriber( + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, + amqptransport.EncodeJSONResponse, + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { return } + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: checkReplyToFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{}) + + var msg amqp.Publishing + + select { + case msg = <-outputChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + + // check if error is not thrown + errRes, err := decodeSubscriberError(msg) + if err != nil { + t.Fatal(err) + } + if errRes.Error != "" { + t.Error("Received error from subscriber", errRes.Error) + return + } + + if want, have := defaultContentType, msg.ContentType; want != have { + t.Errorf("want %s, have %s", want, have) + } + if want, have := defaultContentEncoding, msg.ContentEncoding; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func decodeSubscriberError(pub amqp.Publishing) (amqptransport.DefaultErrorResponse, error) { + var res amqptransport.DefaultErrorResponse + err := json.Unmarshal(pub.Body, &res) + return res, err +} + +type testReq struct { + Squadron int `json:"s"` +} +type testRes struct { + Squadron int `json:"s"` + Name string `json:"n"` +} + +func testEndpoint(_ context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, typeAssertionError + } + name, prs := names[req.Squadron] + if !prs { + return nil, errors.New("unknown squadron name") + } + res := testRes{ + Squadron: req.Squadron, + Name: name, + } + return res, nil +} + +func testReqDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) { + var obj testReq + err := json.Unmarshal(d.Body, &obj) + return obj, err +} + +func testReqEncoder(_ context.Context, p *amqp.Publishing, request interface{}) error { + req, ok := request.(testReq) + if !ok { + return errors.New("type assertion failure") + } + b, err := json.Marshal(req) + if err != nil { + return err + } + p.Body = b + return nil +} + +func testResDeliveryDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) { + return testResDecoder(d.Body) +} + +func testResDecoder(b []byte) (interface{}, error) { + var obj testRes + err := json.Unmarshal(b, &obj) + return obj, err +} + +var names = map[int]string{ + 424: "tiger", + 426: "thunderbird", + 429: "bison", + 436: "tusker", + 437: "husky", +} diff --git a/transport/amqp/util.go b/transport/amqp/util.go new file mode 100644 index 0000000..020051e --- /dev/null +++ b/transport/amqp/util.go @@ -0,0 +1,17 @@ +package amqp + +import ( + "math/rand" +) + +func randomString(l int) string { + bytes := make([]byte, l) + for i := 0; i < l; i++ { + bytes[i] = byte(randInt(65, 90)) + } + return string(bytes) +} + +func randInt(min int, max int) int { + return min + rand.Intn(max-min) +}