transport/http: upgrade error encoder
- Take a context.Context as a first parameter, in case there is some
necessary information there (h/t @thomshutt)
- Refactor BadRequestError as a TransportError
- Improve defaultErrorEncoder behavior
- Update tests
Peter Bourgon
8 years ago
3 | 3 | "fmt" |
4 | 4 | ) |
5 | 5 | |
6 | // These are some pre-generated constants that can be used to check against | |
7 | // for the DomainErrors. | |
8 | 6 | const ( |
9 | // DomainNewRequest represents an error at the Request Generation | |
10 | // Scope. | |
7 | // DomainNewRequest is an error during request generation. | |
11 | 8 | DomainNewRequest = "NewRequest" |
12 | 9 | |
13 | // DomainEncode represent an error that has occurred at the Encode | |
14 | // level of the request. | |
10 | // DomainEncode is an error during request or response encoding. | |
15 | 11 | DomainEncode = "Encode" |
16 | 12 | |
17 | // DomainDo represents an error that has occurred at the Do, or | |
18 | // execution phase of the request. | |
13 | // DomainDo is an error during the execution phase of the request. | |
19 | 14 | DomainDo = "Do" |
20 | 15 | |
21 | // DomainDecode represents an error that has occurred at the Decode | |
22 | // phase of the request. | |
16 | // DomainDecode is an error during request or response decoding. | |
23 | 17 | DomainDecode = "Decode" |
24 | 18 | ) |
25 | 19 | |
26 | // TransportError represents an Error occurred in the Client transport level. | |
20 | // TransportError is an error that occurred at some phase within the transport. | |
27 | 21 | type TransportError struct { |
28 | // Domain represents the domain of the error encountered. | |
29 | // Simply, this refers to the phase in which the error was | |
30 | // generated | |
22 | // Domain is the phase in which the error was generated. | |
31 | 23 | Domain string |
32 | 24 | |
33 | // Err references the underlying error that caused this error | |
34 | // overall. | |
25 | // Err is the concrete error. | |
35 | 26 | Err error |
36 | 27 | } |
37 | 28 | |
38 | // Error implements the error interface | |
29 | // Error implements the error interface. | |
39 | 30 | func (e TransportError) Error() string { |
40 | 31 | return fmt.Sprintf("%s: %s", e.Domain, e.Err) |
41 | 32 | } |
48 | 48 | |
49 | 49 | func ExampleErrOutput() { |
50 | 50 | sampleErr := errors.New("Oh no, an error") |
51 | err := httptransport.TransportError{"Do", sampleErr} | |
51 | err := httptransport.TransportError{Domain: httptransport.DomainDo, Err: sampleErr} | |
52 | 52 | fmt.Println(err) |
53 | 53 | // Output: |
54 | 54 | // Do: Oh no, an error |
16 | 16 | enc EncodeResponseFunc |
17 | 17 | before []RequestFunc |
18 | 18 | after []ResponseFunc |
19 | errorEncoder func(w http.ResponseWriter, err error) | |
19 | errorEncoder ErrorEncoder | |
20 | 20 | logger log.Logger |
21 | 21 | } |
22 | 22 | |
63 | 63 | // use this to provide custom error formatting and response codes. By default, |
64 | 64 | // errors will be written as plain text with an appropriate, if generic, |
65 | 65 | // status code. |
66 | func ServerErrorEncoder(f func(w http.ResponseWriter, err error)) ServerOption { | |
67 | return func(s *Server) { s.errorEncoder = f } | |
66 | func ServerErrorEncoder(ee ErrorEncoder) ServerOption { | |
67 | return func(s *Server) { s.errorEncoder = ee } | |
68 | 68 | } |
69 | 69 | |
70 | 70 | // ServerErrorLogger is used to log non-terminal errors. By default, no errors |
85 | 85 | request, err := s.dec(r) |
86 | 86 | if err != nil { |
87 | 87 | s.logger.Log("err", err) |
88 | s.errorEncoder(w, BadRequestError{err}) | |
88 | s.errorEncoder(ctx, TransportError{Domain: DomainDecode, Err: err}, w) | |
89 | 89 | return |
90 | 90 | } |
91 | 91 | |
92 | 92 | response, err := s.e(ctx, request) |
93 | 93 | if err != nil { |
94 | 94 | s.logger.Log("err", err) |
95 | s.errorEncoder(w, err) | |
95 | s.errorEncoder(ctx, TransportError{Domain: DomainDo, Err: err}, w) | |
96 | 96 | return |
97 | 97 | } |
98 | 98 | |
102 | 102 | |
103 | 103 | if err := s.enc(w, response); err != nil { |
104 | 104 | s.logger.Log("err", err) |
105 | s.errorEncoder(w, err) | |
105 | s.errorEncoder(ctx, TransportError{Domain: DomainEncode, Err: err}, w) | |
106 | 106 | return |
107 | 107 | } |
108 | 108 | } |
109 | 109 | |
110 | func defaultErrorEncoder(w http.ResponseWriter, err error) { | |
111 | switch err.(type) { | |
112 | case BadRequestError: | |
113 | http.Error(w, err.Error(), http.StatusBadRequest) | |
110 | // ErrorEncoder is a function that's responsible for encoding an error to the ResponseWriter. | |
111 | type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) | |
112 | ||
113 | func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { | |
114 | switch e := err.(type) { | |
115 | case TransportError: | |
116 | switch e.Domain { | |
117 | case DomainDecode: | |
118 | http.Error(w, err.Error(), http.StatusBadRequest) | |
119 | case DomainDo: | |
120 | http.Error(w, err.Error(), http.StatusServiceUnavailable) // too aggressive? | |
121 | default: | |
122 | http.Error(w, err.Error(), http.StatusInternalServerError) | |
123 | } | |
114 | 124 | default: |
115 | 125 | http.Error(w, err.Error(), http.StatusInternalServerError) |
116 | 126 | } |
117 | 127 | } |
118 | ||
119 | // BadRequestError is an error in decoding the request. | |
120 | type BadRequestError struct { | |
121 | Err error | |
122 | } | |
123 | ||
124 | // Error implements the error interface. | |
125 | func (err BadRequestError) Error() string { | |
126 | return err.Err.Error() | |
127 | } |
36 | 36 | server := httptest.NewServer(handler) |
37 | 37 | defer server.Close() |
38 | 38 | resp, _ := http.Get(server.URL) |
39 | if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { | |
39 | if want, have := http.StatusServiceUnavailable, resp.StatusCode; want != have { | |
40 | 40 | t.Errorf("want %d, have %d", want, have) |
41 | 41 | } |
42 | 42 | } |
59 | 59 | func TestServerErrorEncoder(t *testing.T) { |
60 | 60 | errTeapot := errors.New("teapot") |
61 | 61 | code := func(err error) int { |
62 | if err == errTeapot { | |
62 | if e, ok := err.(httptransport.TransportError); ok && e.Err == errTeapot { | |
63 | 63 | return http.StatusTeapot |
64 | 64 | } |
65 | 65 | return http.StatusInternalServerError |
69 | 69 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, |
70 | 70 | func(*http.Request) (interface{}, error) { return struct{}{}, nil }, |
71 | 71 | func(http.ResponseWriter, interface{}) error { return nil }, |
72 | httptransport.ServerErrorEncoder(func(w http.ResponseWriter, err error) { w.WriteHeader(code(err)) }), | |
72 | httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), | |
73 | 73 | ) |
74 | 74 | server := httptest.NewServer(handler) |
75 | 75 | defer server.Close() |
117 | 117 | }() |
118 | 118 | return cancelfn, func() { stepch <- true }, response |
119 | 119 | } |
120 | ||
121 | type testBadRequestError struct { | |
122 | code int | |
123 | } | |
124 | ||
125 | func (err testBadRequestError) Error() string { | |
126 | return "Bad Request" | |
127 | } | |
128 | ||
129 | func TestBadRequestError(t *testing.T) { | |
130 | inner := testBadRequestError{1234} | |
131 | var outer error = httptransport.BadRequestError{Err: inner} | |
132 | err := outer.(httptransport.BadRequestError) | |
133 | if want, have := inner, err.Err; want != have { | |
134 | t.Errorf("want %#v, have %#v", want, have) | |
135 | } | |
136 | } |