diff --git a/transport/http/jsonrpc/server.go b/transport/http/jsonrpc/server.go index 0310f3e..75666e4 100644 --- a/transport/http/jsonrpc/server.go +++ b/transport/http/jsonrpc/server.go @@ -10,6 +10,10 @@ "github.com/go-kit/kit/log" httptransport "github.com/go-kit/kit/transport/http" ) + +type requestIDKeyType struct{} + +var requestIDKey requestIDKeyType // Server wraps an endpoint and implements http.Handler. type Server struct { @@ -105,6 +109,8 @@ return } + ctx = context.WithValue(ctx, requestIDKey, req.ID) + // Get the endpoint and codecs from the map using the method // defined in the JSON object ecm, ok := s.ecm[req.Method] @@ -160,7 +166,7 @@ // If the error implements ErrorCoder, the provided code will be set on the // response error. // If the error implements Headerer, the given headers will be set. -func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { +func DefaultErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) { w.Header().Set("Content-Type", ContentType) if headerer, ok := err.(httptransport.Headerer); ok { for k := range headerer.Headers() { @@ -177,7 +183,13 @@ } w.WriteHeader(http.StatusOK) + + var requestID *RequestID + if v := ctx.Value(requestIDKey); v != nil { + requestID = v.(*RequestID) + } _ = json.NewEncoder(w).Encode(Response{ + ID: requestID, JSONRPC: Version, Error: &e, }) diff --git a/transport/http/jsonrpc/server_test.go b/transport/http/jsonrpc/server_test.go index d7960fe..c279d3c 100644 --- a/transport/http/jsonrpc/server_test.go +++ b/transport/http/jsonrpc/server_test.go @@ -24,17 +24,51 @@ return strings.NewReader(in) } +func unmarshalResponse(body []byte) (resp jsonrpc.Response, err error) { + err = json.Unmarshal(body, &resp) + return +} + func expectErrorCode(t *testing.T, want int, body []byte) { - var r jsonrpc.Response - err := json.Unmarshal(body, &r) - if err != nil { - t.Fatalf("Cant' decode response. err=%s, body=%s", err, body) + t.Helper() + + r, err := unmarshalResponse(body) + if err != nil { + t.Fatalf("Can't decode response: %v (%s)", err, body) } if r.Error == nil { t.Fatalf("Expected error on response. Got none: %s", body) } if have := r.Error.Code; want != have { t.Fatalf("Unexpected error code. Want %d, have %d: %s", want, have, body) + } +} + +func expectValidRequestID(t *testing.T, want int, body []byte) { + t.Helper() + + r, err := unmarshalResponse(body) + if err != nil { + t.Fatalf("Can't decode response: %v (%s)", err, body) + } + have, err := r.ID.Int() + if err != nil { + t.Fatalf("Can't get requestID in response. err=%s, body=%s", err, body) + } + if want != have { + t.Fatalf("Request ID: want %d, have %d (%s)", want, have, body) + } +} + +func expectNilRequestID(t *testing.T, body []byte) { + t.Helper() + + r, err := unmarshalResponse(body) + if err != nil { + t.Fatalf("Can't decode response: %v (%s)", err, body) + } + if r.ID != nil { + t.Fatalf("Request ID: want nil, have %v", r.ID) } } @@ -92,6 +126,7 @@ } buf, _ := ioutil.ReadAll(resp.Body) expectErrorCode(t, jsonrpc.InternalError, buf) + expectValidRequestID(t, 1, buf) } func TestServerBadEncode(t *testing.T) { @@ -111,6 +146,7 @@ } buf, _ := ioutil.ReadAll(resp.Body) expectErrorCode(t, jsonrpc.InternalError, buf) + expectValidRequestID(t, 1, buf) } func TestServerErrorEncoder(t *testing.T) { @@ -162,6 +198,7 @@ } buf, _ := ioutil.ReadAll(resp.Body) expectErrorCode(t, jsonrpc.ParseError, buf) + expectNilRequestID(t, buf) } func TestServerUnregisteredMethod(t *testing.T) { @@ -186,10 +223,9 @@ if want, have := http.StatusOK, resp.StatusCode; want != have { t.Errorf("want %d, have %d (%s)", want, have, buf) } - var r jsonrpc.Response - err := json.Unmarshal(buf, &r) - if err != nil { - t.Fatalf("Cant' decode response. err=%s, body=%s", err, buf) + r, err := unmarshalResponse(buf) + if err != nil { + t.Fatalf("Can't decode response. err=%s, body=%s", err, buf) } if r.JSONRPC != jsonrpc.Version { t.Fatalf("JSONRPC Version: want=%s, got=%s", jsonrpc.Version, r.JSONRPC)