Codebase list golang-github-go-kit-kit / b5ef921
Add SetBufferedStream() ClientOption Add ClientOption "SetBufferedStream()" transport/http, which prevents Endpoint from closing the http.Response.Body. Fixes #170 rolaveric 8 years ago
2 changed file(s) with 102 addition(s) and 16 deletion(s). Raw diff Collapse all Expand all
1212
1313 // Client wraps a URL and provides a method that implements endpoint.Endpoint.
1414 type Client struct {
15 client *http.Client
16 method string
17 tgt *url.URL
18 enc EncodeRequestFunc
19 dec DecodeResponseFunc
20 before []RequestFunc
15 client *http.Client
16 method string
17 tgt *url.URL
18 enc EncodeRequestFunc
19 dec DecodeResponseFunc
20 before []RequestFunc
21 bufferedStream bool
2122 }
2223
2324 // NewClient returns a
2425 func NewClient(method string, tgt *url.URL, enc EncodeRequestFunc, dec DecodeResponseFunc, options ...ClientOption) *Client {
2526 c := &Client{
26 client: http.DefaultClient,
27 method: method,
28 tgt: tgt,
29 enc: enc,
30 dec: dec,
31 before: []RequestFunc{},
27 client: http.DefaultClient,
28 method: method,
29 tgt: tgt,
30 enc: enc,
31 dec: dec,
32 before: []RequestFunc{},
33 bufferedStream: false,
3234 }
3335 for _, option := range options {
3436 option(c)
4951 // request before it's invoked.
5052 func SetClientBefore(before ...RequestFunc) ClientOption {
5153 return func(c *Client) { c.before = before }
54 }
55
56 // SetBufferedStream sets whether the Response.Body is left open, allowing it
57 // to be read from later. Useful for transporting a file as a buffered stream.
58 func SetBufferedStream(buffered bool) ClientOption {
59 return func(c *Client) { c.bufferedStream = buffered }
5260 }
5361
5462 // Endpoint returns a usable endpoint that will invoke the RPC specified by
7583 if err != nil {
7684 return nil, fmt.Errorf("Do: %v", err)
7785 }
78 defer func() { _ = resp.Body.Close() }()
86 if !c.bufferedStream {
87 defer resp.Body.Close()
88 }
7989
8090 response, err := c.dec(resp)
8191 if err != nil {
00 package http_test
11
22 import (
3 "io"
34 "net/http"
45 "net/http/httptest"
56 "net/url"
1112 httptransport "github.com/go-kit/kit/transport/http"
1213 )
1314
15 type TestResponse struct {
16 Body io.ReadCloser
17 String string
18 }
19
1420 func TestHTTPClient(t *testing.T) {
1521 var (
16 encode = func(*http.Request, interface{}) error { return nil }
17 decode = func(*http.Response) (interface{}, error) { return struct{}{}, nil }
22 testbody = "testbody"
23 encode = func(*http.Request, interface{}) error { return nil }
24 decode = func(r *http.Response) (interface{}, error) {
25 buffer := make([]byte, len(testbody))
26 r.Body.Read(buffer)
27 return TestResponse{r.Body, string(buffer)}, nil
28 }
1829 headers = make(chan string, 1)
1930 headerKey = "X-Foo"
2031 headerVal = "abcde"
2334 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2435 headers <- r.Header.Get(headerKey)
2536 w.WriteHeader(http.StatusOK)
37 w.Write([]byte(testbody))
2638 }))
2739
2840 client := httptransport.NewClient(
3345 httptransport.SetClientBefore(httptransport.SetRequestHeader(headerKey, headerVal)),
3446 )
3547
36 _, err := client.Endpoint()(context.Background(), struct{}{})
48 res, err := client.Endpoint()(context.Background(), struct{}{})
3749 if err != nil {
3850 t.Fatal(err)
3951 }
4456 case <-time.After(time.Millisecond):
4557 t.Fatalf("timeout waiting for %s", headerKey)
4658 }
59 // Check that Request Header was successfully received
4760 if want := headerVal; want != have {
61 t.Errorf("want %q, have %q", want, have)
62 }
63
64 // Check that the response was successfully decoded
65 response, ok := res.(TestResponse)
66 if !ok {
67 t.Fatal("response should be TestResponse")
68 }
69 if want, have := testbody, response.String; want != have {
70 t.Errorf("want %q, have %q", want, have)
71 }
72
73 // Check that response body was closed
74 b := make([]byte, 1)
75 _, err = response.Body.Read(b)
76 if err == nil {
77 t.Fatal("wanted error, got none")
78 }
79 if doNotWant, have := io.EOF, err; doNotWant == have {
80 t.Errorf("do not want %q, have %q", doNotWant, have)
81 }
82 }
83
84 func TestHTTPClientBufferedStream(t *testing.T) {
85 var (
86 testbody = "testbody"
87 encode = func(*http.Request, interface{}) error { return nil }
88 decode = func(r *http.Response) (interface{}, error) {
89 return TestResponse{r.Body, ""}, nil
90 }
91 )
92
93 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94 w.WriteHeader(http.StatusOK)
95 w.Write([]byte(testbody))
96 }))
97
98 client := httptransport.NewClient(
99 "GET",
100 mustParse(server.URL),
101 encode,
102 decode,
103 httptransport.SetBufferedStream(true),
104 )
105
106 res, err := client.Endpoint()(context.Background(), struct{}{})
107 if err != nil {
108 t.Fatal(err)
109 }
110
111 // Check that the response was successfully decoded
112 response, ok := res.(TestResponse)
113 if !ok {
114 t.Fatal("response should be TestResponse")
115 }
116
117 // Check that response body was NOT closed
118 b := make([]byte, len(testbody))
119 _, err = response.Body.Read(b)
120 if want, have := io.EOF, err; have != want {
121 t.Fatal("want %q, have %q", want, have)
122 }
123 if want, have := testbody, string(b); want != have {
48124 t.Errorf("want %q, have %q", want, have)
49125 }
50126 }