Merge pull request #370 from rossmcf/retry_cb
Retry callbacks passing errors.
Peter Bourgon authored 7 years ago
GitHub committed 7 years ago
9 | 9 | "github.com/go-kit/kit/endpoint" |
10 | 10 | ) |
11 | 11 | |
12 | // RetryError is an error wrapper that is used by the retry mechanism. All | |
13 | // errors returned by the retry mechanism via its endpoint will be RetryErrors. | |
14 | type RetryError struct { | |
15 | RawErrors []error // all errors encountered from endpoints directly | |
16 | Final error // the final, terminating error | |
17 | } | |
18 | ||
19 | func (e RetryError) Error() string { | |
20 | var suffix string | |
21 | if len(e.RawErrors) > 1 { | |
22 | a := make([]string, len(e.RawErrors)-1) | |
23 | for i := 0; i < len(e.RawErrors)-1; i++ { // last one is Final | |
24 | a[i] = e.RawErrors[i].Error() | |
25 | } | |
26 | suffix = fmt.Sprintf(" (previously: %s)", strings.Join(a, "; ")) | |
27 | } | |
28 | return fmt.Sprintf("%v%s", e.Final, suffix) | |
29 | } | |
30 | ||
31 | // Callback is a function that is given the current attempt count and the error | |
32 | // received from the underlying endpoint. It should return whether the Retry | |
33 | // function should continue trying to get a working endpoint, and a custom error | |
34 | // if desired. The error message may be nil, but a true/false is always | |
35 | // expected. In all cases, if the replacement error is supplied, the received | |
36 | // error will be replaced in the calling context. | |
37 | type Callback func(n int, received error) (keepTrying bool, replacement error) | |
38 | ||
12 | 39 | // Retry wraps a service load balancer and returns an endpoint oriented load |
13 | // balancer for the specified service method. | |
14 | // Requests to the endpoint will be automatically load balanced via the load | |
15 | // balancer. Requests that return errors will be retried until they succeed, | |
16 | // up to max times, or until the timeout is elapsed, whichever comes first. | |
40 | // balancer for the specified service method. Requests to the endpoint will be | |
41 | // automatically load balanced via the load balancer. Requests that return | |
42 | // errors will be retried until they succeed, up to max times, or until the | |
43 | // timeout is elapsed, whichever comes first. | |
17 | 44 | func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint { |
45 | return RetryWithCallback(timeout, b, maxRetries(max)) | |
46 | } | |
47 | ||
48 | func maxRetries(max int) Callback { | |
49 | return func(n int, err error) (keepTrying bool, replacement error) { | |
50 | return n < max, nil | |
51 | } | |
52 | } | |
53 | ||
54 | func alwaysRetry(int, error) (keepTrying bool, replacement error) { | |
55 | return true, nil | |
56 | } | |
57 | ||
58 | // RetryWithCallback wraps a service load balancer and returns an endpoint | |
59 | // oriented load balancer for the specified service method. Requests to the | |
60 | // endpoint will be automatically load balanced via the load balancer. Requests | |
61 | // that return errors will be retried until they succeed, up to max times, until | |
62 | // the callback returns false, or until the timeout is elapsed, whichever comes | |
63 | // first. | |
64 | func RetryWithCallback(timeout time.Duration, b Balancer, cb Callback) endpoint.Endpoint { | |
65 | if cb == nil { | |
66 | cb = alwaysRetry | |
67 | } | |
18 | 68 | if b == nil { |
19 | 69 | panic("nil Balancer") |
20 | 70 | } |
71 | ||
21 | 72 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { |
22 | 73 | var ( |
23 | 74 | newctx, cancel = context.WithTimeout(ctx, timeout) |
24 | 75 | responses = make(chan interface{}, 1) |
25 | 76 | errs = make(chan error, 1) |
26 | a = []string{} | |
77 | final RetryError | |
27 | 78 | ) |
28 | 79 | defer cancel() |
29 | for i := 1; i <= max; i++ { | |
80 | ||
81 | for i := 1; ; i++ { | |
30 | 82 | go func() { |
31 | 83 | e, err := b.Endpoint() |
32 | 84 | if err != nil { |
44 | 96 | select { |
45 | 97 | case <-newctx.Done(): |
46 | 98 | return nil, newctx.Err() |
99 | ||
47 | 100 | case response := <-responses: |
48 | 101 | return response, nil |
102 | ||
49 | 103 | case err := <-errs: |
50 | a = append(a, err.Error()) | |
104 | final.RawErrors = append(final.RawErrors, err) | |
105 | keepTrying, replacement := cb(i, err) | |
106 | if replacement != nil { | |
107 | err = replacement | |
108 | } | |
109 | if !keepTrying { | |
110 | final.Final = err | |
111 | return nil, final | |
112 | } | |
51 | 113 | continue |
52 | 114 | } |
53 | 115 | } |
54 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
55 | 116 | } |
56 | 117 | } |
8 | 8 | |
9 | 9 | "github.com/go-kit/kit/endpoint" |
10 | 10 | "github.com/go-kit/kit/sd" |
11 | loadbalancer "github.com/go-kit/kit/sd/lb" | |
11 | "github.com/go-kit/kit/sd/lb" | |
12 | 12 | ) |
13 | 13 | |
14 | 14 | func TestRetryMaxTotalFail(t *testing.T) { |
15 | 15 | var ( |
16 | 16 | endpoints = sd.FixedSubscriber{} // no endpoints |
17 | lb = loadbalancer.NewRoundRobin(endpoints) | |
18 | retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries | |
17 | rr = lb.NewRoundRobin(endpoints) | |
18 | retry = lb.Retry(999, time.Second, rr) // lots of retries | |
19 | 19 | ctx = context.Background() |
20 | 20 | ) |
21 | 21 | if _, err := retry(ctx, struct{}{}); err == nil { |
36 | 36 | 2: endpoints[2], |
37 | 37 | } |
38 | 38 | retries = len(endpoints) - 1 // not quite enough retries |
39 | lb = loadbalancer.NewRoundRobin(subscriber) | |
39 | rr = lb.NewRoundRobin(subscriber) | |
40 | 40 | ctx = context.Background() |
41 | 41 | ) |
42 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { | |
43 | t.Errorf("expected error, got none") | |
42 | if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err == nil { | |
43 | t.Errorf("expected error two, got none") | |
44 | 44 | } |
45 | 45 | } |
46 | 46 | |
57 | 57 | 2: endpoints[2], |
58 | 58 | } |
59 | 59 | retries = len(endpoints) // exactly enough retries |
60 | lb = loadbalancer.NewRoundRobin(subscriber) | |
60 | rr = lb.NewRoundRobin(subscriber) | |
61 | 61 | ctx = context.Background() |
62 | 62 | ) |
63 | if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { | |
63 | if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err != nil { | |
64 | 64 | t.Error(err) |
65 | 65 | } |
66 | 66 | } |
70 | 70 | step = make(chan struct{}) |
71 | 71 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } |
72 | 72 | timeout = time.Millisecond |
73 | retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e})) | |
73 | retry = lb.Retry(999, timeout, lb.NewRoundRobin(sd.FixedSubscriber{0: e})) | |
74 | 74 | errs = make(chan error, 1) |
75 | 75 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } |
76 | 76 | ) |
87 | 87 | t.Errorf("wanted %v, got none", context.DeadlineExceeded) |
88 | 88 | } |
89 | 89 | } |
90 | ||
91 | func TestAbortEarlyCustomMessage(t *testing.T) { | |
92 | var ( | |
93 | myErr = errors.New("aborting early") | |
94 | cb = func(int, error) (bool, error) { return false, myErr } | |
95 | endpoints = sd.FixedSubscriber{} // no endpoints | |
96 | rr = lb.NewRoundRobin(endpoints) | |
97 | retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries | |
98 | ctx = context.Background() | |
99 | ) | |
100 | _, err := retry(ctx, struct{}{}) | |
101 | if want, have := myErr, err.(lb.RetryError).Final; want != have { | |
102 | t.Errorf("want %v, have %v", want, have) | |
103 | } | |
104 | } | |
105 | ||
106 | func TestErrorPassedUnchangedToCallback(t *testing.T) { | |
107 | var ( | |
108 | myErr = errors.New("my custom error") | |
109 | cb = func(_ int, err error) (bool, error) { | |
110 | if want, have := myErr, err; want != have { | |
111 | t.Errorf("want %v, have %v", want, have) | |
112 | } | |
113 | return false, nil | |
114 | } | |
115 | endpoint = func(ctx context.Context, request interface{}) (interface{}, error) { | |
116 | return nil, myErr | |
117 | } | |
118 | endpoints = sd.FixedSubscriber{endpoint} // no endpoints | |
119 | rr = lb.NewRoundRobin(endpoints) | |
120 | retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries | |
121 | ctx = context.Background() | |
122 | ) | |
123 | _, err := retry(ctx, struct{}{}) | |
124 | if want, have := myErr, err.(lb.RetryError).Final; want != have { | |
125 | t.Errorf("want %v, have %v", want, have) | |
126 | } | |
127 | } | |
128 | ||
129 | func TestHandleNilCallback(t *testing.T) { | |
130 | var ( | |
131 | subscriber = sd.FixedSubscriber{ | |
132 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
133 | } | |
134 | rr = lb.NewRoundRobin(subscriber) | |
135 | ctx = context.Background() | |
136 | ) | |
137 | retry := lb.RetryWithCallback(time.Second, rr, nil) | |
138 | if _, err := retry(ctx, struct{}{}); err != nil { | |
139 | t.Error(err) | |
140 | } | |
141 | } |