Codebase list golang-github-go-kit-kit / f809d3d
Added RetyWithCallback allowing retries to call a function before iterating for early termination. Morgan Hein 7 years ago
2 changed file(s) with 240 addition(s) and 42 deletion(s). Raw diff Collapse all Expand all
00 package lb
11
22 import (
3 "fmt"
4 "strings"
5 "time"
3 "fmt"
4 "strings"
5 "time"
66
7 "golang.org/x/net/context"
7 "golang.org/x/net/context"
88
9 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/endpoint"
1010 )
11
12 // Callback function that indicates the current attempt count and the error encountered.
13 // Should return whether the Retry function should continue trying, and a custom
14 // error message if desired. The error message may be nil, but a true/false
15 // is always expected. In all cases if the error message is supplied, the
16 // current error will be replaced.
17 type callback func(int, string) (bool, *string)
1118
1219 // Retry wraps a service load balancer and returns an endpoint oriented load
1320 // balancer for the specified service method.
1522 // balancer. Requests that return errors will be retried until they succeed,
1623 // up to max times, or until the timeout is elapsed, whichever comes first.
1724 func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint {
18 if b == nil {
19 panic("nil Balancer")
20 }
21 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
22 var (
23 newctx, cancel = context.WithTimeout(ctx, timeout)
24 responses = make(chan interface{}, 1)
25 errs = make(chan error, 1)
26 a = []string{}
27 )
28 defer cancel()
29 for i := 1; i <= max; i++ {
30 go func() {
31 e, err := b.Endpoint()
32 if err != nil {
33 errs <- err
34 return
35 }
36 response, err := e(newctx, request)
37 if err != nil {
38 errs <- err
39 return
40 }
41 responses <- response
42 }()
25 if b == nil {
26 panic("nil Balancer")
27 }
28 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
29 var (
30 newctx, cancel = context.WithTimeout(ctx, timeout)
31 responses = make(chan interface{}, 1)
32 errs = make(chan error, 1)
33 a = []string{}
34 )
35 defer cancel()
36 for i := 1; i <= max; i++ {
37 go func() {
38 e, err := b.Endpoint()
39 if err != nil {
40 errs <- err
41 return
42 }
43 response, err := e(newctx, request)
44 if err != nil {
45 errs <- err
46 return
47 }
48 responses <- response
49 }()
4350
44 select {
45 case <-newctx.Done():
46 return nil, newctx.Err()
47 case response := <-responses:
48 return response, nil
49 case err := <-errs:
50 a = append(a, err.Error())
51 continue
52 }
53 }
54 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
55 }
51 select {
52 case <-newctx.Done():
53 return nil, newctx.Err()
54 case response := <-responses:
55 return response, nil
56 case err := <-errs:
57 a = append(a, err.Error())
58 continue
59 }
60 }
61 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
62 }
5663 }
64
65 func RetryWithCallback(max int, timeout time.Duration, b Balancer, cb callback) endpoint.Endpoint {
66 if cb == nil {
67 panic("nil callback")
68 }
69 if b == nil {
70 panic("nil Balancer")
71 }
72 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
73 var (
74 newctx, cancel = context.WithTimeout(ctx, timeout)
75 responses = make(chan interface{}, 1)
76 errs = make(chan error, 1)
77 a = []string{}
78 )
79 defer cancel()
80 for i := 1; i <= max; i++ {
81 go func() {
82 e, err := b.Endpoint()
83 if err != nil {
84 errs <- err
85 return
86 }
87 response, err := e(newctx, request)
88 if err != nil {
89 errs <- err
90 return
91 }
92 responses <- response
93 }()
94
95 select {
96 case <-newctx.Done():
97 return nil, newctx.Err()
98 case response := <-responses:
99 return response, nil
100 case err := <-errs:
101 cont, cbErr := cb(i, err.Error())
102 if !cont {
103 if cbErr == nil {
104 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
105 } else {
106 return nil, fmt.Errorf(*cbErr)
107 }
108 }
109 currentErr := err.Error()
110 if cbErr != nil {
111 currentErr = *cbErr
112 }
113 a = append(a, currentErr)
114 continue
115 }
116 }
117 return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
118 }
119 }
0 package lb_test
1
2 import (
3 "errors"
4 "testing"
5 "time"
6
7 "golang.org/x/net/context"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/sd"
11 loadbalancer "github.com/go-kit/kit/sd/lb"
12 )
13
14 func TestRetryMaxTotalFail_WCB(t *testing.T) {
15 var (
16 cb = func(count int, msg string) (bool, *string) { return true, nil }
17 endpoints = sd.FixedSubscriber{} // no endpoints
18 lb = loadbalancer.NewRoundRobin(endpoints)
19 retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries
20 ctx = context.Background()
21 )
22 if _, err := retry(ctx, struct{}{}); err == nil {
23 t.Errorf("expected error, got none") // should fail
24 }
25 }
26
27 func TestRetryMaxPartialFail_WCB(t *testing.T) {
28 var (
29 cb = func(count int, msg string) (bool, *string) { return true, nil }
30 endpoints = []endpoint.Endpoint{
31 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
32 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
33 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
34 }
35 subscriber = sd.FixedSubscriber{
36 0: endpoints[0],
37 1: endpoints[1],
38 2: endpoints[2],
39 }
40 retries = len(endpoints) - 1 // not quite enough retries
41 lb = loadbalancer.NewRoundRobin(subscriber)
42 ctx = context.Background()
43 )
44 if _, err := loadbalancer.RetryWithCallback(retries, time.Second, lb, cb)(ctx, struct{}{}); err == nil {
45 t.Errorf("expected error, got none")
46 }
47 }
48
49 func TestRetryMaxSuccess_WCB(t *testing.T) {
50 var (
51 cb = func(count int, msg string) (bool, *string) { return true, nil }
52 endpoints = []endpoint.Endpoint{
53 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
54 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
55 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
56 }
57 subscriber = sd.FixedSubscriber{
58 0: endpoints[0],
59 1: endpoints[1],
60 2: endpoints[2],
61 }
62 retries = len(endpoints) // exactly enough retries
63 lb = loadbalancer.NewRoundRobin(subscriber)
64 ctx = context.Background()
65 )
66 if _, err := loadbalancer.RetryWithCallback(retries, time.Second, lb, cb)(ctx, struct{}{}); err != nil {
67 t.Error(err)
68 }
69 }
70
71 func TestRetryTimeout_WCB(t *testing.T) {
72 var (
73 cb = func(count int, msg string) (bool, *string) { return true, nil }
74 step = make(chan struct{})
75 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
76 timeout = time.Millisecond
77 retry = loadbalancer.RetryWithCallback(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e}), cb)
78 errs = make(chan error, 1)
79 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
80 )
81
82 go func() { step <- struct{}{} }() // queue up a flush of the endpoint
83 invoke() // invoke the endpoint and trigger the flush
84 if err := <-errs; err != nil { // that should succeed
85 t.Error(err)
86 }
87
88 go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
89 invoke() // invoke the endpoint
90 if err := <-errs; err != context.DeadlineExceeded { // that should not succeed
91 t.Errorf("wanted %v, got none", context.DeadlineExceeded)
92 }
93 }
94
95 func AbortEarlyCustomMessage_WCB(t *testing.T) {
96 var (
97 cb = func(count int, msg string) (bool, *string) {
98 ret := "Aborting early"
99 return false, &ret
100 }
101 endpoints = sd.FixedSubscriber{} // no endpoints
102 lb = loadbalancer.NewRoundRobin(endpoints)
103 retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries
104 ctx = context.Background()
105 )
106 _, err := retry(ctx, struct{}{})
107 if err == nil {
108 t.Errorf("expected error, got none") // should fail
109 }
110 if err.Error() != "Aborting early" {
111 t.Errorf("expected custom error message, got %v", err)
112 }
113 }
114
115 func AbortEarlyOnNTries_WCB(t *testing.T) {
116 var (
117 cb = func(count int, msg string) (bool, *string) {
118 if (count >= 4) {
119 t.Errorf("expected retries to abort at 3 but continued to %v", count)
120 }
121 if (count == 3) {
122 return false, nil
123 }
124 return true, nil
125 }
126 endpoints = sd.FixedSubscriber{} // no endpoints
127 lb = loadbalancer.NewRoundRobin(endpoints)
128 retry = loadbalancer.RetryWithCallback(999, time.Second, lb, cb) // lots of retries
129 ctx = context.Background()
130 )
131 if _, err := retry(ctx, struct{}{}); err == nil {
132 t.Errorf("expected error, got none") // should fail
133 }
134 }