Codebase list golang-github-go-kit-kit / 44df43e
Merge pull request #414 from go-kit/enhanced-error-encoder transport/http: enhance the DefaultErrorEncoder Peter Bourgon authored 7 years ago GitHub committed 7 years ago
4 changed file(s) with 250 addition(s) and 9 deletion(s). Raw diff Collapse all Expand all
0 package http
1
2 import (
3 "fmt"
4 "net/http"
5 "net/http/httptest"
6
7 "golang.org/x/net/context"
8 )
9
10 func ExamplePopulateRequestContext() {
11 handler := NewServer(
12 context.Background(),
13 func(ctx context.Context, request interface{}) (response interface{}, err error) {
14 fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string))
15 fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string))
16 fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string))
17 fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string))
18 return struct{}{}, nil
19 },
20 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
21 func(context.Context, http.ResponseWriter, interface{}) error { return nil },
22 ServerBefore(PopulateRequestContext),
23 )
24
25 server := httptest.NewServer(handler)
26 defer server.Close()
27
28 req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil)
29 req.Header.Set("X-Request-Id", "a1b2c3d4e5")
30 http.DefaultClient.Do(req)
31
32 // Output:
33 // Method PATCH
34 // RequestPath /search
35 // RequestURI /search?q=sympatico
36 // X-Request-ID a1b2c3d4e5
37 }
2121 // clients, after a request has been made, but prior to it being decoded.
2222 type ClientResponseFunc func(context.Context, *http.Response) context.Context
2323
24 // SetContentType returns a ResponseFunc that sets the Content-Type header to
25 // the provided value.
24 // SetContentType returns a ServerResponseFunc that sets the Content-Type header
25 // to the provided value.
2626 func SetContentType(contentType string) ServerResponseFunc {
2727 return SetResponseHeader("Content-Type", contentType)
2828 }
2929
30 // SetResponseHeader returns a ResponseFunc that sets the specified header.
30 // SetResponseHeader returns a ServerResponseFunc that sets the given header.
3131 func SetResponseHeader(key, val string) ServerResponseFunc {
3232 return func(ctx context.Context, w http.ResponseWriter) context.Context {
3333 w.Header().Set(key, val)
3535 }
3636 }
3737
38 // SetRequestHeader returns a RequestFunc that sets the specified header.
38 // SetRequestHeader returns a RequestFunc that sets the given header.
3939 func SetRequestHeader(key, val string) RequestFunc {
4040 return func(ctx context.Context, r *http.Request) context.Context {
4141 r.Header.Set(key, val)
4242 return ctx
4343 }
4444 }
45
46 // PopulateRequestContext is a RequestFunc that populates several values into
47 // the context from the HTTP request. Those values may be extracted using the
48 // corresponding ContextKey type in this package.
49 func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context {
50 for k, v := range map[contextKey]string{
51 ContextKeyRequestMethod: r.Method,
52 ContextKeyRequestURI: r.RequestURI,
53 ContextKeyRequestPath: r.URL.Path,
54 ContextKeyRequestProto: r.Proto,
55 ContextKeyRequestHost: r.Host,
56 ContextKeyRequestRemoteAddr: r.RemoteAddr,
57 ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"),
58 ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"),
59 ContextKeyRequestAuthorization: r.Header.Get("Authorization"),
60 ContextKeyRequestReferer: r.Header.Get("Referer"),
61 ContextKeyRequestUserAgent: r.Header.Get("User-Agent"),
62 ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"),
63 } {
64 ctx = context.WithValue(ctx, k, v)
65 }
66 return ctx
67 }
68
69 type contextKey int
70
71 const (
72 // ContextKeyRequestMethod is populated in the context by
73 // PopulateRequestContext. Its value is r.Method.
74 ContextKeyRequestMethod contextKey = iota
75
76 // ContextKeyRequestURI is populated in the context by
77 // PopulateRequestContext. Its value is r.RequestURI.
78 ContextKeyRequestURI
79
80 // ContextKeyRequestPath is populated in the context by
81 // PopulateRequestContext. Its value is r.URL.Path.
82 ContextKeyRequestPath
83
84 // ContextKeyRequestProto is populated in the context by
85 // PopulateRequestContext. Its value is r.Proto.
86 ContextKeyRequestProto
87
88 // ContextKeyRequestHost is populated in the context by
89 // PopulateRequestContext. Its value is r.Host.
90 ContextKeyRequestHost
91
92 // ContextKeyRequestRemoteAddr is populated in the context by
93 // PopulateRequestContext. Its value is r.RemoteAddr.
94 ContextKeyRequestRemoteAddr
95
96 // ContextKeyRequestXForwardedFor is populated in the context by
97 // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For").
98 ContextKeyRequestXForwardedFor
99
100 // ContextKeyRequestXForwardedProto is populated in the context by
101 // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto").
102 ContextKeyRequestXForwardedProto
103
104 // ContextKeyRequestAuthorization is populated in the context by
105 // PopulateRequestContext. Its value is r.Header.Get("Authorization").
106 ContextKeyRequestAuthorization
107
108 // ContextKeyRequestReferer is populated in the context by
109 // PopulateRequestContext. Its value is r.Header.Get("Referer").
110 ContextKeyRequestReferer
111
112 // ContextKeyRequestUserAgent is populated in the context by
113 // PopulateRequestContext. Its value is r.Header.Get("User-Agent").
114 ContextKeyRequestUserAgent
115
116 // ContextKeyRequestXRequestID is populated in the context by
117 // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id").
118 ContextKeyRequestXRequestID
119 )
00 package http
11
22 import (
3 "encoding/json"
34 "net/http"
45
56 "golang.org/x/net/context"
3536 e: e,
3637 dec: dec,
3738 enc: enc,
38 errorEncoder: defaultErrorEncoder,
39 errorEncoder: DefaultErrorEncoder,
3940 logger: log.NewNopLogger(),
4041 }
4142 for _, option := range options {
6263 // ServerErrorEncoder is used to encode errors to the http.ResponseWriter
6364 // whenever they're encountered in the processing of a request. Clients can
6465 // use this to provide custom error formatting and response codes. By default,
65 // errors will be written as plain text with an appropriate, if generic,
66 // status code.
66 // errors will be written with the DefaultErrorEncoder.
6767 func ServerErrorEncoder(ee ErrorEncoder) ServerOption {
6868 return func(s *Server) { s.errorEncoder = ee }
6969 }
133133 // intended use is for request logging.
134134 type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request)
135135
136 func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
137 http.Error(w, err.Error(), http.StatusInternalServerError)
136 // EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a
137 // JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as
138 // a sensible default. If the response implements Headerer, the provided headers
139 // will be applied to the response. If the response implements StatusCoder, the
140 // provided StatusCode will be used instead of 200.
141 func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
142 w.Header().Set("Content-Type", "application/json; charset=utf-8")
143 if headerer, ok := response.(Headerer); ok {
144 for k := range headerer.Headers() {
145 w.Header().Set(k, headerer.Headers().Get(k))
146 }
147 }
148 code := http.StatusOK
149 if sc, ok := response.(StatusCoder); ok {
150 code = sc.StatusCode()
151 }
152 w.WriteHeader(code)
153 return json.NewEncoder(w).Encode(response)
154 }
155
156 // DefaultErrorEncoder writes the error to the ResponseWriter, by default a
157 // content type of text/plain, a body of the plain text of the error, and a
158 // status code of 500. If the error implements Headerer, the provided headers
159 // will be applied to the response. If the error implements json.Marshaler, and
160 // the marshaling succeeds, a content type of application/json and the JSON
161 // encoded form of the error will be used. If the error implements StatusCoder,
162 // the provided StatusCode will be used instead of 500.
163 func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
164 contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
165 if marshaler, ok := err.(json.Marshaler); ok {
166 if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil {
167 contentType, body = "application/json; charset=utf-8", jsonBody
168 }
169 }
170 w.Header().Set("Content-Type", contentType)
171 if headerer, ok := err.(Headerer); ok {
172 for k := range headerer.Headers() {
173 w.Header().Set(k, headerer.Headers().Get(k))
174 }
175 }
176 code := http.StatusInternalServerError
177 if sc, ok := err.(StatusCoder); ok {
178 code = sc.StatusCode()
179 }
180 w.WriteHeader(code)
181 w.Write(body)
182 }
183
184 // StatusCoder is checked by DefaultErrorEncoder. If an error value implements
185 // StatusCoder, the StatusCode will be used when encoding the error. By default,
186 // StatusInternalServerError (500) is used.
187 type StatusCoder interface {
188 StatusCode() int
189 }
190
191 // Headerer is checked by DefaultErrorEncoder. If an error value implements
192 // Headerer, the provided headers will be applied to the response writer, after
193 // the Content-Type is set.
194 type Headerer interface {
195 Headers() http.Header
138196 }
139197
140198 type interceptingWriter struct {
44 "io/ioutil"
55 "net/http"
66 "net/http/httptest"
7 "strings"
78 "testing"
89
910 "golang.org/x/net/context"
118119
119120 if want != have {
120121 t.Errorf("want %d, have %d", want, have)
122 }
123 }
124
125 type enhancedResponse struct {
126 Foo string `json:"foo"`
127 }
128
129 func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired }
130 func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
131
132 func TestEncodeJSONResponse(t *testing.T) {
133 handler := httptransport.NewServer(
134 context.Background(),
135 func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
136 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
137 httptransport.EncodeJSONResponse,
138 )
139
140 server := httptest.NewServer(handler)
141 defer server.Close()
142
143 resp, err := http.Get(server.URL)
144 if err != nil {
145 t.Fatal(err)
146 }
147 if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
148 t.Errorf("StatusCode: want %d, have %d", want, have)
149 }
150 if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
151 t.Errorf("X-Edward: want %q, have %q", want, have)
152 }
153 buf, _ := ioutil.ReadAll(resp.Body)
154 if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
155 t.Errorf("Body: want %s, have %s", want, have)
156 }
157 }
158
159 type enhancedError struct{}
160
161 func (e enhancedError) Error() string { return "enhanced error" }
162 func (e enhancedError) StatusCode() int { return http.StatusTeapot }
163 func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
164 func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} }
165
166 func TestEnhancedError(t *testing.T) {
167 handler := httptransport.NewServer(
168 context.Background(),
169 func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
170 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
171 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
172 )
173
174 server := httptest.NewServer(handler)
175 defer server.Close()
176
177 resp, err := http.Get(server.URL)
178 if err != nil {
179 t.Fatal(err)
180 }
181 defer resp.Body.Close()
182 if want, have := http.StatusTeapot, resp.StatusCode; want != have {
183 t.Errorf("StatusCode: want %d, have %d", want, have)
184 }
185 if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
186 t.Errorf("X-Enhanced: want %q, have %q", want, have)
187 }
188 buf, _ := ioutil.ReadAll(resp.Body)
189 if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
190 t.Errorf("Body: want %s, have %s", want, have)
121191 }
122192 }
123193