Added RetyWithCallback allowing retries to call a function before iterating for early termination.
Morgan Hein
7 years ago
0 | 0 | package lb |
1 | 1 | |
2 | 2 | import ( |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
3 | "fmt" | |
4 | "strings" | |
5 | "time" | |
6 | 6 | |
7 | "golang.org/x/net/context" | |
7 | "golang.org/x/net/context" | |
8 | 8 | |
9 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/endpoint" | |
10 | 10 | ) |
11 | ||
12 | // Callback function that indicates the current attempt count and the error encountered. | |
13 | // Should return whether the Retry function should continue trying, and a custom | |
14 | // error message if desired. The error message may be nil, but a true/false | |
15 | // is always expected. In all cases if the error message is supplied, the | |
16 | // current error will be replaced. | |
17 | type callback func(int, string) (bool, *string) | |
11 | 18 | |
12 | 19 | // Retry wraps a service load balancer and returns an endpoint oriented load |
13 | 20 | // balancer for the specified service method. |
15 | 22 | // balancer. Requests that return errors will be retried until they succeed, |
16 | 23 | // up to max times, or until the timeout is elapsed, whichever comes first. |
17 | 24 | func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint { |
18 | if b == nil { | |
19 | panic("nil Balancer") | |
20 | } | |
21 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
22 | var ( | |
23 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
24 | responses = make(chan interface{}, 1) | |
25 | errs = make(chan error, 1) | |
26 | a = []string{} | |
27 | ) | |
28 | defer cancel() | |
29 | for i := 1; i <= max; i++ { | |
30 | go func() { | |
31 | e, err := b.Endpoint() | |
32 | if err != nil { | |
33 | errs <- err | |
34 | return | |
35 | } | |
36 | response, err := e(newctx, request) | |
37 | if err != nil { | |
38 | errs <- err | |
39 | return | |
40 | } | |
41 | responses <- response | |
42 | }() | |
25 | if b == nil { | |
26 | panic("nil Balancer") | |
27 | } | |
28 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
29 | var ( | |
30 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
31 | responses = make(chan interface{}, 1) | |
32 | errs = make(chan error, 1) | |
33 | a = []string{} | |
34 | ) | |
35 | defer cancel() | |
36 | for i := 1; i <= max; i++ { | |
37 | go func() { | |
38 | e, err := b.Endpoint() | |
39 | if err != nil { | |
40 | errs <- err | |
41 | return | |
42 | } | |
43 | response, err := e(newctx, request) | |
44 | if err != nil { | |
45 | errs <- err | |
46 | return | |
47 | } | |
48 | responses <- response | |
49 | }() | |
43 | 50 | |
44 | select { | |
45 | case <-newctx.Done(): | |
46 | return nil, newctx.Err() | |
47 | case response := <-responses: | |
48 | return response, nil | |
49 | case err := <-errs: | |
50 | a = append(a, err.Error()) | |
51 | continue | |
52 | } | |
53 | } | |
54 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
55 | } | |
51 | select { | |
52 | case <-newctx.Done(): | |
53 | return nil, newctx.Err() | |
54 | case response := <-responses: | |
55 | return response, nil | |
56 | case err := <-errs: | |
57 | a = append(a, err.Error()) | |
58 | continue | |
59 | } | |
60 | } | |
61 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
62 | } | |
56 | 63 | } |
64 | ||
65 | func RetryWithCallback(max int, timeout time.Duration, b Balancer, cb callback) endpoint.Endpoint { | |
66 | if cb == nil { | |
67 | panic("nil callback") | |
68 | } | |
69 | if b == nil { | |
70 | panic("nil Balancer") | |
71 | } | |
72 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
73 | var ( | |
74 | newctx, cancel = context.WithTimeout(ctx, timeout) | |
75 | responses = make(chan interface{}, 1) | |
76 | errs = make(chan error, 1) | |
77 | a = []string{} | |
78 | ) | |
79 | defer cancel() | |
80 | for i := 1; i <= max; i++ { | |
81 | go func() { | |
82 | e, err := b.Endpoint() | |
83 | if err != nil { | |
84 | errs <- err | |
85 | return | |
86 | } | |
87 | response, err := e(newctx, request) | |
88 | if err != nil { | |
89 | errs <- err | |
90 | return | |
91 | } | |
92 | responses <- response | |
93 | }() | |
94 | ||
95 | select { | |
96 | case <-newctx.Done(): | |
97 | return nil, newctx.Err() | |
98 | case response := <-responses: | |
99 | return response, nil | |
100 | case err := <-errs: | |
101 | cont, cbErr := cb(i, err.Error()) | |
102 | if !cont { | |
103 | if cbErr == nil { | |
104 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
105 | } else { | |
106 | return nil, fmt.Errorf(*cbErr) | |
107 | } | |
108 | } | |
109 | currentErr := err.Error() | |
110 | if cbErr != nil { | |
111 | currentErr = *cbErr | |
112 | } | |
113 | a = append(a, currentErr) | |
114 | continue | |
115 | } | |
116 | } | |
117 | return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; ")) | |
118 | } | |
119 | } |
0 | package lb_test | |
1 | ||
2 | import ( | |
3 | "errors" | |
4 | "testing" | |
5 | "time" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ||
9 | "github.com/go-kit/kit/endpoint" | |
10 | "github.com/go-kit/kit/sd" | |
11 | loadbalancer "github.com/go-kit/kit/sd/lb" | |
12 | ) | |
13 | ||
14 | func TestRetryMaxTotalFail_WCB(t *testing.T) { | |
15 | var ( | |
16 | cb = func(count int, msg string) (bool, *string) { return true, nil } | |
17 | endpoints = sd.FixedSubscriber{} // no endpoints | |
18 | lb = loadbalancer.NewRoundRobin(endpoints) | |
19 | retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries | |
20 | ctx = context.Background() | |
21 | ) | |
22 | if _, err := retry(ctx, struct{}{}); err == nil { | |
23 | t.Errorf("expected error, got none") // should fail | |
24 | } | |
25 | } | |
26 | ||
27 | func TestRetryMaxPartialFail_WCB(t *testing.T) { | |
28 | var ( | |
29 | cb = func(count int, msg string) (bool, *string) { return true, nil } | |
30 | endpoints = []endpoint.Endpoint{ | |
31 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
32 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
33 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
34 | } | |
35 | subscriber = sd.FixedSubscriber{ | |
36 | 0: endpoints[0], | |
37 | 1: endpoints[1], | |
38 | 2: endpoints[2], | |
39 | } | |
40 | retries = len(endpoints) - 1 // not quite enough retries | |
41 | lb = loadbalancer.NewRoundRobin(subscriber) | |
42 | ctx = context.Background() | |
43 | ) | |
44 | if _, err := loadbalancer.RetryWithCallback(retries, time.Second, lb, cb)(ctx, struct{}{}); err == nil { | |
45 | t.Errorf("expected error, got none") | |
46 | } | |
47 | } | |
48 | ||
49 | func TestRetryMaxSuccess_WCB(t *testing.T) { | |
50 | var ( | |
51 | cb = func(count int, msg string) (bool, *string) { return true, nil } | |
52 | endpoints = []endpoint.Endpoint{ | |
53 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") }, | |
54 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, | |
55 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, | |
56 | } | |
57 | subscriber = sd.FixedSubscriber{ | |
58 | 0: endpoints[0], | |
59 | 1: endpoints[1], | |
60 | 2: endpoints[2], | |
61 | } | |
62 | retries = len(endpoints) // exactly enough retries | |
63 | lb = loadbalancer.NewRoundRobin(subscriber) | |
64 | ctx = context.Background() | |
65 | ) | |
66 | if _, err := loadbalancer.RetryWithCallback(retries, time.Second, lb, cb)(ctx, struct{}{}); err != nil { | |
67 | t.Error(err) | |
68 | } | |
69 | } | |
70 | ||
71 | func TestRetryTimeout_WCB(t *testing.T) { | |
72 | var ( | |
73 | cb = func(count int, msg string) (bool, *string) { return true, nil } | |
74 | step = make(chan struct{}) | |
75 | e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } | |
76 | timeout = time.Millisecond | |
77 | retry = loadbalancer.RetryWithCallback(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e}), cb) | |
78 | errs = make(chan error, 1) | |
79 | invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } | |
80 | ) | |
81 | ||
82 | go func() { step <- struct{}{} }() // queue up a flush of the endpoint | |
83 | invoke() // invoke the endpoint and trigger the flush | |
84 | if err := <-errs; err != nil { // that should succeed | |
85 | t.Error(err) | |
86 | } | |
87 | ||
88 | go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush | |
89 | invoke() // invoke the endpoint | |
90 | if err := <-errs; err != context.DeadlineExceeded { // that should not succeed | |
91 | t.Errorf("wanted %v, got none", context.DeadlineExceeded) | |
92 | } | |
93 | } | |
94 | ||
95 | func AbortEarlyCustomMessage_WCB(t *testing.T) { | |
96 | var ( | |
97 | cb = func(count int, msg string) (bool, *string) { | |
98 | ret := "Aborting early" | |
99 | return false, &ret | |
100 | } | |
101 | endpoints = sd.FixedSubscriber{} // no endpoints | |
102 | lb = loadbalancer.NewRoundRobin(endpoints) | |
103 | retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries | |
104 | ctx = context.Background() | |
105 | ) | |
106 | _, err := retry(ctx, struct{}{}) | |
107 | if err == nil { | |
108 | t.Errorf("expected error, got none") // should fail | |
109 | } | |
110 | if err.Error() != "Aborting early" { | |
111 | t.Errorf("expected custom error message, got %v", err) | |
112 | } | |
113 | } | |
114 | ||
115 | func AbortEarlyOnNTries_WCB(t *testing.T) { | |
116 | var ( | |
117 | cb = func(count int, msg string) (bool, *string) { | |
118 | if (count >= 4) { | |
119 | t.Errorf("expected retries to abort at 3 but continued to %v", count) | |
120 | } | |
121 | if (count == 3) { | |
122 | return false, nil | |
123 | } | |
124 | return true, nil | |
125 | } | |
126 | endpoints = sd.FixedSubscriber{} // no endpoints | |
127 | lb = loadbalancer.NewRoundRobin(endpoints) | |
128 | retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries | |
129 | ctx = context.Background() | |
130 | ) | |
131 | if _, err := retry(ctx, struct{}{}); err == nil { | |
132 | t.Errorf("expected error, got none") // should fail | |
133 | } | |
134 | } |