diff --git a/transport/http/client.go b/transport/http/client.go index 08f1b88..88581a3 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -23,6 +23,7 @@ dec DecodeResponseFunc before []RequestFunc after []ClientResponseFunc + finalizer ClientFinalizerFunc bufferedStream bool } @@ -72,6 +73,12 @@ return func(c *Client) { c.after = append(c.after, after...) } } +// ClientFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ClientFinalizer(f ClientFinalizerFunc) ClientOption { + return func(s *Client) { s.finalizer = f } +} + // BufferedStream sets whether the Response.Body is left open, allowing it // to be read from later. Useful for transporting a file as a buffered stream. func BufferedStream(buffered bool) ClientOption { @@ -84,7 +91,21 @@ ctx, cancel := context.WithCancel(ctx) defer cancel() - req, err := http.NewRequest(c.method, c.tgt.String(), nil) + // Vars used for client finalizer to ensure there are no nil values + var ( + req *http.Request = &http.Request{} + resp *http.Response = &http.Response{} + err error + ) + if c.finalizer != nil { + defer func() { + ctx = context.WithValue(ctx, ContextKeyResponseHeaders, resp.Header) + ctx = context.WithValue(ctx, ContextKeyResponseSize, resp.ContentLength) + c.finalizer(ctx, resp.StatusCode, req) + }() + } + + req, err = http.NewRequest(c.method, c.tgt.String(), nil) if err != nil { return nil, err } @@ -97,10 +118,11 @@ ctx = f(ctx, req) } - resp, err := ctxhttp.Do(ctx, c.client, req) + resp, err = ctxhttp.Do(ctx, c.client, req) if err != nil { return nil, err } + if !c.bufferedStream { defer resp.Body.Close() } @@ -117,6 +139,13 @@ return response, nil } } + +// ClientFinalizerFunc can be used to perform work at the end of a client HTTP +// request, after the response is returned. The principal +// intended use is for request logging. In addition to the response code +// provided in the function signature, additional response parameters are +// provided in the context under keys with the ContextKeyResponse prefix. +type ClientFinalizerFunc func(ctx context.Context, code int, r *http.Request) // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a // JSON object to the Request body. Many JSON-over-HTTP services can use it as diff --git a/transport/http/client_test.go b/transport/http/client_test.go index ad366d1..9a50cc8 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -137,6 +137,62 @@ } if want, have := testbody, string(b); want != have { t.Errorf("want %q, have %q", want, have) + } +} + +func TestClientFinalizer(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + encode = func(context.Context, *http.Request, interface{}) error { return nil } + decode = func(_ context.Context, r *http.Response) (interface{}, error) { + return TestResponse{r.Body, ""}, nil + } + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + })) + defer server.Close() + + client := httptransport.NewClient( + "GET", + mustParse(server.URL), + encode, + decode, + httptransport.ClientFinalizer(func(ctx context.Context, code int, _ *http.Request) { + if want, have := statusCode, code; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + + responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header) + if want, have := headerVal, responseHeader.Get(headerKey); want != have { + t.Errorf("%s: want %q, have %q", headerKey, want, have) + } + + responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64) + if want, have := int64(len(responseBody)), responseSize; want != have { + t.Errorf("response size: want %d, have %d", want, have) + } + + close(done) + }), + ) + + _, err := client.Endpoint()(context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") } }