diff --git a/examples/README.md b/examples/README.md index 2adbec4..97ad33e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -170,7 +170,7 @@ log.Fatal(http.ListenAndServe(":8080", nil)) } -func decodeUppercaseRequest(r *http.Request) (interface{}, error) { +func decodeUppercaseRequest(_ context.Context, r *http.Request) (interface{}, error) { var request uppercaseRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -178,7 +178,7 @@ return request, nil } -func decodeCountRequest(r *http.Request) (interface{}, error) { +func decodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { var request countRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -186,7 +186,7 @@ return request, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } ``` diff --git a/examples/addsvc/server/encode_decode.go b/examples/addsvc/server/encode_decode.go index 5c8b470..bf3e803 100644 --- a/examples/addsvc/server/encode_decode.go +++ b/examples/addsvc/server/encode_decode.go @@ -5,12 +5,14 @@ "encoding/json" "io/ioutil" "net/http" + + "golang.org/x/net/context" ) // DecodeSumRequest decodes the request from the provided HTTP request, simply // by JSON decoding from the request body. It's designed to be used in // transport/http.Server. -func DecodeSumRequest(r *http.Request) (interface{}, error) { +func DecodeSumRequest(_ context.Context, r *http.Request) (interface{}, error) { var request SumRequest err := json.NewDecoder(r.Body).Decode(&request) return &request, err @@ -19,14 +21,14 @@ // EncodeSumResponse encodes the response to the provided HTTP response // writer, simply by JSON encoding to the writer. It's designed to be used in // transport/http.Server. -func EncodeSumResponse(w http.ResponseWriter, response interface{}) error { +func EncodeSumResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } // DecodeConcatRequest decodes the request from the provided HTTP request, // simply by JSON decoding from the request body. It's designed to be used in // transport/http.Server. -func DecodeConcatRequest(r *http.Request) (interface{}, error) { +func DecodeConcatRequest(_ context.Context, r *http.Request) (interface{}, error) { var request ConcatRequest err := json.NewDecoder(r.Body).Decode(&request) return &request, err @@ -35,14 +37,14 @@ // EncodeConcatResponse encodes the response to the provided HTTP response // writer, simply by JSON encoding to the writer. It's designed to be used in // transport/http.Server. -func EncodeConcatResponse(w http.ResponseWriter, response interface{}) error { +func EncodeConcatResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } // EncodeSumRequest encodes the request to the provided HTTP request, simply // by JSON encoding to the request body. It's designed to be used in // transport/http.Client. -func EncodeSumRequest(r *http.Request, request interface{}) error { +func EncodeSumRequest(_ context.Context, r *http.Request, request interface{}) error { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(request); err != nil { return err @@ -54,7 +56,7 @@ // DecodeSumResponse decodes the response from the provided HTTP response, // simply by JSON decoding from the response body. It's designed to be used in // transport/http.Client. -func DecodeSumResponse(resp *http.Response) (interface{}, error) { +func DecodeSumResponse(_ context.Context, resp *http.Response) (interface{}, error) { var response SumResponse err := json.NewDecoder(resp.Body).Decode(&response) return response, err @@ -63,7 +65,7 @@ // EncodeConcatRequest encodes the request to the provided HTTP request, // simply by JSON encoding to the request body. It's designed to be used in // transport/http.Client. -func EncodeConcatRequest(r *http.Request, request interface{}) error { +func EncodeConcatRequest(_ context.Context, r *http.Request, request interface{}) error { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(request); err != nil { return err @@ -75,7 +77,7 @@ // DecodeConcatResponse decodes the response from the provided HTTP response, // simply by JSON decoding from the response body. It's designed to be used in // transport/http.Client. -func DecodeConcatResponse(resp *http.Response) (interface{}, error) { +func DecodeConcatResponse(_ context.Context, resp *http.Response) (interface{}, error) { var response ConcatResponse err := json.NewDecoder(resp.Body).Decode(&response) return response, err diff --git a/examples/apigateway/main.go b/examples/apigateway/main.go index e6ba593..683ae65 100644 --- a/examples/apigateway/main.go +++ b/examples/apigateway/main.go @@ -173,12 +173,12 @@ } } -func passEncode(r *http.Request, request interface{}) error { +func passEncode(_ context.Context, r *http.Request, request interface{}) error { r.Body = request.(io.ReadCloser) return nil } -func passDecode(r *http.Response) (interface{}, error) { +func passDecode(_ context.Context, r *http.Response) (interface{}, error) { return ioutil.ReadAll(r.Body) } diff --git a/examples/profilesvc/transport.go b/examples/profilesvc/transport.go index c1951af..054bfee 100644 --- a/examples/profilesvc/transport.go +++ b/examples/profilesvc/transport.go @@ -101,7 +101,7 @@ return r } -func decodePostProfileRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodePostProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { var req postProfileRequest if e := json.NewDecoder(r.Body).Decode(&req.Profile); e != nil { return nil, e @@ -109,7 +109,7 @@ return req, nil } -func decodeGetProfileRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodeGetProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -118,7 +118,7 @@ return getProfileRequest{ID: id}, nil } -func decodePutProfileRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodePutProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -134,7 +134,7 @@ }, nil } -func decodePatchProfileRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodePatchProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -150,7 +150,7 @@ }, nil } -func decodeDeleteProfileRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodeDeleteProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -159,7 +159,7 @@ return deleteProfileRequest{ID: id}, nil } -func decodeGetAddressesRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodeGetAddressesRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -168,7 +168,7 @@ return getAddressesRequest{ProfileID: id}, nil } -func decodeGetAddressRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodeGetAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -184,7 +184,7 @@ }, nil } -func decodePostAddressRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodePostAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -200,7 +200,7 @@ }, nil } -func decodeDeleteAddressRequest(r *stdhttp.Request) (request interface{}, err error) { +func decodeDeleteAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -228,17 +228,17 @@ // client. I chose to do it this way because I didn't know if something more // specific was necessary. It's certainly possible to specialize on a // per-response (per-method) basis. -func encodeResponse(w stdhttp.ResponseWriter, response interface{}) error { +func encodeResponse(ctx context.Context, w stdhttp.ResponseWriter, response interface{}) error { if e, ok := response.(errorer); ok && e.error() != nil { // Not a Go kit transport error, but a business-logic error. // Provide those as HTTP errors. - encodeError(w, e.error()) + encodeError(ctx, e.error(), w) return nil } return json.NewEncoder(w).Encode(response) } -func encodeError(w stdhttp.ResponseWriter, err error) { +func encodeError(_ context.Context, err error, w stdhttp.ResponseWriter) { if err == nil { panic("encodeError with nil error") } @@ -255,8 +255,15 @@ case errAlreadyExists, errInconsistentIDs: return stdhttp.StatusBadRequest default: - if _, ok := err.(kithttp.BadRequestError); ok { - return stdhttp.StatusBadRequest + if e, ok := err.(kithttp.TransportError); ok { + switch e.Domain { + case kithttp.DomainDecode: + return stdhttp.StatusBadRequest + case kithttp.DomainDo: + return stdhttp.StatusServiceUnavailable + default: + return stdhttp.StatusInternalServerError + } } return stdhttp.StatusInternalServerError } diff --git a/examples/shipping/booking/transport.go b/examples/shipping/booking/transport.go index 9999e17..7cf5994 100644 --- a/examples/shipping/booking/transport.go +++ b/examples/shipping/booking/transport.go @@ -88,7 +88,7 @@ var errBadRoute = errors.New("bad route") -func decodeBookCargoRequest(r *http.Request) (interface{}, error) { +func decodeBookCargoRequest(_ context.Context, r *http.Request) (interface{}, error) { var body struct { Origin string `json:"origin"` Destination string `json:"destination"` @@ -106,7 +106,7 @@ }, nil } -func decodeLoadCargoRequest(r *http.Request) (interface{}, error) { +func decodeLoadCargoRequest(_ context.Context, r *http.Request) (interface{}, error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -115,7 +115,7 @@ return loadCargoRequest{ID: cargo.TrackingID(id)}, nil } -func decodeRequestRoutesRequest(r *http.Request) (interface{}, error) { +func decodeRequestRoutesRequest(_ context.Context, r *http.Request) (interface{}, error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -124,7 +124,7 @@ return requestRoutesRequest{ID: cargo.TrackingID(id)}, nil } -func decodeAssignToRouteRequest(r *http.Request) (interface{}, error) { +func decodeAssignToRouteRequest(_ context.Context, r *http.Request) (interface{}, error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -142,7 +142,7 @@ }, nil } -func decodeChangeDestinationRequest(r *http.Request) (interface{}, error) { +func decodeChangeDestinationRequest(_ context.Context, r *http.Request) (interface{}, error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -163,17 +163,17 @@ }, nil } -func decodeListCargosRequest(r *http.Request) (interface{}, error) { +func decodeListCargosRequest(_ context.Context, r *http.Request) (interface{}, error) { return listCargosRequest{}, nil } -func decodeListLocationsRequest(r *http.Request) (interface{}, error) { +func decodeListLocationsRequest(_ context.Context, r *http.Request) (interface{}, error) { return listLocationsRequest{}, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { if e, ok := response.(errorer); ok && e.error() != nil { - encodeError(w, e.error()) + encodeError(ctx, e.error(), w) return nil } w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -185,7 +185,7 @@ } // encode errors from business-logic -func encodeError(w http.ResponseWriter, err error) { +func encodeError(_ context.Context, err error, w http.ResponseWriter) { switch err { case cargo.ErrUnknown: w.WriteHeader(http.StatusNotFound) diff --git a/examples/shipping/handling/transport.go b/examples/shipping/handling/transport.go index ff9ece9..1777ad6 100644 --- a/examples/shipping/handling/transport.go +++ b/examples/shipping/handling/transport.go @@ -37,7 +37,7 @@ return r } -func decodeRegisterIncidentRequest(r *http.Request) (interface{}, error) { +func decodeRegisterIncidentRequest(_ context.Context, r *http.Request) (interface{}, error) { var body struct { CompletionTime time.Time `json:"completion_time"` TrackingID string `json:"tracking_id"` @@ -70,9 +70,9 @@ return types[s] } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { if e, ok := response.(errorer); ok && e.error() != nil { - encodeError(w, e.error()) + encodeError(ctx, e.error(), w) return nil } w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -84,7 +84,7 @@ } // encode errors from business-logic -func encodeError(w http.ResponseWriter, err error) { +func encodeError(_ context.Context, err error, w http.ResponseWriter) { switch err { case cargo.ErrUnknown: w.WriteHeader(http.StatusNotFound) diff --git a/examples/shipping/routing/proxying.go b/examples/shipping/routing/proxying.go index f9e536a..3051caf 100644 --- a/examples/shipping/routing/proxying.go +++ b/examples/shipping/routing/proxying.go @@ -97,7 +97,7 @@ ).Endpoint() } -func decodeFetchRoutesResponse(resp *http.Response) (interface{}, error) { +func decodeFetchRoutesResponse(_ context.Context, resp *http.Response) (interface{}, error) { var response fetchRoutesResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return nil, err @@ -105,7 +105,7 @@ return response, nil } -func encodeFetchRoutesRequest(r *http.Request, request interface{}) error { +func encodeFetchRoutesRequest(_ context.Context, r *http.Request, request interface{}) error { req := request.(fetchRoutesRequest) vals := r.URL.Query() diff --git a/examples/shipping/tracking/transport.go b/examples/shipping/tracking/transport.go index 428db89..9cac1ec 100644 --- a/examples/shipping/tracking/transport.go +++ b/examples/shipping/tracking/transport.go @@ -35,7 +35,7 @@ return r } -func decodeTrackCargoRequest(r *http.Request) (interface{}, error) { +func decodeTrackCargoRequest(_ context.Context, r *http.Request) (interface{}, error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { @@ -44,9 +44,9 @@ return trackCargoRequest{ID: id}, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { if e, ok := response.(errorer); ok && e.error() != nil { - encodeError(w, e.error()) + encodeError(ctx, e.error(), w) return nil } w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -58,7 +58,7 @@ } // encode errors from business-logic -func encodeError(w http.ResponseWriter, err error) { +func encodeError(_ context.Context, err error, w http.ResponseWriter) { switch err { case cargo.ErrUnknown: w.WriteHeader(http.StatusNotFound) diff --git a/examples/stringsvc1/main.go b/examples/stringsvc1/main.go index 684182b..876eb9c 100644 --- a/examples/stringsvc1/main.go +++ b/examples/stringsvc1/main.go @@ -74,7 +74,7 @@ } } -func decodeUppercaseRequest(r *http.Request) (interface{}, error) { +func decodeUppercaseRequest(_ context.Context, r *http.Request) (interface{}, error) { var request uppercaseRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -82,7 +82,7 @@ return request, nil } -func decodeCountRequest(r *http.Request) (interface{}, error) { +func decodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { var request countRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -90,7 +90,7 @@ return request, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } diff --git a/examples/stringsvc2/transport.go b/examples/stringsvc2/transport.go index 0c29140..a70ad3f 100644 --- a/examples/stringsvc2/transport.go +++ b/examples/stringsvc2/transport.go @@ -28,7 +28,7 @@ } } -func decodeUppercaseRequest(r *http.Request) (interface{}, error) { +func decodeUppercaseRequest(_ context.Context, r *http.Request) (interface{}, error) { var request uppercaseRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -36,7 +36,7 @@ return request, nil } -func decodeCountRequest(r *http.Request) (interface{}, error) { +func decodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { var request countRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -44,7 +44,7 @@ return request, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } diff --git a/examples/stringsvc3/transport.go b/examples/stringsvc3/transport.go index b6d3dfb..c6341c1 100644 --- a/examples/stringsvc3/transport.go +++ b/examples/stringsvc3/transport.go @@ -30,7 +30,7 @@ } } -func decodeUppercaseRequest(r *http.Request) (interface{}, error) { +func decodeUppercaseRequest(_ context.Context, r *http.Request) (interface{}, error) { var request uppercaseRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -38,7 +38,7 @@ return request, nil } -func decodeCountRequest(r *http.Request) (interface{}, error) { +func decodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { var request countRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { return nil, err @@ -46,7 +46,7 @@ return request, nil } -func decodeUppercaseResponse(r *http.Response) (interface{}, error) { +func decodeUppercaseResponse(_ context.Context, r *http.Response) (interface{}, error) { var response uppercaseResponse if err := json.NewDecoder(r.Body).Decode(&response); err != nil { return nil, err @@ -54,11 +54,11 @@ return response, nil } -func encodeResponse(w http.ResponseWriter, response interface{}) error { +func encodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { return json.NewEncoder(w).Encode(response) } -func encodeRequest(r *http.Request, request interface{}) error { +func encodeRequest(_ context.Context, r *http.Request, request interface{}) error { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(request); err != nil { return err diff --git a/transport/http/client.go b/transport/http/client.go index ae7b7c8..ae130bb 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -74,11 +74,11 @@ req, err := http.NewRequest(c.method, c.tgt.String(), nil) if err != nil { - return nil, TransportError{DomainNewRequest, err} + return nil, TransportError{Domain: DomainNewRequest, Err: err} } - if err = c.enc(req, request); err != nil { - return nil, TransportError{DomainEncode, err} + if err = c.enc(ctx, req, request); err != nil { + return nil, TransportError{Domain: DomainEncode, Err: err} } for _, f := range c.before { @@ -87,15 +87,15 @@ resp, err := ctxhttp.Do(ctx, c.client, req) if err != nil { - return nil, TransportError{DomainDo, err} + return nil, TransportError{Domain: DomainDo, Err: err} } if !c.bufferedStream { defer resp.Body.Close() } - response, err := c.dec(resp) + response, err := c.dec(ctx, resp) if err != nil { - return nil, TransportError{DomainDecode, err} + return nil, TransportError{Domain: DomainDecode, Err: err} } return response, nil diff --git a/transport/http/client_test.go b/transport/http/client_test.go index 24a8919..26d4634 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -21,8 +21,8 @@ func TestHTTPClient(t *testing.T) { var ( testbody = "testbody" - encode = func(*http.Request, interface{}) error { return nil } - decode = func(r *http.Response) (interface{}, error) { + encode = func(context.Context, *http.Request, interface{}) error { return nil } + decode = func(_ context.Context, r *http.Response) (interface{}, error) { buffer := make([]byte, len(testbody)) r.Body.Read(buffer) return TestResponse{r.Body, string(buffer)}, nil @@ -85,8 +85,8 @@ func TestHTTPClientBufferedStream(t *testing.T) { var ( testbody = "testbody" - encode = func(*http.Request, interface{}) error { return nil } - decode = func(r *http.Response) (interface{}, error) { + encode = func(context.Context, *http.Request, interface{}) error { return nil } + decode = func(_ context.Context, r *http.Response) (interface{}, error) { return TestResponse{r.Body, ""}, nil } ) diff --git a/transport/http/encode_decode.go b/transport/http/encode_decode.go index d2e84d9..8b8f6d7 100644 --- a/transport/http/encode_decode.go +++ b/transport/http/encode_decode.go @@ -1,27 +1,31 @@ package http -import "net/http" +import ( + "net/http" + + "golang.org/x/net/context" +) // DecodeRequestFunc extracts a user-domain request object from an HTTP // request object. It's designed to be used in HTTP servers, for server-side // endpoints. One straightforward DecodeRequestFunc could be something that // JSON decodes from the request body to the concrete response type. -type DecodeRequestFunc func(*http.Request) (request interface{}, err error) +type DecodeRequestFunc func(context.Context, *http.Request) (request interface{}, err error) // EncodeRequestFunc encodes the passed request object into the HTTP request // object. It's designed to be used in HTTP clients, for client-side // endpoints. One straightforward EncodeRequestFunc could something that JSON // encodes the object directly to the request body. -type EncodeRequestFunc func(*http.Request, interface{}) error +type EncodeRequestFunc func(context.Context, *http.Request, interface{}) error // EncodeResponseFunc encodes the passed response object to the HTTP response // writer. It's designed to be used in HTTP servers, for server-side // endpoints. One straightforward EncodeResponseFunc could be something that // JSON encodes the object directly to the response body. -type EncodeResponseFunc func(http.ResponseWriter, interface{}) error +type EncodeResponseFunc func(context.Context, http.ResponseWriter, interface{}) error // DecodeResponseFunc extracts a user-domain response object from an HTTP // response object. It's designed to be used in HTTP clients, for client-side // endpoints. One straightforward DecodeResponseFunc could be something that // JSON decodes from the response body to the concrete response type. -type DecodeResponseFunc func(*http.Response) (response interface{}, err error) +type DecodeResponseFunc func(context.Context, *http.Response) (response interface{}, err error) diff --git a/transport/http/err_test.go b/transport/http/err_test.go index 9682875..8a7c6b0 100644 --- a/transport/http/err_test.go +++ b/transport/http/err_test.go @@ -15,8 +15,8 @@ func TestClientEndpointEncodeError(t *testing.T) { var ( sampleErr = errors.New("Oh no, an error") - enc = func(r *http.Request, request interface{}) error { return sampleErr } - dec = func(r *http.Response) (response interface{}, err error) { return nil, nil } + enc = func(context.Context, *http.Request, interface{}) error { return sampleErr } + dec = func(context.Context, *http.Response) (interface{}, error) { return nil, nil } ) u := &url.URL{ diff --git a/transport/http/server.go b/transport/http/server.go index e9218d6..0c98632 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -83,7 +83,7 @@ ctx = f(ctx, r) } - request, err := s.dec(r) + request, err := s.dec(ctx, r) if err != nil { s.logger.Log("err", err) s.errorEncoder(ctx, TransportError{Domain: DomainDecode, Err: err}, w) @@ -101,7 +101,7 @@ f(ctx, w) } - if err := s.enc(w, response); err != nil { + if err := s.enc(ctx, w, response); err != nil { s.logger.Log("err", err) s.errorEncoder(ctx, TransportError{Domain: DomainEncode, Err: err}, w) return diff --git a/transport/http/server_test.go b/transport/http/server_test.go index afa304f..14e6fe3 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -16,8 +16,8 @@ handler := httptransport.NewServer( context.Background(), func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - func(*http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") }, - func(http.ResponseWriter, interface{}) error { return nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, ) server := httptest.NewServer(handler) defer server.Close() @@ -31,8 +31,8 @@ handler := httptransport.NewServer( context.Background(), func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, - func(*http.Request) (interface{}, error) { return struct{}{}, nil }, - func(http.ResponseWriter, interface{}) error { return nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, ) server := httptest.NewServer(handler) defer server.Close() @@ -46,8 +46,8 @@ handler := httptransport.NewServer( context.Background(), func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - func(*http.Request) (interface{}, error) { return struct{}{}, nil }, - func(http.ResponseWriter, interface{}) error { return errors.New("dang") }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") }, ) server := httptest.NewServer(handler) defer server.Close() @@ -68,8 +68,8 @@ handler := httptransport.NewServer( context.Background(), func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, - func(*http.Request) (interface{}, error) { return struct{}{}, nil }, - func(http.ResponseWriter, interface{}) error { return nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), ) server := httptest.NewServer(handler) @@ -100,8 +100,8 @@ handler = httptransport.NewServer( ctx, endpoint, - func(*http.Request) (interface{}, error) { return struct{}{}, nil }, - func(http.ResponseWriter, interface{}) error { return nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { return ctx }), httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) { return }), )