diff --git a/examples/README.md b/examples/README.md index cf1812f..2a92a1b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -507,14 +507,12 @@ } func makeUppercaseEndpoint(ctx context.Context, proxyURL string) endpoint.Endpoint { - return (httptransport.Client{ - Client: http.DefaultClient, - Method: "GET", - URL: mustParseURL(proxyURL), - Context: ctx, - EncodeFunc: encodeUppercaseRequest, - DecodeFunc: decodeUppercaseResponse, - }).Endpoint() + return httptransport.NewClient( + "GET", + mustParseURL(proxyURL), + encodeUppercaseRequest, + decodeUppercaseResponse, + ).Endpoint() } ``` diff --git a/examples/addsvc/client/httpjson/client.go b/examples/addsvc/client/httpjson/client.go index aa98d07..3a2b2f5 100644 --- a/examples/addsvc/client/httpjson/client.go +++ b/examples/addsvc/client/httpjson/client.go @@ -29,20 +29,18 @@ return client{ Context: ctx, Logger: logger, - sum: (httptransport.Client{ - Client: c, - Method: "GET", - URL: sumURL, - EncodeRequestFunc: server.EncodeSumRequest, - DecodeResponseFunc: server.DecodeSumResponse, - }).Endpoint(), - concat: (httptransport.Client{ - Client: c, - Method: "GET", - URL: concatURL, - EncodeRequestFunc: server.EncodeConcatRequest, - DecodeResponseFunc: server.DecodeConcatResponse, - }).Endpoint(), + sum: httptransport.NewClient( + "GET", + sumURL, + server.EncodeSumRequest, + server.DecodeSumResponse, + ).Endpoint(), + concat: httptransport.NewClient( + "GET", + concatURL, + server.EncodeConcatRequest, + server.DecodeConcatResponse, + ).Endpoint(), } } diff --git a/examples/stringsvc3/proxying.go b/examples/stringsvc3/proxying.go index 6a10afd..6427b06 100644 --- a/examples/stringsvc3/proxying.go +++ b/examples/stringsvc3/proxying.go @@ -3,7 +3,6 @@ import ( "errors" "fmt" - "net/http" "net/url" "strings" "time" @@ -85,13 +84,12 @@ if u.Path == "" { u.Path = "/uppercase" } - return (httptransport.Client{ - Client: http.DefaultClient, - Method: "GET", - URL: u, - DecodeResponseFunc: decodeUppercaseResponse, - EncodeRequestFunc: encodeRequest, - }).Endpoint() + return httptransport.NewClient( + "GET", + u, + encodeRequest, + decodeUppercaseResponse, + ).Endpoint() } func split(s string) []string { diff --git a/transport/http/client.go b/transport/http/client.go index 1531db9..155d4f8 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -12,26 +12,43 @@ // Client wraps a URL and provides a method that implements endpoint.Endpoint. type Client struct { - // If client is nil, http.DefaultClient will be used. - *http.Client + client *http.Client + method string + tgt *url.URL + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc +} - // Method must be provided. - Method string +// NewClient returns a +func NewClient(method string, tgt *url.URL, enc EncodeRequestFunc, dec DecodeResponseFunc, options ...ClientOption) *Client { + c := &Client{ + client: http.DefaultClient, + method: method, + tgt: tgt, + enc: enc, + dec: dec, + before: []RequestFunc{}, + } + for _, option := range options { + option(c) + } + return c +} - // URL must be provided. - URL *url.URL +// ClientOption sets an optional parameter for clients. +type ClientOption func(*Client) - // EncodeRequestFunc must be provided. The HTTP request passed to the - // EncodeRequestFunc will have a nil body. - EncodeRequestFunc +// SetClient sets the underlying HTTP client used for requests. +// By default, http.DefaultClient is used. +func SetClient(client *http.Client) ClientOption { + return func(c *Client) { c.client = client } +} - // DecodeResponseFunc must be provided. - DecodeResponseFunc - - // Before functions are executed on the outgoing request after it is - // created, but before it's sent to the HTTP client. Clients have no After - // ResponseFuncs, as they don't work with ResponseWriters. - Before []RequestFunc +// SetClientBefore sets the RequestFuncs that are applied to the outgoing HTTP +// request before it's invoked. +func SetClientBefore(before ...RequestFunc) ClientOption { + return func(c *Client) { c.before = before } } // Endpoint returns a usable endpoint that will invoke the RPC specified by @@ -41,31 +58,26 @@ ctx, cancel := context.WithCancel(ctx) defer cancel() - req, err := http.NewRequest(c.Method, c.URL.String(), nil) + req, err := http.NewRequest(c.method, c.tgt.String(), nil) if err != nil { return nil, fmt.Errorf("NewRequest: %v", err) } - if err = c.EncodeRequestFunc(req, request); err != nil { + if err = c.enc(req, request); err != nil { return nil, fmt.Errorf("Encode: %v", err) } - for _, f := range c.Before { + for _, f := range c.before { ctx = f(ctx, req) } - var resp *http.Response - if c.Client != nil { - resp, err = c.Client.Do(req) - } else { - resp, err = http.DefaultClient.Do(req) - } + resp, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("Do: %v", err) } defer func() { _ = resp.Body.Close() }() - response, err := c.DecodeResponseFunc(resp) + response, err := c.dec(resp) if err != nil { return nil, fmt.Errorf("Decode: %v", err) } diff --git a/transport/http/client_test.go b/transport/http/client_test.go index 19d2b79..5135e7d 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -26,13 +26,13 @@ w.WriteHeader(http.StatusOK) })) - client := httptransport.Client{ - Method: "GET", - URL: mustParse(server.URL), - EncodeRequestFunc: encode, - DecodeResponseFunc: decode, - Before: []httptransport.RequestFunc{httptransport.SetRequestHeader(headerKey, headerVal)}, - } + client := httptransport.NewClient( + "GET", + mustParse(server.URL), + encode, + decode, + httptransport.SetClientBefore(httptransport.SetRequestHeader(headerKey, headerVal)), + ) _, err := client.Endpoint()(context.Background(), struct{}{}) if err != nil {