Merge pull request #414 from go-kit/enhanced-error-encoder
transport/http: enhance the DefaultErrorEncoder
Peter Bourgon authored 7 years ago
GitHub committed 7 years ago
0 | package http | |
1 | ||
2 | import ( | |
3 | "fmt" | |
4 | "net/http" | |
5 | "net/http/httptest" | |
6 | ||
7 | "golang.org/x/net/context" | |
8 | ) | |
9 | ||
10 | func ExamplePopulateRequestContext() { | |
11 | handler := NewServer( | |
12 | context.Background(), | |
13 | func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
14 | fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string)) | |
15 | fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string)) | |
16 | fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string)) | |
17 | fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string)) | |
18 | return struct{}{}, nil | |
19 | }, | |
20 | func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, | |
21 | func(context.Context, http.ResponseWriter, interface{}) error { return nil }, | |
22 | ServerBefore(PopulateRequestContext), | |
23 | ) | |
24 | ||
25 | server := httptest.NewServer(handler) | |
26 | defer server.Close() | |
27 | ||
28 | req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil) | |
29 | req.Header.Set("X-Request-Id", "a1b2c3d4e5") | |
30 | http.DefaultClient.Do(req) | |
31 | ||
32 | // Output: | |
33 | // Method PATCH | |
34 | // RequestPath /search | |
35 | // RequestURI /search?q=sympatico | |
36 | // X-Request-ID a1b2c3d4e5 | |
37 | } |
21 | 21 | // clients, after a request has been made, but prior to it being decoded. |
22 | 22 | type ClientResponseFunc func(context.Context, *http.Response) context.Context |
23 | 23 | |
24 | // SetContentType returns a ResponseFunc that sets the Content-Type header to | |
25 | // the provided value. | |
24 | // SetContentType returns a ServerResponseFunc that sets the Content-Type header | |
25 | // to the provided value. | |
26 | 26 | func SetContentType(contentType string) ServerResponseFunc { |
27 | 27 | return SetResponseHeader("Content-Type", contentType) |
28 | 28 | } |
29 | 29 | |
30 | // SetResponseHeader returns a ResponseFunc that sets the specified header. | |
30 | // SetResponseHeader returns a ServerResponseFunc that sets the given header. | |
31 | 31 | func SetResponseHeader(key, val string) ServerResponseFunc { |
32 | 32 | return func(ctx context.Context, w http.ResponseWriter) context.Context { |
33 | 33 | w.Header().Set(key, val) |
35 | 35 | } |
36 | 36 | } |
37 | 37 | |
38 | // SetRequestHeader returns a RequestFunc that sets the specified header. | |
38 | // SetRequestHeader returns a RequestFunc that sets the given header. | |
39 | 39 | func SetRequestHeader(key, val string) RequestFunc { |
40 | 40 | return func(ctx context.Context, r *http.Request) context.Context { |
41 | 41 | r.Header.Set(key, val) |
42 | 42 | return ctx |
43 | 43 | } |
44 | 44 | } |
45 | ||
46 | // PopulateRequestContext is a RequestFunc that populates several values into | |
47 | // the context from the HTTP request. Those values may be extracted using the | |
48 | // corresponding ContextKey type in this package. | |
49 | func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context { | |
50 | for k, v := range map[contextKey]string{ | |
51 | ContextKeyRequestMethod: r.Method, | |
52 | ContextKeyRequestURI: r.RequestURI, | |
53 | ContextKeyRequestPath: r.URL.Path, | |
54 | ContextKeyRequestProto: r.Proto, | |
55 | ContextKeyRequestHost: r.Host, | |
56 | ContextKeyRequestRemoteAddr: r.RemoteAddr, | |
57 | ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"), | |
58 | ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"), | |
59 | ContextKeyRequestAuthorization: r.Header.Get("Authorization"), | |
60 | ContextKeyRequestReferer: r.Header.Get("Referer"), | |
61 | ContextKeyRequestUserAgent: r.Header.Get("User-Agent"), | |
62 | ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"), | |
63 | } { | |
64 | ctx = context.WithValue(ctx, k, v) | |
65 | } | |
66 | return ctx | |
67 | } | |
68 | ||
69 | type contextKey int | |
70 | ||
71 | const ( | |
72 | // ContextKeyRequestMethod is populated in the context by | |
73 | // PopulateRequestContext. Its value is r.Method. | |
74 | ContextKeyRequestMethod contextKey = iota | |
75 | ||
76 | // ContextKeyRequestURI is populated in the context by | |
77 | // PopulateRequestContext. Its value is r.RequestURI. | |
78 | ContextKeyRequestURI | |
79 | ||
80 | // ContextKeyRequestPath is populated in the context by | |
81 | // PopulateRequestContext. Its value is r.URL.Path. | |
82 | ContextKeyRequestPath | |
83 | ||
84 | // ContextKeyRequestProto is populated in the context by | |
85 | // PopulateRequestContext. Its value is r.Proto. | |
86 | ContextKeyRequestProto | |
87 | ||
88 | // ContextKeyRequestHost is populated in the context by | |
89 | // PopulateRequestContext. Its value is r.Host. | |
90 | ContextKeyRequestHost | |
91 | ||
92 | // ContextKeyRequestRemoteAddr is populated in the context by | |
93 | // PopulateRequestContext. Its value is r.RemoteAddr. | |
94 | ContextKeyRequestRemoteAddr | |
95 | ||
96 | // ContextKeyRequestXForwardedFor is populated in the context by | |
97 | // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For"). | |
98 | ContextKeyRequestXForwardedFor | |
99 | ||
100 | // ContextKeyRequestXForwardedProto is populated in the context by | |
101 | // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto"). | |
102 | ContextKeyRequestXForwardedProto | |
103 | ||
104 | // ContextKeyRequestAuthorization is populated in the context by | |
105 | // PopulateRequestContext. Its value is r.Header.Get("Authorization"). | |
106 | ContextKeyRequestAuthorization | |
107 | ||
108 | // ContextKeyRequestReferer is populated in the context by | |
109 | // PopulateRequestContext. Its value is r.Header.Get("Referer"). | |
110 | ContextKeyRequestReferer | |
111 | ||
112 | // ContextKeyRequestUserAgent is populated in the context by | |
113 | // PopulateRequestContext. Its value is r.Header.Get("User-Agent"). | |
114 | ContextKeyRequestUserAgent | |
115 | ||
116 | // ContextKeyRequestXRequestID is populated in the context by | |
117 | // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). | |
118 | ContextKeyRequestXRequestID | |
119 | ) |
0 | 0 | package http |
1 | 1 | |
2 | 2 | import ( |
3 | "encoding/json" | |
3 | 4 | "net/http" |
4 | 5 | |
5 | 6 | "golang.org/x/net/context" |
35 | 36 | e: e, |
36 | 37 | dec: dec, |
37 | 38 | enc: enc, |
38 | errorEncoder: defaultErrorEncoder, | |
39 | errorEncoder: DefaultErrorEncoder, | |
39 | 40 | logger: log.NewNopLogger(), |
40 | 41 | } |
41 | 42 | for _, option := range options { |
62 | 63 | // ServerErrorEncoder is used to encode errors to the http.ResponseWriter |
63 | 64 | // whenever they're encountered in the processing of a request. Clients can |
64 | 65 | // use this to provide custom error formatting and response codes. By default, |
65 | // errors will be written as plain text with an appropriate, if generic, | |
66 | // status code. | |
66 | // errors will be written with the DefaultErrorEncoder. | |
67 | 67 | func ServerErrorEncoder(ee ErrorEncoder) ServerOption { |
68 | 68 | return func(s *Server) { s.errorEncoder = ee } |
69 | 69 | } |
133 | 133 | // intended use is for request logging. |
134 | 134 | type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) |
135 | 135 | |
136 | func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { | |
137 | http.Error(w, err.Error(), http.StatusInternalServerError) | |
136 | // EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a | |
137 | // JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as | |
138 | // a sensible default. If the response implements Headerer, the provided headers | |
139 | // will be applied to the response. If the response implements StatusCoder, the | |
140 | // provided StatusCode will be used instead of 200. | |
141 | func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { | |
142 | w.Header().Set("Content-Type", "application/json; charset=utf-8") | |
143 | if headerer, ok := response.(Headerer); ok { | |
144 | for k := range headerer.Headers() { | |
145 | w.Header().Set(k, headerer.Headers().Get(k)) | |
146 | } | |
147 | } | |
148 | code := http.StatusOK | |
149 | if sc, ok := response.(StatusCoder); ok { | |
150 | code = sc.StatusCode() | |
151 | } | |
152 | w.WriteHeader(code) | |
153 | return json.NewEncoder(w).Encode(response) | |
154 | } | |
155 | ||
156 | // DefaultErrorEncoder writes the error to the ResponseWriter, by default a | |
157 | // content type of text/plain, a body of the plain text of the error, and a | |
158 | // status code of 500. If the error implements Headerer, the provided headers | |
159 | // will be applied to the response. If the error implements json.Marshaler, and | |
160 | // the marshaling succeeds, a content type of application/json and the JSON | |
161 | // encoded form of the error will be used. If the error implements StatusCoder, | |
162 | // the provided StatusCode will be used instead of 500. | |
163 | func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { | |
164 | contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) | |
165 | if marshaler, ok := err.(json.Marshaler); ok { | |
166 | if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { | |
167 | contentType, body = "application/json; charset=utf-8", jsonBody | |
168 | } | |
169 | } | |
170 | w.Header().Set("Content-Type", contentType) | |
171 | if headerer, ok := err.(Headerer); ok { | |
172 | for k := range headerer.Headers() { | |
173 | w.Header().Set(k, headerer.Headers().Get(k)) | |
174 | } | |
175 | } | |
176 | code := http.StatusInternalServerError | |
177 | if sc, ok := err.(StatusCoder); ok { | |
178 | code = sc.StatusCode() | |
179 | } | |
180 | w.WriteHeader(code) | |
181 | w.Write(body) | |
182 | } | |
183 | ||
184 | // StatusCoder is checked by DefaultErrorEncoder. If an error value implements | |
185 | // StatusCoder, the StatusCode will be used when encoding the error. By default, | |
186 | // StatusInternalServerError (500) is used. | |
187 | type StatusCoder interface { | |
188 | StatusCode() int | |
189 | } | |
190 | ||
191 | // Headerer is checked by DefaultErrorEncoder. If an error value implements | |
192 | // Headerer, the provided headers will be applied to the response writer, after | |
193 | // the Content-Type is set. | |
194 | type Headerer interface { | |
195 | Headers() http.Header | |
138 | 196 | } |
139 | 197 | |
140 | 198 | type interceptingWriter struct { |
4 | 4 | "io/ioutil" |
5 | 5 | "net/http" |
6 | 6 | "net/http/httptest" |
7 | "strings" | |
7 | 8 | "testing" |
8 | 9 | |
9 | 10 | "golang.org/x/net/context" |
118 | 119 | |
119 | 120 | if want != have { |
120 | 121 | t.Errorf("want %d, have %d", want, have) |
122 | } | |
123 | } | |
124 | ||
125 | type enhancedResponse struct { | |
126 | Foo string `json:"foo"` | |
127 | } | |
128 | ||
129 | func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } | |
130 | func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } | |
131 | ||
132 | func TestEncodeJSONResponse(t *testing.T) { | |
133 | handler := httptransport.NewServer( | |
134 | context.Background(), | |
135 | func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, | |
136 | func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, | |
137 | httptransport.EncodeJSONResponse, | |
138 | ) | |
139 | ||
140 | server := httptest.NewServer(handler) | |
141 | defer server.Close() | |
142 | ||
143 | resp, err := http.Get(server.URL) | |
144 | if err != nil { | |
145 | t.Fatal(err) | |
146 | } | |
147 | if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { | |
148 | t.Errorf("StatusCode: want %d, have %d", want, have) | |
149 | } | |
150 | if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { | |
151 | t.Errorf("X-Edward: want %q, have %q", want, have) | |
152 | } | |
153 | buf, _ := ioutil.ReadAll(resp.Body) | |
154 | if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { | |
155 | t.Errorf("Body: want %s, have %s", want, have) | |
156 | } | |
157 | } | |
158 | ||
159 | type enhancedError struct{} | |
160 | ||
161 | func (e enhancedError) Error() string { return "enhanced error" } | |
162 | func (e enhancedError) StatusCode() int { return http.StatusTeapot } | |
163 | func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } | |
164 | func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } | |
165 | ||
166 | func TestEnhancedError(t *testing.T) { | |
167 | handler := httptransport.NewServer( | |
168 | context.Background(), | |
169 | func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, | |
170 | func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, | |
171 | func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, | |
172 | ) | |
173 | ||
174 | server := httptest.NewServer(handler) | |
175 | defer server.Close() | |
176 | ||
177 | resp, err := http.Get(server.URL) | |
178 | if err != nil { | |
179 | t.Fatal(err) | |
180 | } | |
181 | defer resp.Body.Close() | |
182 | if want, have := http.StatusTeapot, resp.StatusCode; want != have { | |
183 | t.Errorf("StatusCode: want %d, have %d", want, have) | |
184 | } | |
185 | if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { | |
186 | t.Errorf("X-Enhanced: want %q, have %q", want, have) | |
187 | } | |
188 | buf, _ := ioutil.ReadAll(resp.Body) | |
189 | if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { | |
190 | t.Errorf("Body: want %s, have %s", want, have) | |
121 | 191 | } |
122 | 192 | } |
123 | 193 |