Codebase list golang-github-go-kit-kit / d2f2902b-79c4-43cd-8c09-341e33fc6017/v0.8.0
Add AMQP transport (#746) (#746) amqp transport publisher amqp transport tests lint fixes for amqp transport fixed formatting and punctuation zzz default ContentType is null, increased max length of correlationId to 255 refractored subscriber EncodeResponseFunc into encode func, and send func zzz Matthew Fung authored 5 years ago Peter Bourgon committed 5 years ago
8 changed file(s) with 1249 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 // Package amqp implements an AMQP transport.
1 package amqp
0 package amqp
1
2 import (
3 "context"
4 "github.com/streadway/amqp"
5 )
6
7 // DecodeRequestFunc extracts a user-domain request object from
8 // an AMQP Delivery object. It is designed to be used in AMQP Subscribers.
9 type DecodeRequestFunc func(context.Context, *amqp.Delivery) (request interface{}, err error)
10
11 // EncodeRequestFunc encodes the passed request object into
12 // an AMQP Publishing object. It is designed to be used in AMQP Publishers.
13 type EncodeRequestFunc func(context.Context, *amqp.Publishing, interface{}) error
14
15 // EncodeResponseFunc encodes the passed reponse object to
16 // an AMQP Publishing object. It is designed to be used in AMQP Subscribers.
17 type EncodeResponseFunc func(context.Context, *amqp.Publishing, interface{}) error
18
19 // DecodeResponseFunc extracts a user-domain response object from
20 // an AMQP Delivery object. It is designed to be used in AMQP Publishers.
21 type DecodeResponseFunc func(context.Context, *amqp.Delivery) (response interface{}, err error)
0 package amqp
1
2 import (
3 "context"
4 "time"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/streadway/amqp"
8 )
9
10 // The golang AMQP implementation requires the []byte representation of
11 // correlation id strings to have a maximum length of 255 bytes.
12 const maxCorrelationIdLength = 255
13
14 // Publisher wraps an AMQP channel and queue, and provides a method that
15 // implements endpoint.Endpoint.
16 type Publisher struct {
17 ch Channel
18 q *amqp.Queue
19 enc EncodeRequestFunc
20 dec DecodeResponseFunc
21 before []RequestFunc
22 after []PublisherResponseFunc
23 timeout time.Duration
24 }
25
26 // NewPublisher constructs a usable Publisher for a single remote method.
27 func NewPublisher(
28 ch Channel,
29 q *amqp.Queue,
30 enc EncodeRequestFunc,
31 dec DecodeResponseFunc,
32 options ...PublisherOption,
33 ) *Publisher {
34 p := &Publisher{
35 ch: ch,
36 q: q,
37 enc: enc,
38 dec: dec,
39 timeout: 10 * time.Second,
40 }
41 for _, option := range options {
42 option(p)
43 }
44 return p
45 }
46
47 // PublisherOption sets an optional parameter for clients.
48 type PublisherOption func(*Publisher)
49
50 // PublisherBefore sets the RequestFuncs that are applied to the outgoing AMQP
51 // request before it's invoked.
52 func PublisherBefore(before ...RequestFunc) PublisherOption {
53 return func(p *Publisher) { p.before = append(p.before, before...) }
54 }
55
56 // PublisherAfter sets the ClientResponseFuncs applied to the incoming AMQP
57 // request prior to it being decoded. This is useful for obtaining anything off
58 // of the response and adding onto the context prior to decoding.
59 func PublisherAfter(after ...PublisherResponseFunc) PublisherOption {
60 return func(p *Publisher) { p.after = append(p.after, after...) }
61 }
62
63 // PublisherTimeout sets the available timeout for an AMQP request.
64 func PublisherTimeout(timeout time.Duration) PublisherOption {
65 return func(p *Publisher) { p.timeout = timeout }
66 }
67
68 // Endpoint returns a usable endpoint that invokes the remote endpoint.
69 func (p Publisher) Endpoint() endpoint.Endpoint {
70 return func(ctx context.Context, request interface{}) (interface{}, error) {
71 ctx, cancel := context.WithTimeout(ctx, p.timeout)
72 defer cancel()
73
74 pub := amqp.Publishing{
75 ReplyTo: p.q.Name,
76 CorrelationId: randomString(randInt(5, maxCorrelationIdLength)),
77 }
78
79 if err := p.enc(ctx, &pub, request); err != nil {
80 return nil, err
81 }
82
83 for _, f := range p.before {
84 ctx = f(ctx, &pub)
85 }
86
87 deliv, err := p.publishAndConsumeFirstMatchingResponse(ctx, &pub)
88 if err != nil {
89 return nil, err
90 }
91
92 for _, f := range p.after {
93 ctx = f(ctx, deliv)
94 }
95 response, err := p.dec(ctx, deliv)
96 if err != nil {
97 return nil, err
98 }
99
100 return response, nil
101 }
102 }
103
104 // publishAndConsumeFirstMatchingResponse publishes the specified Publishing
105 // and returns the first Delivery object with the matching correlationId.
106 // If the context times out while waiting for a reply, an error will be returned.
107 func (p Publisher) publishAndConsumeFirstMatchingResponse(
108 ctx context.Context,
109 pub *amqp.Publishing,
110 ) (*amqp.Delivery, error) {
111 err := p.ch.Publish(
112 getPublishExchange(ctx),
113 getPublishKey(ctx),
114 false, //mandatory
115 false, //immediate
116 *pub,
117 )
118 if err != nil {
119 return nil, err
120 }
121 autoAck := getConsumeAutoAck(ctx)
122
123 msg, err := p.ch.Consume(
124 p.q.Name,
125 "", //consumer
126 autoAck,
127 false, //exclusive
128 false, //noLocal
129 false, //noWait
130 getConsumeArgs(ctx),
131 )
132 if err != nil {
133 return nil, err
134 }
135
136 for {
137 select {
138 case d := <-msg:
139 if d.CorrelationId == pub.CorrelationId {
140 if !autoAck {
141 d.Ack(false) //multiple
142 }
143 return &d, nil
144 }
145
146 case <-ctx.Done():
147 return nil, ctx.Err()
148 }
149 }
150
151 }
0 package amqp_test
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "testing"
7 "time"
8
9 amqptransport "github.com/go-kit/kit/transport/amqp"
10 "github.com/streadway/amqp"
11 )
12
13 var (
14 defaultContentType = ""
15 defaultContentEncoding = ""
16 )
17
18 // TestBadEncode tests if encode errors are handled properly.
19 func TestBadEncode(t *testing.T) {
20 ch := &mockChannel{f: nullFunc}
21 q := &amqp.Queue{Name: "some queue"}
22 pub := amqptransport.NewPublisher(
23 ch,
24 q,
25 func(context.Context, *amqp.Publishing, interface{}) error { return errors.New("err!") },
26 func(context.Context, *amqp.Delivery) (response interface{}, err error) { return struct{}{}, nil },
27 )
28 errChan := make(chan error, 1)
29 var err error
30 go func() {
31 _, err := pub.Endpoint()(context.Background(), struct{}{})
32 errChan <- err
33
34 }()
35 select {
36 case err = <-errChan:
37 break
38
39 case <-time.After(100 * time.Millisecond):
40 t.Fatal("Timed out waiting for result")
41 }
42 if err == nil {
43 t.Error("expected error")
44 }
45 if want, have := "err!", err.Error(); want != have {
46 t.Errorf("want %s, have %s", want, have)
47 }
48 }
49
50 // TestBadDecode tests if decode errors are handled properly.
51 func TestBadDecode(t *testing.T) {
52 cid := "correlation"
53 ch := &mockChannel{
54 f: nullFunc,
55 c: make(chan amqp.Publishing, 1),
56 deliveries: []amqp.Delivery{
57 amqp.Delivery{
58 CorrelationId: cid,
59 },
60 },
61 }
62 q := &amqp.Queue{Name: "some queue"}
63
64 pub := amqptransport.NewPublisher(
65 ch,
66 q,
67 func(context.Context, *amqp.Publishing, interface{}) error { return nil },
68 func(context.Context, *amqp.Delivery) (response interface{}, err error) {
69 return struct{}{}, errors.New("err!")
70 },
71 amqptransport.PublisherBefore(
72 amqptransport.SetCorrelationID(cid),
73 ),
74 )
75
76 var err error
77 errChan := make(chan error, 1)
78 go func() {
79 _, err := pub.Endpoint()(context.Background(), struct{}{})
80 errChan <- err
81
82 }()
83
84 select {
85 case err = <-errChan:
86 break
87
88 case <-time.After(100 * time.Millisecond):
89 t.Fatal("Timed out waiting for result")
90 }
91
92 if err == nil {
93 t.Error("expected error")
94 }
95 if want, have := "err!", err.Error(); want != have {
96 t.Errorf("want %s, have %s", want, have)
97 }
98 }
99
100 // TestPublisherTimeout ensures that the publisher timeout mechanism works.
101 func TestPublisherTimeout(t *testing.T) {
102 ch := &mockChannel{
103 f: nullFunc,
104 c: make(chan amqp.Publishing, 1),
105 deliveries: []amqp.Delivery{}, // no reply from mock subscriber
106 }
107 q := &amqp.Queue{Name: "some queue"}
108
109 pub := amqptransport.NewPublisher(
110 ch,
111 q,
112 func(context.Context, *amqp.Publishing, interface{}) error { return nil },
113 func(context.Context, *amqp.Delivery) (response interface{}, err error) {
114 return struct{}{}, nil
115 },
116 amqptransport.PublisherTimeout(50*time.Millisecond),
117 )
118
119 var err error
120 errChan := make(chan error, 1)
121 go func() {
122 _, err := pub.Endpoint()(context.Background(), struct{}{})
123 errChan <- err
124
125 }()
126
127 select {
128 case err = <-errChan:
129 break
130
131 case <-time.After(100 * time.Millisecond):
132 t.Fatal("timed out waiting for result")
133 }
134
135 if err == nil {
136 t.Error("expected error")
137 }
138 if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have {
139 t.Errorf("want %s, have %s", want, have)
140 }
141 }
142
143 func TestSuccessfulPublisher(t *testing.T) {
144 cid := "correlation"
145 mockReq := testReq{437}
146 mockRes := testRes{
147 Squadron: mockReq.Squadron,
148 Name: names[mockReq.Squadron],
149 }
150 b, err := json.Marshal(mockRes)
151 if err != nil {
152 t.Fatal(err)
153 }
154 reqChan := make(chan amqp.Publishing, 1)
155 ch := &mockChannel{
156 f: nullFunc,
157 c: reqChan,
158 deliveries: []amqp.Delivery{
159 amqp.Delivery{
160 CorrelationId: cid,
161 Body: b,
162 },
163 },
164 }
165 q := &amqp.Queue{Name: "some queue"}
166
167 pub := amqptransport.NewPublisher(
168 ch,
169 q,
170 testReqEncoder,
171 testResDeliveryDecoder,
172 amqptransport.PublisherBefore(
173 amqptransport.SetCorrelationID(cid),
174 ),
175 )
176 var publishing amqp.Publishing
177 var res testRes
178 var ok bool
179 resChan := make(chan interface{}, 1)
180 errChan := make(chan error, 1)
181 go func() {
182 res, err := pub.Endpoint()(context.Background(), mockReq)
183 if err != nil {
184 errChan <- err
185 } else {
186 resChan <- res
187 }
188 }()
189
190 select {
191 case publishing = <-reqChan:
192 break
193
194 case <-time.After(100 * time.Millisecond):
195 t.Fatal("timed out waiting for request")
196 }
197 if want, have := defaultContentType, publishing.ContentType; want != have {
198 t.Errorf("want %s, have %s", want, have)
199 }
200 if want, have := defaultContentEncoding, publishing.ContentEncoding; want != have {
201 t.Errorf("want %s, have %s", want, have)
202 }
203
204 select {
205 case response := <-resChan:
206 res, ok = response.(testRes)
207 if !ok {
208 t.Error("failed to assert endpoint response type")
209 }
210 break
211
212 case err = <-errChan:
213 break
214
215 case <-time.After(100 * time.Millisecond):
216 t.Fatal("timed out waiting for result")
217 }
218
219 if err != nil {
220 t.Fatal(err)
221 }
222 if want, have := mockRes.Name, res.Name; want != have {
223 t.Errorf("want %s, have %s", want, have)
224 }
225 }
0 package amqp
1
2 import (
3 "context"
4 "time"
5
6 "github.com/streadway/amqp"
7 )
8
9 // RequestFunc may take information from a publisher request and put it into a
10 // request context. In Subscribers, RequestFuncs are executed prior to invoking
11 // the endpoint.
12 type RequestFunc func(context.Context, *amqp.Publishing) context.Context
13
14 // SubscriberResponseFunc may take information from a request context and use it to
15 // manipulate a Publisher. SubscriberResponseFuncs are only executed in
16 // subscribers, after invoking the endpoint but prior to publishing a reply.
17 type SubscriberResponseFunc func(context.Context,
18 *amqp.Delivery,
19 Channel,
20 *amqp.Publishing,
21 ) context.Context
22
23 // PublisherResponseFunc may take information from an AMQP request and make the
24 // response available for consumption. PublisherResponseFunc are only executed
25 // in publishers, after a request has been made, but prior to it being decoded.
26 type PublisherResponseFunc func(context.Context, *amqp.Delivery) context.Context
27
28 // SetPublishExchange returns a RequestFunc that sets the Exchange field
29 // of an AMQP Publish call.
30 func SetPublishExchange(publishExchange string) RequestFunc {
31 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
32 return context.WithValue(ctx, ContextKeyExchange, publishExchange)
33 }
34 }
35
36 // SetPublishKey returns a RequestFunc that sets the Key field
37 // of an AMQP Publish call.
38 func SetPublishKey(publishKey string) RequestFunc {
39 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
40 return context.WithValue(ctx, ContextKeyPublishKey, publishKey)
41 }
42 }
43
44 // SetPublishDeliveryMode sets the delivery mode of a Publishing.
45 // Please refer to AMQP delivery mode constants in the AMQP package.
46 func SetPublishDeliveryMode(dmode uint8) RequestFunc {
47 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
48 pub.DeliveryMode = dmode
49 return ctx
50 }
51 }
52
53 // SetNackSleepDuration returns a RequestFunc that sets the amount of time
54 // to sleep in the event of a Nack.
55 // This has to be used in conjunction with an error encoder that Nack and sleeps.
56 // One example is the SingleNackRequeueErrorEncoder.
57 // It is designed to be used by Subscribers.
58 func SetNackSleepDuration(duration time.Duration) RequestFunc {
59 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
60 return context.WithValue(ctx, ContextKeyNackSleepDuration, duration)
61 }
62 }
63
64 // SetConsumeAutoAck returns a RequestFunc that sets whether or not to autoAck
65 // messages when consuming.
66 // When set to false, the publisher will Ack the first message it receives with
67 // a matching correlationId.
68 // It is designed to be used by Publishers.
69 func SetConsumeAutoAck(autoAck bool) RequestFunc {
70 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
71 return context.WithValue(ctx, ContextKeyAutoAck, autoAck)
72 }
73 }
74
75 // SetConsumeArgs returns a RequestFunc that set the arguments for amqp Consume
76 // function.
77 // It is designed to be used by Publishers.
78 func SetConsumeArgs(args amqp.Table) RequestFunc {
79 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
80 return context.WithValue(ctx, ContextKeyConsumeArgs, args)
81 }
82 }
83
84 // SetContentType returns a RequestFunc that sets the ContentType field of
85 // an AMQP Publishing.
86 func SetContentType(contentType string) RequestFunc {
87 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
88 pub.ContentType = contentType
89 return ctx
90 }
91 }
92
93 // SetContentEncoding returns a RequestFunc that sets the ContentEncoding field
94 // of an AMQP Publishing.
95 func SetContentEncoding(contentEncoding string) RequestFunc {
96 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
97 pub.ContentEncoding = contentEncoding
98 return ctx
99 }
100 }
101
102 // SetCorrelationID returns a RequestFunc that sets the CorrelationId field
103 // of an AMQP Publishing.
104 func SetCorrelationID(cid string) RequestFunc {
105 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
106 pub.CorrelationId = cid
107 return ctx
108 }
109 }
110
111 // SetAckAfterEndpoint returns a SubscriberResponseFunc that prompts the service
112 // to Ack the Delivery object after successfully evaluating the endpoint,
113 // and before it encodes the response.
114 // It is designed to be used by Subscribers.
115 func SetAckAfterEndpoint(multiple bool) SubscriberResponseFunc {
116 return func(ctx context.Context,
117 deliv *amqp.Delivery,
118 ch Channel,
119 pub *amqp.Publishing,
120 ) context.Context {
121 deliv.Ack(multiple)
122 return ctx
123 }
124 }
125
126 func getPublishExchange(ctx context.Context) string {
127 if exchange := ctx.Value(ContextKeyExchange); exchange != nil {
128 return exchange.(string)
129 }
130 return ""
131 }
132
133 func getPublishKey(ctx context.Context) string {
134 if publishKey := ctx.Value(ContextKeyPublishKey); publishKey != nil {
135 return publishKey.(string)
136 }
137 return ""
138 }
139
140 func getNackSleepDuration(ctx context.Context) time.Duration {
141 if duration := ctx.Value(ContextKeyNackSleepDuration); duration != nil {
142 return duration.(time.Duration)
143 }
144 return 0
145 }
146
147 func getConsumeAutoAck(ctx context.Context) bool {
148 if autoAck := ctx.Value(ContextKeyAutoAck); autoAck != nil {
149 return autoAck.(bool)
150 }
151 return false
152 }
153
154 func getConsumeArgs(ctx context.Context) amqp.Table {
155 if args := ctx.Value(ContextKeyConsumeArgs); args != nil {
156 return args.(amqp.Table)
157 }
158 return nil
159 }
160
161 type contextKey int
162
163 const (
164 // ContextKeyExchange is the value of the reply Exchange in
165 // amqp.Publish.
166 ContextKeyExchange contextKey = iota
167 // ContextKeyPublishKey is the value of the ReplyTo field in
168 // amqp.Publish.
169 ContextKeyPublishKey
170 // ContextKeyNackSleepDuration is the duration to sleep for if the
171 // service Nack and requeues a message.
172 // This is to prevent sporadic send-resending of message
173 // when a message is constantly Nack'd and requeued.
174 ContextKeyNackSleepDuration
175 // ContextKeyAutoAck is the value of autoAck field when calling
176 // amqp.Channel.Consume.
177 ContextKeyAutoAck
178 // ContextKeyConsumeArgs is the value of consumeArgs field when calling
179 // amqp.Channel.Consume.
180 ContextKeyConsumeArgs
181 )
0 package amqp
1
2 import (
3 "context"
4 "encoding/json"
5 "time"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/log"
9 "github.com/streadway/amqp"
10 )
11
12 // Subscriber wraps an endpoint and provides a handler for AMQP Delivery messages.
13 type Subscriber struct {
14 e endpoint.Endpoint
15 dec DecodeRequestFunc
16 enc EncodeResponseFunc
17 before []RequestFunc
18 after []SubscriberResponseFunc
19 errorEncoder ErrorEncoder
20 logger log.Logger
21 }
22
23 // NewSubscriber constructs a new subscriber, which provides a handler
24 // for AMQP Delivery messages.
25 func NewSubscriber(
26 e endpoint.Endpoint,
27 dec DecodeRequestFunc,
28 enc EncodeResponseFunc,
29 options ...SubscriberOption,
30 ) *Subscriber {
31 s := &Subscriber{
32 e: e,
33 dec: dec,
34 enc: enc,
35 errorEncoder: DefaultErrorEncoder,
36 logger: log.NewNopLogger(),
37 }
38 for _, option := range options {
39 option(s)
40 }
41 return s
42 }
43
44 // SubscriberOption sets an optional parameter for subscribers.
45 type SubscriberOption func(*Subscriber)
46
47 // SubscriberBefore functions are executed on the publisher delivery object
48 // before the request is decoded.
49 func SubscriberBefore(before ...RequestFunc) SubscriberOption {
50 return func(s *Subscriber) { s.before = append(s.before, before...) }
51 }
52
53 // SubscriberAfter functions are executed on the subscriber reply after the
54 // endpoint is invoked, but before anything is published to the reply.
55 func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption {
56 return func(s *Subscriber) { s.after = append(s.after, after...) }
57 }
58
59 // SubscriberErrorEncoder is used to encode errors to the subscriber reply
60 // whenever they're encountered in the processing of a request. Clients can
61 // use this to provide custom error formatting. By default,
62 // errors will be published with the DefaultErrorEncoder.
63 func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption {
64 return func(s *Subscriber) { s.errorEncoder = ee }
65 }
66
67 // SubscriberErrorLogger is used to log non-terminal errors. By default, no errors
68 // are logged. This is intended as a diagnostic measure. Finer-grained control
69 // of error handling, including logging in more detail, should be performed in a
70 // custom SubscriberErrorEncoder which has access to the context.
71 func SubscriberErrorLogger(logger log.Logger) SubscriberOption {
72 return func(s *Subscriber) { s.logger = logger }
73 }
74
75 // ServeDelivery handles AMQP Delivery messages
76 // It is strongly recommended to use *amqp.Channel as the
77 // Channel interface implementation.
78 func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) {
79 return func(deliv *amqp.Delivery) {
80 ctx, cancel := context.WithCancel(context.Background())
81 defer cancel()
82
83 pub := amqp.Publishing{}
84
85 for _, f := range s.before {
86 ctx = f(ctx, &pub)
87 }
88
89 request, err := s.dec(ctx, deliv)
90 if err != nil {
91 s.logger.Log("err", err)
92 s.errorEncoder(ctx, err, deliv, ch, &pub)
93 return
94 }
95
96 response, err := s.e(ctx, request)
97 if err != nil {
98 s.logger.Log("err", err)
99 s.errorEncoder(ctx, err, deliv, ch, &pub)
100 return
101 }
102
103 for _, f := range s.after {
104 ctx = f(ctx, deliv, ch, &pub)
105 }
106
107 if err := s.enc(ctx, &pub, response); err != nil {
108 s.logger.Log("err", err)
109 s.errorEncoder(ctx, err, deliv, ch, &pub)
110 return
111 }
112
113 if err := s.publishResponse(ctx, deliv, ch, &pub); err != nil {
114 s.logger.Log("err", err)
115 s.errorEncoder(ctx, err, deliv, ch, &pub)
116 return
117 }
118 }
119
120 }
121
122 func (s Subscriber) publishResponse(
123 ctx context.Context,
124 deliv *amqp.Delivery,
125 ch Channel,
126 pub *amqp.Publishing,
127 ) error {
128 if pub.CorrelationId == "" {
129 pub.CorrelationId = deliv.CorrelationId
130 }
131
132 replyExchange := getPublishExchange(ctx)
133 replyTo := getPublishKey(ctx)
134 if replyTo == "" {
135 replyTo = deliv.ReplyTo
136 }
137
138 return ch.Publish(
139 replyExchange,
140 replyTo,
141 false, // mandatory
142 false, // immediate
143 *pub,
144 )
145 }
146
147 // EncodeJSONResponse marshals the response as JSON as part of the
148 // payload of the AMQP Publishing object.
149 func EncodeJSONResponse(
150 ctx context.Context,
151 pub *amqp.Publishing,
152 response interface{},
153 ) error {
154 b, err := json.Marshal(response)
155 if err != nil {
156 return err
157 }
158 pub.Body = b
159 return nil
160 }
161
162 // EncodeNopResponse is a response function that does nothing.
163 func EncodeNopResponse(
164 ctx context.Context,
165 pub *amqp.Publishing,
166 response interface{},
167 ) error {
168 return nil
169 }
170
171 // ErrorEncoder is responsible for encoding an error to the subscriber reply.
172 // Users are encouraged to use custom ErrorEncoders to encode errors to
173 // their replies, and will likely want to pass and check for their own error
174 // types.
175 type ErrorEncoder func(ctx context.Context,
176 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing)
177
178 // DefaultErrorEncoder simply ignores the message. It does not reply
179 // nor Ack/Nack the message.
180 func DefaultErrorEncoder(ctx context.Context,
181 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
182 }
183
184 // SingleNackRequeueErrorEncoder issues a Nack to the delivery with multiple flag set as false
185 // and requeue flag set as true. It does not reply the message.
186 func SingleNackRequeueErrorEncoder(ctx context.Context,
187 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
188 deliv.Nack(
189 false, //multiple
190 true, //requeue
191 )
192 duration := getNackSleepDuration(ctx)
193 time.Sleep(duration)
194 }
195
196 // ReplyErrorEncoder serializes the error message as a DefaultErrorResponse
197 // JSON and sends the message to the ReplyTo address.
198 func ReplyErrorEncoder(
199 ctx context.Context,
200 err error,
201 deliv *amqp.Delivery,
202 ch Channel,
203 pub *amqp.Publishing,
204 ) {
205
206 if pub.CorrelationId == "" {
207 pub.CorrelationId = deliv.CorrelationId
208 }
209
210 replyExchange := getPublishExchange(ctx)
211 replyTo := getPublishKey(ctx)
212 if replyTo == "" {
213 replyTo = deliv.ReplyTo
214 }
215
216 response := DefaultErrorResponse{err.Error()}
217
218 b, err := json.Marshal(response)
219 if err != nil {
220 return
221 }
222 pub.Body = b
223
224 ch.Publish(
225 replyExchange,
226 replyTo,
227 false, // mandatory
228 false, // immediate
229 *pub,
230 )
231 }
232
233 // ReplyAndAckErrorEncoder serializes the error message as a DefaultErrorResponse
234 // JSON and sends the message to the ReplyTo address then Acks the original
235 // message.
236 func ReplyAndAckErrorEncoder(ctx context.Context, err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
237 ReplyErrorEncoder(ctx, err, deliv, ch, pub)
238 deliv.Ack(false)
239 }
240
241 // DefaultErrorResponse is the default structure of responses in the event
242 // of an error.
243 type DefaultErrorResponse struct {
244 Error string `json:"err"`
245 }
246
247 // Channel is a channel interface to make testing possible.
248 // It is highly recommended to use *amqp.Channel as the interface implementation.
249 type Channel interface {
250 Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error
251 Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error)
252 }
0 package amqp_test
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "testing"
7 "time"
8
9 amqptransport "github.com/go-kit/kit/transport/amqp"
10 "github.com/streadway/amqp"
11 )
12
13 var (
14 typeAssertionError = errors.New("type assertion error")
15 )
16
17 // mockChannel is a mock of *amqp.Channel.
18 type mockChannel struct {
19 f func(exchange, key string, mandatory, immediate bool)
20 c chan<- amqp.Publishing
21 deliveries []amqp.Delivery
22 }
23
24 // Publish runs a test function f and sends resultant message to a channel.
25 func (ch *mockChannel) Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error {
26 ch.f(exchange, key, mandatory, immediate)
27 ch.c <- msg
28 return nil
29 }
30
31 var nullFunc = func(exchange, key string, mandatory, immediate bool) {
32 }
33
34 func (ch *mockChannel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) {
35 c := make(chan amqp.Delivery, len(ch.deliveries))
36 for _, d := range ch.deliveries {
37 c <- d
38 }
39 return c, nil
40 }
41
42 // TestSubscriberBadDecode checks if decoder errors are handled properly.
43 func TestSubscriberBadDecode(t *testing.T) {
44 sub := amqptransport.NewSubscriber(
45 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
46 func(context.Context, *amqp.Delivery) (interface{}, error) { return nil, errors.New("err!") },
47 func(context.Context, *amqp.Publishing, interface{}) error {
48 return nil
49 },
50 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
51 )
52
53 outputChan := make(chan amqp.Publishing, 1)
54 ch := &mockChannel{f: nullFunc, c: outputChan}
55 sub.ServeDelivery(ch)(&amqp.Delivery{})
56
57 var msg amqp.Publishing
58 select {
59 case msg = <-outputChan:
60 break
61
62 case <-time.After(100 * time.Millisecond):
63 t.Fatal("Timed out waiting for publishing")
64 }
65 res, err := decodeSubscriberError(msg)
66 if err != nil {
67 t.Fatal(err)
68 }
69 if want, have := "err!", res.Error; want != have {
70 t.Errorf("want %s, have %s", want, have)
71 }
72 }
73
74 // TestSubscriberBadEndpoint checks if endpoint errors are handled properly.
75 func TestSubscriberBadEndpoint(t *testing.T) {
76 sub := amqptransport.NewSubscriber(
77 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") },
78 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
79 func(context.Context, *amqp.Publishing, interface{}) error {
80 return nil
81 },
82 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
83 )
84
85 outputChan := make(chan amqp.Publishing, 1)
86 ch := &mockChannel{f: nullFunc, c: outputChan}
87 sub.ServeDelivery(ch)(&amqp.Delivery{})
88
89 var msg amqp.Publishing
90
91 select {
92 case msg = <-outputChan:
93 break
94
95 case <-time.After(100 * time.Millisecond):
96 t.Fatal("Timed out waiting for publishing")
97 }
98
99 res, err := decodeSubscriberError(msg)
100 if err != nil {
101 t.Fatal(err)
102 }
103 if want, have := "err!", res.Error; want != have {
104 t.Errorf("want %s, have %s", want, have)
105 }
106 }
107
108 // TestSubscriberBadEncoder checks if encoder errors are handled properly.
109 func TestSubscriberBadEncoder(t *testing.T) {
110 sub := amqptransport.NewSubscriber(
111 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
112 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
113 func(context.Context, *amqp.Publishing, interface{}) error {
114 return errors.New("err!")
115 },
116 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
117 )
118
119 outputChan := make(chan amqp.Publishing, 1)
120 ch := &mockChannel{f: nullFunc, c: outputChan}
121 sub.ServeDelivery(ch)(&amqp.Delivery{})
122
123 var msg amqp.Publishing
124
125 select {
126 case msg = <-outputChan:
127 break
128
129 case <-time.After(100 * time.Millisecond):
130 t.Fatal("Timed out waiting for publishing")
131 }
132
133 res, err := decodeSubscriberError(msg)
134 if err != nil {
135 t.Fatal(err)
136 }
137 if want, have := "err!", res.Error; want != have {
138 t.Errorf("want %s, have %s", want, have)
139 }
140 }
141
142 // TestSubscriberSuccess checks if CorrelationId and ReplyTo are set properly
143 // and if the payload is encoded properly.
144 func TestSubscriberSuccess(t *testing.T) {
145 cid := "correlation"
146 replyTo := "sender"
147 obj := testReq{
148 Squadron: 436,
149 }
150 b, err := json.Marshal(obj)
151 if err != nil {
152 t.Fatal(err)
153 }
154
155 sub := amqptransport.NewSubscriber(
156 testEndpoint,
157 testReqDecoder,
158 amqptransport.EncodeJSONResponse,
159 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
160 )
161
162 checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) {
163 if want, have := replyTo, key; want != have {
164 t.Errorf("want %s, have %s", want, have)
165 }
166 }
167
168 outputChan := make(chan amqp.Publishing, 1)
169 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
170 sub.ServeDelivery(ch)(&amqp.Delivery{
171 CorrelationId: cid,
172 ReplyTo: replyTo,
173 Body: b,
174 })
175
176 var msg amqp.Publishing
177
178 select {
179 case msg = <-outputChan:
180 break
181
182 case <-time.After(100 * time.Millisecond):
183 t.Fatal("Timed out waiting for publishing")
184 }
185
186 if want, have := cid, msg.CorrelationId; want != have {
187 t.Errorf("want %s, have %s", want, have)
188 }
189
190 // check if error is not thrown
191 errRes, err := decodeSubscriberError(msg)
192 if err != nil {
193 t.Fatal(err)
194 }
195 if errRes.Error != "" {
196 t.Error("Received error from subscriber", errRes.Error)
197 return
198 }
199
200 // check obj vals
201 response, err := testResDecoder(msg.Body)
202 if err != nil {
203 t.Fatal(err)
204 }
205 res, ok := response.(testRes)
206 if !ok {
207 t.Error(typeAssertionError)
208 }
209
210 if want, have := obj.Squadron, res.Squadron; want != have {
211 t.Errorf("want %d, have %d", want, have)
212 }
213 if want, have := names[obj.Squadron], res.Name; want != have {
214 t.Errorf("want %s, have %s", want, have)
215 }
216 }
217
218 // TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode
219 // are working.
220 func TestSubscriberMultipleBefore(t *testing.T) {
221 exchange := "some exchange"
222 key := "some key"
223 deliveryMode := uint8(127)
224 contentType := "some content type"
225 contentEncoding := "some content encoding"
226 sub := amqptransport.NewSubscriber(
227 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
228 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
229 amqptransport.EncodeJSONResponse,
230 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
231 amqptransport.SubscriberBefore(
232 amqptransport.SetPublishExchange(exchange),
233 amqptransport.SetPublishKey(key),
234 amqptransport.SetPublishDeliveryMode(deliveryMode),
235 amqptransport.SetContentType(contentType),
236 amqptransport.SetContentEncoding(contentEncoding),
237 ),
238 )
239 checkReplyToFunc := func(exch, k string, mandatory, immediate bool) {
240 if want, have := exchange, exch; want != have {
241 t.Errorf("want %s, have %s", want, have)
242 }
243 if want, have := key, k; want != have {
244 t.Errorf("want %s, have %s", want, have)
245 }
246 }
247
248 outputChan := make(chan amqp.Publishing, 1)
249 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
250 sub.ServeDelivery(ch)(&amqp.Delivery{})
251
252 var msg amqp.Publishing
253
254 select {
255 case msg = <-outputChan:
256 break
257
258 case <-time.After(100 * time.Millisecond):
259 t.Fatal("Timed out waiting for publishing")
260 }
261
262 // check if error is not thrown
263 errRes, err := decodeSubscriberError(msg)
264 if err != nil {
265 t.Fatal(err)
266 }
267 if errRes.Error != "" {
268 t.Error("Received error from subscriber", errRes.Error)
269 return
270 }
271
272 if want, have := contentType, msg.ContentType; want != have {
273 t.Errorf("want %s, have %s", want, have)
274 }
275
276 if want, have := contentEncoding, msg.ContentEncoding; want != have {
277 t.Errorf("want %s, have %s", want, have)
278 }
279
280 if want, have := deliveryMode, msg.DeliveryMode; want != have {
281 t.Errorf("want %d, have %d", want, have)
282 }
283 }
284
285 // TestDefaultContentMetaData checks that default ContentType and Content-Encoding
286 // is not set as mentioned by AMQP specification.
287 func TestDefaultContentMetaData(t *testing.T) {
288 defaultContentType := ""
289 defaultContentEncoding := ""
290 sub := amqptransport.NewSubscriber(
291 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
292 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
293 amqptransport.EncodeJSONResponse,
294 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
295 )
296 checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { return }
297 outputChan := make(chan amqp.Publishing, 1)
298 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
299 sub.ServeDelivery(ch)(&amqp.Delivery{})
300
301 var msg amqp.Publishing
302
303 select {
304 case msg = <-outputChan:
305 break
306
307 case <-time.After(100 * time.Millisecond):
308 t.Fatal("Timed out waiting for publishing")
309 }
310
311 // check if error is not thrown
312 errRes, err := decodeSubscriberError(msg)
313 if err != nil {
314 t.Fatal(err)
315 }
316 if errRes.Error != "" {
317 t.Error("Received error from subscriber", errRes.Error)
318 return
319 }
320
321 if want, have := defaultContentType, msg.ContentType; want != have {
322 t.Errorf("want %s, have %s", want, have)
323 }
324 if want, have := defaultContentEncoding, msg.ContentEncoding; want != have {
325 t.Errorf("want %s, have %s", want, have)
326 }
327 }
328
329 func decodeSubscriberError(pub amqp.Publishing) (amqptransport.DefaultErrorResponse, error) {
330 var res amqptransport.DefaultErrorResponse
331 err := json.Unmarshal(pub.Body, &res)
332 return res, err
333 }
334
335 type testReq struct {
336 Squadron int `json:"s"`
337 }
338 type testRes struct {
339 Squadron int `json:"s"`
340 Name string `json:"n"`
341 }
342
343 func testEndpoint(_ context.Context, request interface{}) (interface{}, error) {
344 req, ok := request.(testReq)
345 if !ok {
346 return nil, typeAssertionError
347 }
348 name, prs := names[req.Squadron]
349 if !prs {
350 return nil, errors.New("unknown squadron name")
351 }
352 res := testRes{
353 Squadron: req.Squadron,
354 Name: name,
355 }
356 return res, nil
357 }
358
359 func testReqDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
360 var obj testReq
361 err := json.Unmarshal(d.Body, &obj)
362 return obj, err
363 }
364
365 func testReqEncoder(_ context.Context, p *amqp.Publishing, request interface{}) error {
366 req, ok := request.(testReq)
367 if !ok {
368 return errors.New("type assertion failure")
369 }
370 b, err := json.Marshal(req)
371 if err != nil {
372 return err
373 }
374 p.Body = b
375 return nil
376 }
377
378 func testResDeliveryDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
379 return testResDecoder(d.Body)
380 }
381
382 func testResDecoder(b []byte) (interface{}, error) {
383 var obj testRes
384 err := json.Unmarshal(b, &obj)
385 return obj, err
386 }
387
388 var names = map[int]string{
389 424: "tiger",
390 426: "thunderbird",
391 429: "bison",
392 436: "tusker",
393 437: "husky",
394 }
0 package amqp
1
2 import (
3 "math/rand"
4 )
5
6 func randomString(l int) string {
7 bytes := make([]byte, l)
8 for i := 0; i < l; i++ {
9 bytes[i] = byte(randInt(65, 90))
10 }
11 return string(bytes)
12 }
13
14 func randInt(min int, max int) int {
15 return min + rand.Intn(max-min)
16 }