JSON RPC over HTTP (#576)
* first pass at JSON RPC HTTP transport
* example implementation of JSON RPC over HTTP
* Add ID type for JSON RPC Request, with tests.
* Add basic server testing for JSON RPC.
Add basic server tests, following example from http transport. Switch Response.Error to pointer, to make absence clearer.
* Handle unregistered JSON RPC methods.
* Package tidy-up.
* Test ServerBefore / ServerAfter for JSON RPC.
* More JSON RPC tests.
* Remove JSON RPC from addsvc example, pending full JSON RPC example.
* Remove JSON RPC from addsvc example, pending full JSON RPC example.
* Remove context field from jsonrpc.Server.
* Add JSON content type to all JSON RPC responses.
* Add JSON content type to all JSON RPC responses.
* Remove client-side JSON RPC funcs for now.
* Document interceptingWriter
* Add JSON RPC doc.go.
* Add README for JSON RPC.
* Wire in JSON RPC addsvc.
* Add JSON RPC to Addsvc CLI.
* Set JSONRPC version in responses.
* Add JSON RPC client to addcli example.
* Wire in client middlewares for JSON RPC addsvc example.
* Fix rate limiter dependency.
* Add concat JSON RPC method.
* Improve JSON RPC server test coverage.
* Add error tests.
* Clarify ErrorCoder in comment.
* Make endpoint consistent in README.
* Gofmt handler example in README.
* Auto-increment client IDs. Allow for customisation.
* Add happy-path test for JSON RPC client.
* Provide default encoder/decoder in JSON RPC client.
* Fix comment line.
* RequestIDGenerator tidy-up.
Make auto-incrementing IDs goroutine safe.
Make RequestIDGenerator interface public.
* Fix client ID creation.
The client had been using the RequestID type in requests. Making this
serialize in a deterministic and predictable way was going to be fiddly, so
I decided to allow interface{} for IDs, client-side.
* Test client request ID more effectively.
* Cover client options in test.
* Improve error test coverage.
* Fix format spec in test output.
* Tweaks to satisfy the linter.
Ross McFarlane authored 6 years ago
Peter Bourgon committed 6 years ago
36 | 36 | httpAddr = fs.String("http-addr", "", "HTTP address of addsvc") |
37 | 37 | grpcAddr = fs.String("grpc-addr", "", "gRPC address of addsvc") |
38 | 38 | thriftAddr = fs.String("thrift-addr", "", "Thrift address of addsvc") |
39 | jsonRPCAddr = fs.String("jsonrpc-addr", "", "JSON RPC address of addsvc") | |
39 | 40 | thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson") |
40 | 41 | thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered") |
41 | 42 | thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing") |
101 | 102 | } |
102 | 103 | defer conn.Close() |
103 | 104 | svc = addtransport.NewGRPCClient(conn, tracer, log.NewNopLogger()) |
105 | } else if *jsonRPCAddr != "" { | |
106 | svc, err = addtransport.NewJSONRPCClient(*jsonRPCAddr, tracer, log.NewNopLogger()) | |
104 | 107 | } else if *thriftAddr != "" { |
105 | 108 | // It's necessary to do all of this construction in the func main, |
106 | 109 | // because (among other reasons) we need to control the lifecycle of the |
41 | 41 | httpAddr = fs.String("http-addr", ":8081", "HTTP listen address") |
42 | 42 | grpcAddr = fs.String("grpc-addr", ":8082", "gRPC listen address") |
43 | 43 | thriftAddr = fs.String("thrift-addr", ":8083", "Thrift listen address") |
44 | jsonRPCAddr = fs.String("jsonrpc-addr", ":8084", "JSON RPC listen address") | |
44 | 45 | thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson") |
45 | 46 | thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered") |
46 | 47 | thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing") |
134 | 135 | // the interfaces that the transports expect. Note that we're not binding |
135 | 136 | // them to ports or anything yet; we'll do that next. |
136 | 137 | var ( |
137 | service = addservice.New(logger, ints, chars) | |
138 | endpoints = addendpoint.New(service, logger, duration, tracer) | |
139 | httpHandler = addtransport.NewHTTPHandler(endpoints, tracer, logger) | |
140 | grpcServer = addtransport.NewGRPCServer(endpoints, tracer, logger) | |
141 | thriftServer = addtransport.NewThriftServer(endpoints) | |
138 | service = addservice.New(logger, ints, chars) | |
139 | endpoints = addendpoint.New(service, logger, duration, tracer) | |
140 | httpHandler = addtransport.NewHTTPHandler(endpoints, tracer, logger) | |
141 | grpcServer = addtransport.NewGRPCServer(endpoints, tracer, logger) | |
142 | thriftServer = addtransport.NewThriftServer(endpoints) | |
143 | jsonrpcHandler = addtransport.NewJSONRPCHandler(endpoints, logger) | |
142 | 144 | ) |
143 | 145 | |
144 | 146 | // Now we're to the part of the func main where we want to start actually |
244 | 246 | }) |
245 | 247 | } |
246 | 248 | { |
249 | httpListener, err := net.Listen("tcp", *jsonRPCAddr) | |
250 | if err != nil { | |
251 | logger.Log("transport", "JSONRPC over HTTP", "during", "Listen", "err", err) | |
252 | os.Exit(1) | |
253 | } | |
254 | g.Add(func() error { | |
255 | logger.Log("transport", "JSONRPC over HTTP", "addr", *jsonRPCAddr) | |
256 | return http.Serve(httpListener, jsonrpcHandler) | |
257 | }, func(error) { | |
258 | httpListener.Close() | |
259 | }) | |
260 | } | |
261 | { | |
247 | 262 | // This function just sits and waits for ctrl-C. |
248 | 263 | cancelInterrupt := make(chan struct{}) |
249 | 264 | g.Add(func() error { |
0 | package addtransport | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "fmt" | |
6 | "net/url" | |
7 | "strings" | |
8 | "time" | |
9 | ||
10 | "golang.org/x/time/rate" | |
11 | ||
12 | "github.com/go-kit/kit/circuitbreaker" | |
13 | "github.com/go-kit/kit/endpoint" | |
14 | "github.com/go-kit/kit/examples/addsvc/pkg/addendpoint" | |
15 | "github.com/go-kit/kit/examples/addsvc/pkg/addservice" | |
16 | "github.com/go-kit/kit/log" | |
17 | "github.com/go-kit/kit/ratelimit" | |
18 | "github.com/go-kit/kit/tracing/opentracing" | |
19 | "github.com/go-kit/kit/transport/http/jsonrpc" | |
20 | stdopentracing "github.com/opentracing/opentracing-go" | |
21 | "github.com/sony/gobreaker" | |
22 | ) | |
23 | ||
24 | // NewJSONRPCHandler returns a JSON RPC Server/Handler that can be passed to http.Handle() | |
25 | func NewJSONRPCHandler(endpoints addendpoint.Set, logger log.Logger) *jsonrpc.Server { | |
26 | handler := jsonrpc.NewServer( | |
27 | makeEndpointCodecMap(endpoints), | |
28 | jsonrpc.ServerErrorLogger(logger), | |
29 | ) | |
30 | return handler | |
31 | } | |
32 | ||
33 | // NewJSONRPCClient returns an addservice backed by a JSON RPC over HTTP server | |
34 | // living at the remote instance. We expect instance to come from a service | |
35 | // discovery system, so likely of the form "host:port". We bake-in certain | |
36 | // middlewares, implementing the client library pattern. | |
37 | func NewJSONRPCClient(instance string, tracer stdopentracing.Tracer, logger log.Logger) (addservice.Service, error) { | |
38 | // Quickly sanitize the instance string. | |
39 | if !strings.HasPrefix(instance, "http") { | |
40 | instance = "http://" + instance | |
41 | } | |
42 | u, err := url.Parse(instance) | |
43 | if err != nil { | |
44 | return nil, err | |
45 | } | |
46 | ||
47 | // We construct a single ratelimiter middleware, to limit the total outgoing | |
48 | // QPS from this client to all methods on the remote instance. We also | |
49 | // construct per-endpoint circuitbreaker middlewares to demonstrate how | |
50 | // that's done, although they could easily be combined into a single breaker | |
51 | // for the entire remote instance, too. | |
52 | limiter := ratelimit.NewErroringLimiter(rate.NewLimiter(rate.Every(time.Second), 100)) | |
53 | ||
54 | var sumEndpoint endpoint.Endpoint | |
55 | { | |
56 | sumEndpoint = jsonrpc.NewClient( | |
57 | u, | |
58 | "sum", | |
59 | jsonrpc.ClientRequestEncoder(encodeSumRequest), | |
60 | jsonrpc.ClientResponseDecoder(decodeSumResponse), | |
61 | ).Endpoint() | |
62 | sumEndpoint = opentracing.TraceClient(tracer, "Sum")(sumEndpoint) | |
63 | sumEndpoint = limiter(sumEndpoint) | |
64 | sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ | |
65 | Name: "Sum", | |
66 | Timeout: 30 * time.Second, | |
67 | }))(sumEndpoint) | |
68 | } | |
69 | ||
70 | var concatEndpoint endpoint.Endpoint | |
71 | { | |
72 | concatEndpoint = jsonrpc.NewClient( | |
73 | u, | |
74 | "concat", | |
75 | jsonrpc.ClientRequestEncoder(encodeConcatRequest), | |
76 | jsonrpc.ClientResponseDecoder(decodeConcatResponse), | |
77 | ).Endpoint() | |
78 | concatEndpoint = opentracing.TraceClient(tracer, "Concat")(concatEndpoint) | |
79 | concatEndpoint = limiter(concatEndpoint) | |
80 | concatEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ | |
81 | Name: "Concat", | |
82 | Timeout: 30 * time.Second, | |
83 | }))(concatEndpoint) | |
84 | } | |
85 | ||
86 | // Returning the endpoint.Set as a service.Service relies on the | |
87 | // endpoint.Set implementing the Service methods. That's just a simple bit | |
88 | // of glue code. | |
89 | return addendpoint.Set{ | |
90 | SumEndpoint: sumEndpoint, | |
91 | ConcatEndpoint: concatEndpoint, | |
92 | }, nil | |
93 | ||
94 | } | |
95 | ||
96 | // makeEndpointCodecMap returns a codec map configured for the addsvc. | |
97 | func makeEndpointCodecMap(endpoints addendpoint.Set) jsonrpc.EndpointCodecMap { | |
98 | return jsonrpc.EndpointCodecMap{ | |
99 | "sum": jsonrpc.EndpointCodec{ | |
100 | Endpoint: endpoints.SumEndpoint, | |
101 | Decode: decodeSumRequest, | |
102 | Encode: encodeSumResponse, | |
103 | }, | |
104 | "concat": jsonrpc.EndpointCodec{ | |
105 | Endpoint: endpoints.ConcatEndpoint, | |
106 | Decode: decodeConcatRequest, | |
107 | Encode: encodeConcatResponse, | |
108 | }, | |
109 | } | |
110 | } | |
111 | ||
112 | func decodeSumRequest(_ context.Context, msg json.RawMessage) (interface{}, error) { | |
113 | var req addendpoint.SumRequest | |
114 | err := json.Unmarshal(msg, &req) | |
115 | if err != nil { | |
116 | return nil, &jsonrpc.Error{ | |
117 | Code: -32000, | |
118 | Message: fmt.Sprintf("couldn't unmarshal body to sum request: %s", err), | |
119 | } | |
120 | } | |
121 | return req, nil | |
122 | } | |
123 | ||
124 | func encodeSumResponse(_ context.Context, obj interface{}) (json.RawMessage, error) { | |
125 | res, ok := obj.(addendpoint.SumResponse) | |
126 | if !ok { | |
127 | return nil, &jsonrpc.Error{ | |
128 | Code: -32000, | |
129 | Message: fmt.Sprintf("Asserting result to *SumResponse failed. Got %T, %+v", obj, obj), | |
130 | } | |
131 | } | |
132 | b, err := json.Marshal(res) | |
133 | if err != nil { | |
134 | return nil, fmt.Errorf("couldn't marshal response: %s", err) | |
135 | } | |
136 | return b, nil | |
137 | } | |
138 | ||
139 | func decodeSumResponse(_ context.Context, msg json.RawMessage) (interface{}, error) { | |
140 | var res addendpoint.SumResponse | |
141 | err := json.Unmarshal(msg, &res) | |
142 | if err != nil { | |
143 | return nil, fmt.Errorf("couldn't unmarshal body to SumResponse: %s", err) | |
144 | } | |
145 | return res, nil | |
146 | } | |
147 | ||
148 | func encodeSumRequest(_ context.Context, obj interface{}) (json.RawMessage, error) { | |
149 | req, ok := obj.(addendpoint.SumRequest) | |
150 | if !ok { | |
151 | return nil, fmt.Errorf("couldn't assert request as SumRequest, got %T", obj) | |
152 | } | |
153 | b, err := json.Marshal(req) | |
154 | if err != nil { | |
155 | return nil, fmt.Errorf("couldn't marshal request: %s", err) | |
156 | } | |
157 | return b, nil | |
158 | } | |
159 | ||
160 | func decodeConcatRequest(_ context.Context, msg json.RawMessage) (interface{}, error) { | |
161 | var req addendpoint.ConcatRequest | |
162 | err := json.Unmarshal(msg, &req) | |
163 | if err != nil { | |
164 | return nil, &jsonrpc.Error{ | |
165 | Code: -32000, | |
166 | Message: fmt.Sprintf("couldn't unmarshal body to concat request: %s", err), | |
167 | } | |
168 | } | |
169 | return req, nil | |
170 | } | |
171 | ||
172 | func encodeConcatResponse(_ context.Context, obj interface{}) (json.RawMessage, error) { | |
173 | res, ok := obj.(addendpoint.ConcatResponse) | |
174 | if !ok { | |
175 | return nil, &jsonrpc.Error{ | |
176 | Code: -32000, | |
177 | Message: fmt.Sprintf("Asserting result to *ConcatResponse failed. Got %T, %+v", obj, obj), | |
178 | } | |
179 | } | |
180 | b, err := json.Marshal(res) | |
181 | if err != nil { | |
182 | return nil, fmt.Errorf("couldn't marshal response: %s", err) | |
183 | } | |
184 | return b, nil | |
185 | } | |
186 | ||
187 | func decodeConcatResponse(_ context.Context, msg json.RawMessage) (interface{}, error) { | |
188 | var res addendpoint.ConcatResponse | |
189 | err := json.Unmarshal(msg, &res) | |
190 | if err != nil { | |
191 | return nil, fmt.Errorf("couldn't unmarshal body to ConcatResponse: %s", err) | |
192 | } | |
193 | return res, nil | |
194 | } | |
195 | ||
196 | func encodeConcatRequest(_ context.Context, obj interface{}) (json.RawMessage, error) { | |
197 | req, ok := obj.(addendpoint.ConcatRequest) | |
198 | if !ok { | |
199 | return nil, fmt.Errorf("couldn't assert request as ConcatRequest, got %T", obj) | |
200 | } | |
201 | b, err := json.Marshal(req) | |
202 | if err != nil { | |
203 | return nil, fmt.Errorf("couldn't marshal request: %s", err) | |
204 | } | |
205 | return b, nil | |
206 | } |
142 | 142 | // request, after the response is returned. The principal |
143 | 143 | // intended use is for error logging. Additional response parameters are |
144 | 144 | // provided in the context under keys with the ContextKeyResponse prefix. |
145 | // Note: err may be nil. There maybe also no additional response parameters depending on | |
146 | // when an error occurs. | |
145 | // Note: err may be nil. There maybe also no additional response parameters | |
146 | // depending on when an error occurs. | |
147 | 147 | type ClientFinalizerFunc func(ctx context.Context, err error) |
148 | 148 | |
149 | 149 | // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a |
0 | # JSON RPC | |
1 | ||
2 | [JSON RPC](http://www.jsonrpc.org) is "A light weight remote procedure call protocol". It allows for the creation of simple RPC-style APIs with human-readable messages that are front-end friendly. | |
3 | ||
4 | ## Using JSON RPC with Go-Kit | |
5 | Using JSON RPC and go-kit together is quite simple. | |
6 | ||
7 | A JSON RPC _server_ acts as an [HTTP Handler](https://godoc.org/net/http#Handler), receiving all requests to the JSON RPC's URL. The server looks at the `method` property of the [Request Object](http://www.jsonrpc.org/specification#request_object), and routes it to the corresponding code. | |
8 | ||
9 | Each JSON RPC _method_ is implemented as an `EndpointCodec`, a go-kit [Endpoint](https://godoc.org/github.com/go-kit/kit/endpoint#Endpoint), sandwiched between a decoder and encoder. The decoder picks apart the JSON RPC request params, which can be passed to your endpoint. The encoder receives the output from the endpoint and encodes a JSON-RPC result. | |
10 | ||
11 | ## Example — Add Service | |
12 | Let's say we want a service that adds two ints together. We'll serve this at `http://localhost/rpc`. So a request to our `sum` method will be a POST to `http://localhost/rpc` with a request body of: | |
13 | ||
14 | { | |
15 | "id": 123, | |
16 | "jsonrpc": "2.0", | |
17 | "method": "sum", | |
18 | "params": { | |
19 | "A": 2, | |
20 | "B": 2 | |
21 | } | |
22 | } | |
23 | ||
24 | ### `EndpointCodecMap` | |
25 | The routing table for incoming JSON RPC requests is the `EndpointCodecMap`. The key of the map is the JSON RPC method name. Here, we're routing the `sum` method to an `EndpointCodec` wrapped around `sumEndpoint`. | |
26 | ||
27 | jsonrpc.EndpointCodecMap{ | |
28 | "sum": jsonrpc.EndpointCodec{ | |
29 | Endpoint: sumEndpoint, | |
30 | Decode: decodeSumRequest, | |
31 | Encode: encodeSumResponse, | |
32 | }, | |
33 | } | |
34 | ||
35 | ### Decoder | |
36 | type DecodeRequestFunc func(context.Context, json.RawMessage) (request interface{}, err error) | |
37 | ||
38 | A `DecodeRequestFunc` is given the raw JSON from the `params` property of the Request object, _not_ the whole request object. It returns an object that will be the input to the Endpoint. For our purposes, the output should be a SumRequest, like this: | |
39 | ||
40 | type SumRequest struct { | |
41 | A, B int | |
42 | } | |
43 | ||
44 | So here's our decoder: | |
45 | ||
46 | func decodeSumRequest(ctx context.Context, msg json.RawMessage) (interface{}, error) { | |
47 | var req SumRequest | |
48 | err := json.Unmarshal(msg, &req) | |
49 | if err != nil { | |
50 | return nil, err | |
51 | } | |
52 | return req, nil | |
53 | } | |
54 | ||
55 | So our `SumRequest` will now be passed to the endpoint. Once the endpoint has done its work, we hand over to the… | |
56 | ||
57 | ### Encoder | |
58 | The encoder takes the output of the endpoint, and builds the raw JSON message that will form the `result` field of a [Response Object](http://www.jsonrpc.org/specification#response_object). Our result is going to be a plain int. Here's our encoder: | |
59 | ||
60 | func encodeSumResponse(ctx context.Context, result interface{}) (json.RawMessage, error) { | |
61 | sum, ok := result.(int) | |
62 | if !ok { | |
63 | return nil, errors.New("result is not an int") | |
64 | } | |
65 | b, err := json.Marshal(sum) | |
66 | if err != nil { | |
67 | return nil, err | |
68 | } | |
69 | return b, nil | |
70 | } | |
71 | ||
72 | ### Server | |
73 | Now that we have an EndpointCodec with decoder, endpoint, and encoder, we can wire up the server: | |
74 | ||
75 | handler := jsonrpc.NewServer(jsonrpc.EndpointCodecMap{ | |
76 | "sum": jsonrpc.EndpointCodec{ | |
77 | Endpoint: sumEndpoint, | |
78 | Decode: decodeSumRequest, | |
79 | Encode: encodeSumResponse, | |
80 | }, | |
81 | }) | |
82 | http.Handle("/rpc", handler) | |
83 | http.ListenAndServe(":80", nil) | |
84 | ||
85 | With all of this done, our example request above should result in a response like this: | |
86 | ||
87 | { | |
88 | "jsonrpc": "2.0", | |
89 | "result": 4, | |
90 | "error": null | |
91 | } |
0 | package jsonrpc | |
1 | ||
2 | import ( | |
3 | "bytes" | |
4 | "context" | |
5 | "encoding/json" | |
6 | "io/ioutil" | |
7 | "net/http" | |
8 | "net/url" | |
9 | "sync/atomic" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | httptransport "github.com/go-kit/kit/transport/http" | |
13 | ) | |
14 | ||
15 | // Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint. | |
16 | type Client struct { | |
17 | client *http.Client | |
18 | ||
19 | // JSON RPC endpoint URL | |
20 | tgt *url.URL | |
21 | ||
22 | // JSON RPC method name. | |
23 | method string | |
24 | ||
25 | enc EncodeRequestFunc | |
26 | dec DecodeResponseFunc | |
27 | before []httptransport.RequestFunc | |
28 | after []httptransport.ClientResponseFunc | |
29 | finalizer httptransport.ClientFinalizerFunc | |
30 | requestID RequestIDGenerator | |
31 | bufferedStream bool | |
32 | } | |
33 | ||
34 | type clientRequest struct { | |
35 | JSONRPC string `json:"jsonrpc"` | |
36 | Method string `json:"method"` | |
37 | Params json.RawMessage `json:"params"` | |
38 | ID interface{} `json:"id"` | |
39 | } | |
40 | ||
41 | // NewClient constructs a usable Client for a single remote method. | |
42 | func NewClient( | |
43 | tgt *url.URL, | |
44 | method string, | |
45 | options ...ClientOption, | |
46 | ) *Client { | |
47 | c := &Client{ | |
48 | client: http.DefaultClient, | |
49 | method: method, | |
50 | tgt: tgt, | |
51 | enc: DefaultRequestEncoder, | |
52 | dec: DefaultResponseDecoder, | |
53 | before: []httptransport.RequestFunc{}, | |
54 | after: []httptransport.ClientResponseFunc{}, | |
55 | requestID: NewAutoIncrementID(0), | |
56 | bufferedStream: false, | |
57 | } | |
58 | for _, option := range options { | |
59 | option(c) | |
60 | } | |
61 | return c | |
62 | } | |
63 | ||
64 | // DefaultRequestEncoder marshals the given request to JSON. | |
65 | func DefaultRequestEncoder(_ context.Context, req interface{}) (json.RawMessage, error) { | |
66 | return json.Marshal(req) | |
67 | } | |
68 | ||
69 | // DefaultResponseDecoder unmarshals the given JSON to interface{}. | |
70 | func DefaultResponseDecoder(_ context.Context, res json.RawMessage) (interface{}, error) { | |
71 | var result interface{} | |
72 | err := json.Unmarshal(res, &result) | |
73 | if err != nil { | |
74 | return nil, err | |
75 | } | |
76 | return result, nil | |
77 | } | |
78 | ||
79 | // ClientOption sets an optional parameter for clients. | |
80 | type ClientOption func(*Client) | |
81 | ||
82 | // SetClient sets the underlying HTTP client used for requests. | |
83 | // By default, http.DefaultClient is used. | |
84 | func SetClient(client *http.Client) ClientOption { | |
85 | return func(c *Client) { c.client = client } | |
86 | } | |
87 | ||
88 | // ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP | |
89 | // request before it's invoked. | |
90 | func ClientBefore(before ...httptransport.RequestFunc) ClientOption { | |
91 | return func(c *Client) { c.before = append(c.before, before...) } | |
92 | } | |
93 | ||
94 | // ClientAfter sets the ClientResponseFuncs applied to the server's HTTP | |
95 | // response prior to it being decoded. This is useful for obtaining anything | |
96 | // from the response and adding onto the context prior to decoding. | |
97 | func ClientAfter(after ...httptransport.ClientResponseFunc) ClientOption { | |
98 | return func(c *Client) { c.after = append(c.after, after...) } | |
99 | } | |
100 | ||
101 | // ClientFinalizer is executed at the end of every HTTP request. | |
102 | // By default, no finalizer is registered. | |
103 | func ClientFinalizer(f httptransport.ClientFinalizerFunc) ClientOption { | |
104 | return func(c *Client) { c.finalizer = f } | |
105 | } | |
106 | ||
107 | // ClientRequestEncoder sets the func used to encode the request params to JSON. | |
108 | // If not set, DefaultRequestEncoder is used. | |
109 | func ClientRequestEncoder(enc EncodeRequestFunc) ClientOption { | |
110 | return func(c *Client) { c.enc = enc } | |
111 | } | |
112 | ||
113 | // ClientResponseDecoder sets the func used to decode the response params from | |
114 | // JSON. If not set, DefaultResponseDecoder is used. | |
115 | func ClientResponseDecoder(dec DecodeResponseFunc) ClientOption { | |
116 | return func(c *Client) { c.dec = dec } | |
117 | } | |
118 | ||
119 | // RequestIDGenerator returns an ID for the request. | |
120 | type RequestIDGenerator interface { | |
121 | Generate() interface{} | |
122 | } | |
123 | ||
124 | // ClientRequestIDGenerator is executed before each request to generate an ID | |
125 | // for the request. | |
126 | // By default, AutoIncrementRequestID is used. | |
127 | func ClientRequestIDGenerator(g RequestIDGenerator) ClientOption { | |
128 | return func(c *Client) { c.requestID = g } | |
129 | } | |
130 | ||
131 | // BufferedStream sets whether the Response.Body is left open, allowing it | |
132 | // to be read from later. Useful for transporting a file as a buffered stream. | |
133 | func BufferedStream(buffered bool) ClientOption { | |
134 | return func(c *Client) { c.bufferedStream = buffered } | |
135 | } | |
136 | ||
137 | // Endpoint returns a usable endpoint that invokes the remote endpoint. | |
138 | func (c Client) Endpoint() endpoint.Endpoint { | |
139 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
140 | ctx, cancel := context.WithCancel(ctx) | |
141 | defer cancel() | |
142 | ||
143 | var ( | |
144 | resp *http.Response | |
145 | err error | |
146 | ) | |
147 | if c.finalizer != nil { | |
148 | defer func() { | |
149 | if resp != nil { | |
150 | ctx = context.WithValue(ctx, httptransport.ContextKeyResponseHeaders, resp.Header) | |
151 | ctx = context.WithValue(ctx, httptransport.ContextKeyResponseSize, resp.ContentLength) | |
152 | } | |
153 | c.finalizer(ctx, err) | |
154 | }() | |
155 | } | |
156 | ||
157 | var params json.RawMessage | |
158 | if params, err = c.enc(ctx, request); err != nil { | |
159 | return nil, err | |
160 | } | |
161 | rpcReq := clientRequest{ | |
162 | JSONRPC: "", | |
163 | Method: c.method, | |
164 | Params: params, | |
165 | ID: c.requestID.Generate(), | |
166 | } | |
167 | ||
168 | req, err := http.NewRequest("POST", c.tgt.String(), nil) | |
169 | if err != nil { | |
170 | return nil, err | |
171 | } | |
172 | ||
173 | req.Header.Set("Content-Type", "application/json; charset=utf-8") | |
174 | var b bytes.Buffer | |
175 | req.Body = ioutil.NopCloser(&b) | |
176 | err = json.NewEncoder(&b).Encode(rpcReq) | |
177 | if err != nil { | |
178 | return nil, err | |
179 | } | |
180 | ||
181 | for _, f := range c.before { | |
182 | ctx = f(ctx, req) | |
183 | } | |
184 | ||
185 | resp, err = c.client.Do(req.WithContext(ctx)) | |
186 | if err != nil { | |
187 | return nil, err | |
188 | } | |
189 | ||
190 | if !c.bufferedStream { | |
191 | defer resp.Body.Close() | |
192 | } | |
193 | ||
194 | // Decode the body into an object | |
195 | var rpcRes Response | |
196 | err = json.NewDecoder(resp.Body).Decode(&rpcRes) | |
197 | if err != nil { | |
198 | return nil, err | |
199 | } | |
200 | ||
201 | for _, f := range c.after { | |
202 | ctx = f(ctx, resp) | |
203 | } | |
204 | ||
205 | return c.dec(ctx, rpcRes.Result) | |
206 | } | |
207 | } | |
208 | ||
209 | // ClientFinalizerFunc can be used to perform work at the end of a client HTTP | |
210 | // request, after the response is returned. The principal | |
211 | // intended use is for error logging. Additional response parameters are | |
212 | // provided in the context under keys with the ContextKeyResponse prefix. | |
213 | // Note: err may be nil. There maybe also no additional response parameters | |
214 | // depending on when an error occurs. | |
215 | type ClientFinalizerFunc func(ctx context.Context, err error) | |
216 | ||
217 | // autoIncrementID is a RequestIDGenerator that generates | |
218 | // auto-incrementing integer IDs. | |
219 | type autoIncrementID struct { | |
220 | v *uint64 | |
221 | } | |
222 | ||
223 | // NewAutoIncrementID returns an auto-incrementing request ID generator, | |
224 | // initialised with the given value. | |
225 | func NewAutoIncrementID(init uint64) RequestIDGenerator { | |
226 | // Offset by one so that the first generated value = init. | |
227 | v := init - 1 | |
228 | return &autoIncrementID{v: &v} | |
229 | } | |
230 | ||
231 | // Generate satisfies RequestIDGenerator | |
232 | func (i *autoIncrementID) Generate() interface{} { | |
233 | id := atomic.AddUint64(i.v, 1) | |
234 | return id | |
235 | } |
0 | package jsonrpc_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "io" | |
6 | "io/ioutil" | |
7 | "net/http" | |
8 | "net/http/httptest" | |
9 | "net/url" | |
10 | "testing" | |
11 | ||
12 | "github.com/go-kit/kit/transport/http/jsonrpc" | |
13 | ) | |
14 | ||
15 | type TestResponse struct { | |
16 | Body io.ReadCloser | |
17 | String string | |
18 | } | |
19 | ||
20 | func TestCanCallBeforeFunc(t *testing.T) { | |
21 | called := false | |
22 | u, _ := url.Parse("http://senseye.io/jsonrpc") | |
23 | sut := jsonrpc.NewClient( | |
24 | u, | |
25 | "add", | |
26 | jsonrpc.ClientBefore(func(ctx context.Context, req *http.Request) context.Context { | |
27 | called = true | |
28 | return ctx | |
29 | }), | |
30 | ) | |
31 | ||
32 | sut.Endpoint()(context.TODO(), "foo") | |
33 | ||
34 | if !called { | |
35 | t.Fatal("Expected client before func to be called. Wasn't.") | |
36 | } | |
37 | } | |
38 | ||
39 | type staticIDGenerator int | |
40 | ||
41 | func (g staticIDGenerator) Generate() interface{} { return g } | |
42 | ||
43 | func TestClientHappyPath(t *testing.T) { | |
44 | var ( | |
45 | afterCalledKey = "AC" | |
46 | beforeHeaderKey = "BF" | |
47 | beforeHeaderValue = "beforeFuncWozEre" | |
48 | testbody = `{"jsonrpc":"2.0", "result":5}` | |
49 | requestBody []byte | |
50 | beforeFunc = func(ctx context.Context, r *http.Request) context.Context { | |
51 | r.Header.Add(beforeHeaderKey, beforeHeaderValue) | |
52 | return ctx | |
53 | } | |
54 | encode = func(ctx context.Context, req interface{}) (json.RawMessage, error) { | |
55 | return json.Marshal(req) | |
56 | } | |
57 | afterFunc = func(ctx context.Context, r *http.Response) context.Context { | |
58 | return context.WithValue(ctx, afterCalledKey, true) | |
59 | } | |
60 | finalizerCalled = false | |
61 | fin = func(ctx context.Context, err error) { | |
62 | finalizerCalled = true | |
63 | } | |
64 | decode = func(ctx context.Context, res json.RawMessage) (interface{}, error) { | |
65 | if ac := ctx.Value(afterCalledKey); ac == nil { | |
66 | t.Fatal("after not called") | |
67 | } | |
68 | var result int | |
69 | err := json.Unmarshal(res, &result) | |
70 | if err != nil { | |
71 | return nil, err | |
72 | } | |
73 | return result, nil | |
74 | } | |
75 | ||
76 | wantID = 666 | |
77 | gen = staticIDGenerator(wantID) | |
78 | ) | |
79 | ||
80 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
81 | if r.Header.Get(beforeHeaderKey) != beforeHeaderValue { | |
82 | t.Fatal("Header not set by before func.") | |
83 | } | |
84 | ||
85 | b, err := ioutil.ReadAll(r.Body) | |
86 | if err != nil && err != io.EOF { | |
87 | t.Fatal(err) | |
88 | } | |
89 | requestBody = b | |
90 | ||
91 | w.WriteHeader(http.StatusOK) | |
92 | w.Write([]byte(testbody)) | |
93 | })) | |
94 | ||
95 | sut := jsonrpc.NewClient( | |
96 | mustParse(server.URL), | |
97 | "add", | |
98 | jsonrpc.ClientRequestEncoder(encode), | |
99 | jsonrpc.ClientResponseDecoder(decode), | |
100 | jsonrpc.ClientBefore(beforeFunc), | |
101 | jsonrpc.ClientAfter(afterFunc), | |
102 | jsonrpc.ClientRequestIDGenerator(gen), | |
103 | jsonrpc.ClientFinalizer(fin), | |
104 | jsonrpc.SetClient(http.DefaultClient), | |
105 | jsonrpc.BufferedStream(false), | |
106 | ) | |
107 | ||
108 | type addRequest struct { | |
109 | A int | |
110 | B int | |
111 | } | |
112 | ||
113 | in := addRequest{2, 2} | |
114 | ||
115 | result, err := sut.Endpoint()(context.Background(), in) | |
116 | if err != nil { | |
117 | t.Fatal(err) | |
118 | } | |
119 | ri, ok := result.(int) | |
120 | if !ok { | |
121 | t.Fatalf("result is not int: (%T)%+v", result, result) | |
122 | } | |
123 | if ri != 5 { | |
124 | t.Fatalf("want=5, got=%d", ri) | |
125 | } | |
126 | ||
127 | var requestAtServer jsonrpc.Request | |
128 | err = json.Unmarshal(requestBody, &requestAtServer) | |
129 | if err != nil { | |
130 | t.Fatal(err) | |
131 | } | |
132 | if id, _ := requestAtServer.ID.Int(); id != wantID { | |
133 | t.Fatalf("Request ID at server: want=%d, got=%d", wantID, id) | |
134 | } | |
135 | ||
136 | var paramsAtServer addRequest | |
137 | err = json.Unmarshal(requestAtServer.Params, ¶msAtServer) | |
138 | if err != nil { | |
139 | t.Fatal(err) | |
140 | } | |
141 | ||
142 | if paramsAtServer != in { | |
143 | t.Fatalf("want=%+v, got=%+v", in, paramsAtServer) | |
144 | } | |
145 | ||
146 | if !finalizerCalled { | |
147 | t.Fatal("Expected finalizer to be called. Wasn't.") | |
148 | } | |
149 | } | |
150 | ||
151 | func TestCanUseDefaults(t *testing.T) { | |
152 | var ( | |
153 | testbody = `{"jsonrpc":"2.0", "result":"boogaloo"}` | |
154 | requestBody []byte | |
155 | ) | |
156 | ||
157 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
158 | b, err := ioutil.ReadAll(r.Body) | |
159 | if err != nil && err != io.EOF { | |
160 | t.Fatal(err) | |
161 | } | |
162 | requestBody = b | |
163 | ||
164 | w.WriteHeader(http.StatusOK) | |
165 | w.Write([]byte(testbody)) | |
166 | })) | |
167 | ||
168 | sut := jsonrpc.NewClient( | |
169 | mustParse(server.URL), | |
170 | "add", | |
171 | ) | |
172 | ||
173 | type addRequest struct { | |
174 | A int | |
175 | B int | |
176 | } | |
177 | ||
178 | in := addRequest{2, 2} | |
179 | ||
180 | result, err := sut.Endpoint()(context.Background(), in) | |
181 | if err != nil { | |
182 | t.Fatal(err) | |
183 | } | |
184 | rs, ok := result.(string) | |
185 | if !ok { | |
186 | t.Fatalf("result is not string: (%T)%+v", result, result) | |
187 | } | |
188 | if rs != "boogaloo" { | |
189 | t.Fatalf("want=boogaloo, got=%s", rs) | |
190 | } | |
191 | ||
192 | var requestAtServer jsonrpc.Request | |
193 | err = json.Unmarshal(requestBody, &requestAtServer) | |
194 | if err != nil { | |
195 | t.Fatal(err) | |
196 | } | |
197 | var paramsAtServer addRequest | |
198 | err = json.Unmarshal(requestAtServer.Params, ¶msAtServer) | |
199 | if err != nil { | |
200 | t.Fatal(err) | |
201 | } | |
202 | ||
203 | if paramsAtServer != in { | |
204 | t.Fatalf("want=%+v, got=%+v", in, paramsAtServer) | |
205 | } | |
206 | } | |
207 | ||
208 | func TestDefaultAutoIncrementer(t *testing.T) { | |
209 | sut := jsonrpc.NewAutoIncrementID(0) | |
210 | var want uint64 | |
211 | for ; want < 100; want++ { | |
212 | got := sut.Generate() | |
213 | if got != want { | |
214 | t.Fatalf("want=%d, got=%d", want, got) | |
215 | } | |
216 | } | |
217 | } | |
218 | ||
219 | func mustParse(s string) *url.URL { | |
220 | u, err := url.Parse(s) | |
221 | if err != nil { | |
222 | panic(err) | |
223 | } | |
224 | return u | |
225 | } |
0 | // Package jsonrpc provides a JSON RPC (v2.0) binding for endpoints. | |
1 | // See http://www.jsonrpc.org/specification | |
2 | package jsonrpc |
0 | package jsonrpc | |
1 | ||
2 | import ( | |
3 | "encoding/json" | |
4 | ||
5 | "github.com/go-kit/kit/endpoint" | |
6 | ||
7 | "context" | |
8 | ) | |
9 | ||
10 | // Server-Side Codec | |
11 | ||
12 | // EndpointCodec defines a server Endpoint and its associated codecs | |
13 | type EndpointCodec struct { | |
14 | Endpoint endpoint.Endpoint | |
15 | Decode DecodeRequestFunc | |
16 | Encode EncodeResponseFunc | |
17 | } | |
18 | ||
19 | // EndpointCodecMap maps the Request.Method to the proper EndpointCodec | |
20 | type EndpointCodecMap map[string]EndpointCodec | |
21 | ||
22 | // DecodeRequestFunc extracts a user-domain request object from an raw JSON | |
23 | // It's designed to be used in HTTP servers, for server-side endpoints. | |
24 | // One straightforward DecodeRequestFunc could be something that unmarshals | |
25 | // JSON from the request body to the concrete request type. | |
26 | type DecodeRequestFunc func(context.Context, json.RawMessage) (request interface{}, err error) | |
27 | ||
28 | // EncodeResponseFunc encodes the passed response object to a JSON RPC response. | |
29 | // It's designed to be used in HTTP servers, for server-side endpoints. | |
30 | // One straightforward EncodeResponseFunc could be something that JSON encodes | |
31 | // the object directly. | |
32 | type EncodeResponseFunc func(context.Context, interface{}) (response json.RawMessage, err error) | |
33 | ||
34 | // Client-Side Codec | |
35 | ||
36 | // EncodeRequestFunc encodes the passed request object to raw JSON. | |
37 | // It's designed to be used in JSON RPC clients, for client-side | |
38 | // endpoints. One straightforward EncodeResponseFunc could be something that | |
39 | // JSON encodes the object directly. | |
40 | type EncodeRequestFunc func(context.Context, interface{}) (request json.RawMessage, err error) | |
41 | ||
42 | // DecodeResponseFunc extracts a user-domain response object from an HTTP | |
43 | // request object. It's designed to be used in JSON RPC clients, for | |
44 | // client-side endpoints. One straightforward DecodeRequestFunc could be | |
45 | // something that JSON decodes from the request body to the concrete | |
46 | // response type. | |
47 | type DecodeResponseFunc func(context.Context, json.RawMessage) (response interface{}, err error) |
0 | package jsonrpc | |
1 | ||
2 | // Error defines a JSON RPC error that can be returned | |
3 | // in a Response from the spec | |
4 | // http://www.jsonrpc.org/specification#error_object | |
5 | type Error struct { | |
6 | Code int `json:"code"` | |
7 | Message string `json:"message"` | |
8 | Data interface{} `json:"data,omitempty"` | |
9 | } | |
10 | ||
11 | // Error implements error. | |
12 | func (e Error) Error() string { | |
13 | if e.Message != "" { | |
14 | return e.Message | |
15 | } | |
16 | return errorMessage[e.Code] | |
17 | } | |
18 | ||
19 | // ErrorCode returns the JSON RPC error code associated with the error. | |
20 | func (e Error) ErrorCode() int { | |
21 | return e.Code | |
22 | } | |
23 | ||
24 | const ( | |
25 | // ParseError defines invalid JSON was received by the server. | |
26 | // An error occurred on the server while parsing the JSON text. | |
27 | ParseError int = -32700 | |
28 | ||
29 | // InvalidRequestError defines the JSON sent is not a valid Request object. | |
30 | InvalidRequestError int = -32600 | |
31 | ||
32 | // MethodNotFoundError defines the method does not exist / is not available. | |
33 | MethodNotFoundError int = -32601 | |
34 | ||
35 | // InvalidParamsError defines invalid method parameter(s). | |
36 | InvalidParamsError int = -32602 | |
37 | ||
38 | // InternalError defines a server error | |
39 | InternalError int = -32603 | |
40 | ) | |
41 | ||
42 | var errorMessage = map[int]string{ | |
43 | ParseError: "An error occurred on the server while parsing the JSON text.", | |
44 | InvalidRequestError: "The JSON sent is not a valid Request object.", | |
45 | MethodNotFoundError: "The method does not exist / is not available.", | |
46 | InvalidParamsError: "Invalid method parameter(s).", | |
47 | InternalError: "Internal JSON-RPC error.", | |
48 | } | |
49 | ||
50 | // ErrorMessage returns a message for the JSON RPC error code. It returns the empty | |
51 | // string if the code is unknown. | |
52 | func ErrorMessage(code int) string { | |
53 | return errorMessage[code] | |
54 | } | |
55 | ||
56 | type parseError string | |
57 | ||
58 | func (e parseError) Error() string { | |
59 | return string(e) | |
60 | } | |
61 | func (e parseError) ErrorCode() int { | |
62 | return ParseError | |
63 | } | |
64 | ||
65 | type invalidRequestError string | |
66 | ||
67 | func (e invalidRequestError) Error() string { | |
68 | return string(e) | |
69 | } | |
70 | func (e invalidRequestError) ErrorCode() int { | |
71 | return InvalidRequestError | |
72 | } | |
73 | ||
74 | type methodNotFoundError string | |
75 | ||
76 | func (e methodNotFoundError) Error() string { | |
77 | return string(e) | |
78 | } | |
79 | func (e methodNotFoundError) ErrorCode() int { | |
80 | return MethodNotFoundError | |
81 | } | |
82 | ||
83 | type invalidParamsError string | |
84 | ||
85 | func (e invalidParamsError) Error() string { | |
86 | return string(e) | |
87 | } | |
88 | func (e invalidParamsError) ErrorCode() int { | |
89 | return InvalidParamsError | |
90 | } | |
91 | ||
92 | type internalError string | |
93 | ||
94 | func (e internalError) Error() string { | |
95 | return string(e) | |
96 | } | |
97 | func (e internalError) ErrorCode() int { | |
98 | return InternalError | |
99 | } |
0 | package jsonrpc | |
1 | ||
2 | import "testing" | |
3 | ||
4 | func TestError(t *testing.T) { | |
5 | wantCode := ParseError | |
6 | sut := Error{ | |
7 | Code: wantCode, | |
8 | } | |
9 | ||
10 | gotCode := sut.ErrorCode() | |
11 | if gotCode != wantCode { | |
12 | t.Fatalf("want=%d, got=%d", gotCode, wantCode) | |
13 | } | |
14 | ||
15 | if sut.Error() == "" { | |
16 | t.Fatal("Empty error string.") | |
17 | } | |
18 | ||
19 | want := "override" | |
20 | sut.Message = want | |
21 | got := sut.Error() | |
22 | if sut.Error() != want { | |
23 | t.Fatalf("overridden error message: want=%s, got=%s", want, got) | |
24 | } | |
25 | ||
26 | } | |
27 | func TestErrorsSatisfyError(t *testing.T) { | |
28 | errs := []interface{}{ | |
29 | parseError("parseError"), | |
30 | invalidRequestError("invalidRequestError"), | |
31 | methodNotFoundError("methodNotFoundError"), | |
32 | invalidParamsError("invalidParamsError"), | |
33 | internalError("internalError"), | |
34 | } | |
35 | for _, e := range errs { | |
36 | err, ok := e.(error) | |
37 | if !ok { | |
38 | t.Fatalf("Couldn't assert %s as error.", e) | |
39 | } | |
40 | errString := err.Error() | |
41 | if errString == "" { | |
42 | t.Fatal("Empty error string") | |
43 | } | |
44 | ||
45 | ec, ok := e.(ErrorCoder) | |
46 | if !ok { | |
47 | t.Fatalf("Couldn't assert %s as ErrorCoder.", e) | |
48 | } | |
49 | if ErrorMessage(ec.ErrorCode()) == "" { | |
50 | t.Fatalf("Error type %s returned code of %d, which does not map to error string", e, ec.ErrorCode()) | |
51 | } | |
52 | } | |
53 | } |
0 | package jsonrpc | |
1 | ||
2 | import "encoding/json" | |
3 | ||
4 | // Request defines a JSON RPC request from the spec | |
5 | // http://www.jsonrpc.org/specification#request_object | |
6 | type Request struct { | |
7 | JSONRPC string `json:"jsonrpc"` | |
8 | Method string `json:"method"` | |
9 | Params json.RawMessage `json:"params"` | |
10 | ID *RequestID `json:"id"` | |
11 | } | |
12 | ||
13 | // RequestID defines a request ID that can be string, number, or null. | |
14 | // An identifier established by the Client that MUST contain a String, | |
15 | // Number, or NULL value if included. | |
16 | // If it is not included it is assumed to be a notification. | |
17 | // The value SHOULD normally not be Null and | |
18 | // Numbers SHOULD NOT contain fractional parts. | |
19 | type RequestID struct { | |
20 | intValue int | |
21 | intError error | |
22 | floatValue float32 | |
23 | floatError error | |
24 | stringValue string | |
25 | stringError error | |
26 | } | |
27 | ||
28 | // UnmarshalJSON satisfies json.Unmarshaler | |
29 | func (id *RequestID) UnmarshalJSON(b []byte) error { | |
30 | id.intError = json.Unmarshal(b, &id.intValue) | |
31 | id.floatError = json.Unmarshal(b, &id.floatValue) | |
32 | id.stringError = json.Unmarshal(b, &id.stringValue) | |
33 | ||
34 | return nil | |
35 | } | |
36 | ||
37 | // Int returns the ID as an integer value. | |
38 | // An error is returned if the ID can't be treated as an int. | |
39 | func (id *RequestID) Int() (int, error) { | |
40 | return id.intValue, id.intError | |
41 | } | |
42 | ||
43 | // Float32 returns the ID as a float value. | |
44 | // An error is returned if the ID can't be treated as an float. | |
45 | func (id *RequestID) Float32() (float32, error) { | |
46 | return id.floatValue, id.floatError | |
47 | } | |
48 | ||
49 | // String returns the ID as a string value. | |
50 | // An error is returned if the ID can't be treated as an string. | |
51 | func (id *RequestID) String() (string, error) { | |
52 | return id.stringValue, id.stringError | |
53 | } | |
54 | ||
55 | // Response defines a JSON RPC response from the spec | |
56 | // http://www.jsonrpc.org/specification#response_object | |
57 | type Response struct { | |
58 | JSONRPC string `json:"jsonrpc"` | |
59 | Result json.RawMessage `json:"result,omitempty"` | |
60 | Error *Error `json:"error,omitemty"` | |
61 | } | |
62 | ||
63 | const ( | |
64 | // Version defines the version of the JSON RPC implementation | |
65 | Version string = "2.0" | |
66 | ||
67 | // ContentType defines the content type to be served. | |
68 | ContentType string = "application/json; charset=utf-8" | |
69 | ) |
0 | package jsonrpc_test | |
1 | ||
2 | import ( | |
3 | "encoding/json" | |
4 | "fmt" | |
5 | "testing" | |
6 | ||
7 | "github.com/go-kit/kit/transport/http/jsonrpc" | |
8 | ) | |
9 | ||
10 | func TestCanUnMarshalID(t *testing.T) { | |
11 | cases := []struct { | |
12 | JSON string | |
13 | expType string | |
14 | expValue interface{} | |
15 | }{ | |
16 | {`12345`, "int", 12345}, | |
17 | {`12345.6`, "float", 12345.6}, | |
18 | {`"stringaling"`, "string", "stringaling"}, | |
19 | } | |
20 | ||
21 | for _, c := range cases { | |
22 | r := jsonrpc.Request{} | |
23 | JSON := fmt.Sprintf(`{"id":%s}`, c.JSON) | |
24 | ||
25 | var foo interface{} | |
26 | _ = json.Unmarshal([]byte(JSON), &foo) | |
27 | ||
28 | err := json.Unmarshal([]byte(JSON), &r) | |
29 | if err != nil { | |
30 | t.Fatalf("Unexpected error unmarshaling JSON into request: %s\n", err) | |
31 | } | |
32 | id := r.ID | |
33 | ||
34 | switch c.expType { | |
35 | case "int": | |
36 | want := c.expValue.(int) | |
37 | got, err := id.Int() | |
38 | if err != nil { | |
39 | t.Fatal(err) | |
40 | } | |
41 | if got != want { | |
42 | t.Fatalf("'%s' Int(): want %d, got %d.", c.JSON, want, got) | |
43 | } | |
44 | ||
45 | // Allow an int ID to be interpreted as a float. | |
46 | wantf := float32(c.expValue.(int)) | |
47 | gotf, err := id.Float32() | |
48 | if err != nil { | |
49 | t.Fatal(err) | |
50 | } | |
51 | if gotf != wantf { | |
52 | t.Fatalf("'%s' Int value as Float32(): want %f, got %f.", c.JSON, wantf, gotf) | |
53 | } | |
54 | ||
55 | _, err = id.String() | |
56 | if err == nil { | |
57 | t.Fatal("Expected String() to error for int value. Didn't.") | |
58 | } | |
59 | case "string": | |
60 | want := c.expValue.(string) | |
61 | got, err := id.String() | |
62 | if err != nil { | |
63 | t.Fatal(err) | |
64 | } | |
65 | if got != want { | |
66 | t.Fatalf("'%s' String(): want %s, got %s.", c.JSON, want, got) | |
67 | } | |
68 | ||
69 | _, err = id.Int() | |
70 | if err == nil { | |
71 | t.Fatal("Expected Int() to error for string value. Didn't.") | |
72 | } | |
73 | _, err = id.Float32() | |
74 | if err == nil { | |
75 | t.Fatal("Expected Float32() to error for string value. Didn't.") | |
76 | } | |
77 | case "float32": | |
78 | want := c.expValue.(float32) | |
79 | got, err := id.Float32() | |
80 | if err != nil { | |
81 | t.Fatal(err) | |
82 | } | |
83 | if got != want { | |
84 | t.Fatalf("'%s' Float32(): want %f, got %f.", c.JSON, want, got) | |
85 | } | |
86 | ||
87 | _, err = id.String() | |
88 | if err == nil { | |
89 | t.Fatal("Expected String() to error for float value. Didn't.") | |
90 | } | |
91 | _, err = id.Int() | |
92 | if err == nil { | |
93 | t.Fatal("Expected Int() to error for float value. Didn't.") | |
94 | } | |
95 | } | |
96 | } | |
97 | } | |
98 | ||
99 | func TestCanUnmarshalNullID(t *testing.T) { | |
100 | r := jsonrpc.Request{} | |
101 | JSON := `{"id":null}` | |
102 | err := json.Unmarshal([]byte(JSON), &r) | |
103 | if err != nil { | |
104 | t.Fatalf("Unexpected error unmarshaling JSON into request: %s\n", err) | |
105 | } | |
106 | ||
107 | if r.ID != nil { | |
108 | t.Fatalf("Expected ID to be nil, got %+v.\n", r.ID) | |
109 | } | |
110 | } |
0 | package jsonrpc | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "fmt" | |
6 | "io" | |
7 | "net/http" | |
8 | ||
9 | "github.com/go-kit/kit/log" | |
10 | httptransport "github.com/go-kit/kit/transport/http" | |
11 | ) | |
12 | ||
13 | // Server wraps an endpoint and implements http.Handler. | |
14 | type Server struct { | |
15 | ecm EndpointCodecMap | |
16 | before []httptransport.RequestFunc | |
17 | after []httptransport.ServerResponseFunc | |
18 | errorEncoder httptransport.ErrorEncoder | |
19 | finalizer httptransport.ServerFinalizerFunc | |
20 | logger log.Logger | |
21 | } | |
22 | ||
23 | // NewServer constructs a new server, which implements http.Server. | |
24 | func NewServer( | |
25 | ecm EndpointCodecMap, | |
26 | options ...ServerOption, | |
27 | ) *Server { | |
28 | s := &Server{ | |
29 | ecm: ecm, | |
30 | errorEncoder: DefaultErrorEncoder, | |
31 | logger: log.NewNopLogger(), | |
32 | } | |
33 | for _, option := range options { | |
34 | option(s) | |
35 | } | |
36 | return s | |
37 | } | |
38 | ||
39 | // ServerOption sets an optional parameter for servers. | |
40 | type ServerOption func(*Server) | |
41 | ||
42 | // ServerBefore functions are executed on the HTTP request object before the | |
43 | // request is decoded. | |
44 | func ServerBefore(before ...httptransport.RequestFunc) ServerOption { | |
45 | return func(s *Server) { s.before = append(s.before, before...) } | |
46 | } | |
47 | ||
48 | // ServerAfter functions are executed on the HTTP response writer after the | |
49 | // endpoint is invoked, but before anything is written to the client. | |
50 | func ServerAfter(after ...httptransport.ServerResponseFunc) ServerOption { | |
51 | return func(s *Server) { s.after = append(s.after, after...) } | |
52 | } | |
53 | ||
54 | // ServerErrorEncoder is used to encode errors to the http.ResponseWriter | |
55 | // whenever they're encountered in the processing of a request. Clients can | |
56 | // use this to provide custom error formatting and response codes. By default, | |
57 | // errors will be written with the DefaultErrorEncoder. | |
58 | func ServerErrorEncoder(ee httptransport.ErrorEncoder) ServerOption { | |
59 | return func(s *Server) { s.errorEncoder = ee } | |
60 | } | |
61 | ||
62 | // ServerErrorLogger is used to log non-terminal errors. By default, no errors | |
63 | // are logged. This is intended as a diagnostic measure. Finer-grained control | |
64 | // of error handling, including logging in more detail, should be performed in a | |
65 | // custom ServerErrorEncoder or ServerFinalizer, both of which have access to | |
66 | // the context. | |
67 | func ServerErrorLogger(logger log.Logger) ServerOption { | |
68 | return func(s *Server) { s.logger = logger } | |
69 | } | |
70 | ||
71 | // ServerFinalizer is executed at the end of every HTTP request. | |
72 | // By default, no finalizer is registered. | |
73 | func ServerFinalizer(f httptransport.ServerFinalizerFunc) ServerOption { | |
74 | return func(s *Server) { s.finalizer = f } | |
75 | } | |
76 | ||
77 | // ServeHTTP implements http.Handler. | |
78 | func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
79 | if r.Method != http.MethodPost { | |
80 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") | |
81 | w.WriteHeader(http.StatusMethodNotAllowed) | |
82 | _, _ = io.WriteString(w, "405 must POST\n") | |
83 | return | |
84 | } | |
85 | ctx := r.Context() | |
86 | ||
87 | if s.finalizer != nil { | |
88 | iw := &interceptingWriter{w, http.StatusOK} | |
89 | defer func() { s.finalizer(ctx, iw.code, r) }() | |
90 | w = iw | |
91 | } | |
92 | ||
93 | for _, f := range s.before { | |
94 | ctx = f(ctx, r) | |
95 | } | |
96 | ||
97 | // Decode the body into an object | |
98 | var req Request | |
99 | err := json.NewDecoder(r.Body).Decode(&req) | |
100 | if err != nil { | |
101 | rpcerr := parseError("JSON could not be decoded: " + err.Error()) | |
102 | s.logger.Log("err", rpcerr) | |
103 | s.errorEncoder(ctx, rpcerr, w) | |
104 | return | |
105 | } | |
106 | ||
107 | // Get the endpoint and codecs from the map using the method | |
108 | // defined in the JSON object | |
109 | ecm, ok := s.ecm[req.Method] | |
110 | if !ok { | |
111 | err := methodNotFoundError(fmt.Sprintf("Method %s was not found.", req.Method)) | |
112 | s.logger.Log("err", err) | |
113 | s.errorEncoder(ctx, err, w) | |
114 | return | |
115 | } | |
116 | ||
117 | // Decode the JSON "params" | |
118 | reqParams, err := ecm.Decode(ctx, req.Params) | |
119 | if err != nil { | |
120 | s.logger.Log("err", err) | |
121 | s.errorEncoder(ctx, err, w) | |
122 | return | |
123 | } | |
124 | ||
125 | // Call the Endpoint with the params | |
126 | response, err := ecm.Endpoint(ctx, reqParams) | |
127 | if err != nil { | |
128 | s.logger.Log("err", err) | |
129 | s.errorEncoder(ctx, err, w) | |
130 | return | |
131 | } | |
132 | ||
133 | for _, f := range s.after { | |
134 | ctx = f(ctx, w) | |
135 | } | |
136 | ||
137 | res := Response{ | |
138 | JSONRPC: Version, | |
139 | } | |
140 | ||
141 | // Encode the response from the Endpoint | |
142 | resParams, err := ecm.Encode(ctx, response) | |
143 | if err != nil { | |
144 | s.logger.Log("err", err) | |
145 | s.errorEncoder(ctx, err, w) | |
146 | return | |
147 | } | |
148 | ||
149 | res.Result = resParams | |
150 | ||
151 | w.Header().Set("Content-Type", ContentType) | |
152 | _ = json.NewEncoder(w).Encode(res) | |
153 | } | |
154 | ||
155 | // DefaultErrorEncoder writes the error to the ResponseWriter, | |
156 | // as a json-rpc error response, with an InternalError status code. | |
157 | // The Error() string of the error will be used as the response error message. | |
158 | // If the error implements ErrorCoder, the provided code will be set on the | |
159 | // response error. | |
160 | // If the error implements Headerer, the given headers will be set. | |
161 | func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { | |
162 | w.Header().Set("Content-Type", ContentType) | |
163 | if headerer, ok := err.(httptransport.Headerer); ok { | |
164 | for k := range headerer.Headers() { | |
165 | w.Header().Set(k, headerer.Headers().Get(k)) | |
166 | } | |
167 | } | |
168 | ||
169 | e := Error{ | |
170 | Code: InternalError, | |
171 | Message: err.Error(), | |
172 | } | |
173 | if sc, ok := err.(ErrorCoder); ok { | |
174 | e.Code = sc.ErrorCode() | |
175 | } | |
176 | ||
177 | w.WriteHeader(http.StatusOK) | |
178 | _ = json.NewEncoder(w).Encode(Response{ | |
179 | JSONRPC: Version, | |
180 | Error: &e, | |
181 | }) | |
182 | } | |
183 | ||
184 | // ErrorCoder is checked by DefaultErrorEncoder. If an error value implements | |
185 | // ErrorCoder, the integer result of ErrorCode() will be used as the JSONRPC | |
186 | // error code when encoding the error. | |
187 | // | |
188 | // By default, InternalError (-32603) is used. | |
189 | type ErrorCoder interface { | |
190 | ErrorCode() int | |
191 | } | |
192 | ||
193 | // interceptingWriter intercepts calls to WriteHeader, so that a finalizer | |
194 | // can be given the correct status code. | |
195 | type interceptingWriter struct { | |
196 | http.ResponseWriter | |
197 | code int | |
198 | } | |
199 | ||
200 | // WriteHeader may not be explicitly called, so care must be taken to | |
201 | // initialize w.code to its default value of http.StatusOK. | |
202 | func (w *interceptingWriter) WriteHeader(code int) { | |
203 | w.code = code | |
204 | w.ResponseWriter.WriteHeader(code) | |
205 | } |
0 | package jsonrpc_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "io" | |
7 | "io/ioutil" | |
8 | "net/http" | |
9 | "net/http/httptest" | |
10 | "strings" | |
11 | "testing" | |
12 | "time" | |
13 | ||
14 | "github.com/go-kit/kit/endpoint" | |
15 | "github.com/go-kit/kit/transport/http/jsonrpc" | |
16 | ) | |
17 | ||
18 | func addBody() io.Reader { | |
19 | return body(`{"jsonrpc": "2.0", "method": "add", "params": [3, 2], "id": 1}`) | |
20 | } | |
21 | ||
22 | func body(in string) io.Reader { | |
23 | return strings.NewReader(in) | |
24 | } | |
25 | ||
26 | func expectErrorCode(t *testing.T, want int, body []byte) { | |
27 | var r jsonrpc.Response | |
28 | err := json.Unmarshal(body, &r) | |
29 | if err != nil { | |
30 | t.Fatalf("Cant' decode response. err=%s, body=%s", err, body) | |
31 | } | |
32 | if r.Error == nil { | |
33 | t.Fatalf("Expected error on response. Got none: %s", body) | |
34 | } | |
35 | if have := r.Error.Code; want != have { | |
36 | t.Fatalf("Unexpected error code. Want %d, have %d: %s", want, have, body) | |
37 | } | |
38 | } | |
39 | ||
40 | func nopDecoder(context.Context, json.RawMessage) (interface{}, error) { return struct{}{}, nil } | |
41 | func nopEncoder(context.Context, interface{}) (json.RawMessage, error) { return []byte("[]"), nil } | |
42 | ||
43 | type mockLogger struct { | |
44 | Called bool | |
45 | LastArgs []interface{} | |
46 | } | |
47 | ||
48 | func (l *mockLogger) Log(keyvals ...interface{}) error { | |
49 | l.Called = true | |
50 | l.LastArgs = append(l.LastArgs, keyvals) | |
51 | return nil | |
52 | } | |
53 | ||
54 | func TestServerBadDecode(t *testing.T) { | |
55 | ecm := jsonrpc.EndpointCodecMap{ | |
56 | "add": jsonrpc.EndpointCodec{ | |
57 | Endpoint: endpoint.Nop, | |
58 | Decode: func(context.Context, json.RawMessage) (interface{}, error) { return struct{}{}, errors.New("oof") }, | |
59 | Encode: nopEncoder, | |
60 | }, | |
61 | } | |
62 | logger := mockLogger{} | |
63 | handler := jsonrpc.NewServer(ecm, jsonrpc.ServerErrorLogger(&logger)) | |
64 | server := httptest.NewServer(handler) | |
65 | defer server.Close() | |
66 | resp, _ := http.Post(server.URL, "application/json", addBody()) | |
67 | buf, _ := ioutil.ReadAll(resp.Body) | |
68 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
69 | t.Errorf("want %d, have %d: %s", want, have, buf) | |
70 | } | |
71 | expectErrorCode(t, jsonrpc.InternalError, buf) | |
72 | if !logger.Called { | |
73 | t.Fatal("Expected logger to be called with error. Wasn't.") | |
74 | } | |
75 | } | |
76 | ||
77 | func TestServerBadEndpoint(t *testing.T) { | |
78 | ecm := jsonrpc.EndpointCodecMap{ | |
79 | "add": jsonrpc.EndpointCodec{ | |
80 | Endpoint: func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("oof") }, | |
81 | Decode: nopDecoder, | |
82 | Encode: nopEncoder, | |
83 | }, | |
84 | } | |
85 | handler := jsonrpc.NewServer(ecm) | |
86 | server := httptest.NewServer(handler) | |
87 | defer server.Close() | |
88 | resp, _ := http.Post(server.URL, "application/json", addBody()) | |
89 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
90 | t.Errorf("want %d, have %d", want, have) | |
91 | } | |
92 | buf, _ := ioutil.ReadAll(resp.Body) | |
93 | expectErrorCode(t, jsonrpc.InternalError, buf) | |
94 | } | |
95 | ||
96 | func TestServerBadEncode(t *testing.T) { | |
97 | ecm := jsonrpc.EndpointCodecMap{ | |
98 | "add": jsonrpc.EndpointCodec{ | |
99 | Endpoint: endpoint.Nop, | |
100 | Decode: nopDecoder, | |
101 | Encode: func(context.Context, interface{}) (json.RawMessage, error) { return []byte{}, errors.New("oof") }, | |
102 | }, | |
103 | } | |
104 | handler := jsonrpc.NewServer(ecm) | |
105 | server := httptest.NewServer(handler) | |
106 | defer server.Close() | |
107 | resp, _ := http.Post(server.URL, "application/json", addBody()) | |
108 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
109 | t.Errorf("want %d, have %d", want, have) | |
110 | } | |
111 | buf, _ := ioutil.ReadAll(resp.Body) | |
112 | expectErrorCode(t, jsonrpc.InternalError, buf) | |
113 | } | |
114 | ||
115 | func TestServerErrorEncoder(t *testing.T) { | |
116 | errTeapot := errors.New("teapot") | |
117 | code := func(err error) int { | |
118 | if err == errTeapot { | |
119 | return http.StatusTeapot | |
120 | } | |
121 | return http.StatusInternalServerError | |
122 | } | |
123 | ecm := jsonrpc.EndpointCodecMap{ | |
124 | "add": jsonrpc.EndpointCodec{ | |
125 | Endpoint: func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, | |
126 | Decode: nopDecoder, | |
127 | Encode: nopEncoder, | |
128 | }, | |
129 | } | |
130 | handler := jsonrpc.NewServer( | |
131 | ecm, | |
132 | jsonrpc.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), | |
133 | ) | |
134 | server := httptest.NewServer(handler) | |
135 | defer server.Close() | |
136 | resp, _ := http.Post(server.URL, "application/json", addBody()) | |
137 | if want, have := http.StatusTeapot, resp.StatusCode; want != have { | |
138 | t.Errorf("want %d, have %d", want, have) | |
139 | } | |
140 | } | |
141 | ||
142 | func TestCanRejectNonPostRequest(t *testing.T) { | |
143 | ecm := jsonrpc.EndpointCodecMap{} | |
144 | handler := jsonrpc.NewServer(ecm) | |
145 | server := httptest.NewServer(handler) | |
146 | defer server.Close() | |
147 | resp, _ := http.Get(server.URL) | |
148 | if want, have := http.StatusMethodNotAllowed, resp.StatusCode; want != have { | |
149 | t.Errorf("want %d, have %d", want, have) | |
150 | } | |
151 | } | |
152 | ||
153 | func TestCanRejectInvalidJSON(t *testing.T) { | |
154 | ecm := jsonrpc.EndpointCodecMap{} | |
155 | handler := jsonrpc.NewServer(ecm) | |
156 | server := httptest.NewServer(handler) | |
157 | defer server.Close() | |
158 | resp, _ := http.Post(server.URL, "application/json", body("clearlynotjson")) | |
159 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
160 | t.Errorf("want %d, have %d", want, have) | |
161 | } | |
162 | buf, _ := ioutil.ReadAll(resp.Body) | |
163 | expectErrorCode(t, jsonrpc.ParseError, buf) | |
164 | } | |
165 | ||
166 | func TestServerUnregisteredMethod(t *testing.T) { | |
167 | ecm := jsonrpc.EndpointCodecMap{} | |
168 | handler := jsonrpc.NewServer(ecm) | |
169 | server := httptest.NewServer(handler) | |
170 | defer server.Close() | |
171 | resp, _ := http.Post(server.URL, "application/json", addBody()) | |
172 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
173 | t.Errorf("want %d, have %d", want, have) | |
174 | } | |
175 | buf, _ := ioutil.ReadAll(resp.Body) | |
176 | expectErrorCode(t, jsonrpc.MethodNotFoundError, buf) | |
177 | } | |
178 | ||
179 | func TestServerHappyPath(t *testing.T) { | |
180 | step, response := testServer(t) | |
181 | step() | |
182 | resp := <-response | |
183 | defer resp.Body.Close() // nolint | |
184 | buf, _ := ioutil.ReadAll(resp.Body) | |
185 | if want, have := http.StatusOK, resp.StatusCode; want != have { | |
186 | t.Errorf("want %d, have %d (%s)", want, have, buf) | |
187 | } | |
188 | var r jsonrpc.Response | |
189 | err := json.Unmarshal(buf, &r) | |
190 | if err != nil { | |
191 | t.Fatalf("Cant' decode response. err=%s, body=%s", err, buf) | |
192 | } | |
193 | if r.JSONRPC != jsonrpc.Version { | |
194 | t.Fatalf("JSONRPC Version: want=%s, got=%s", jsonrpc.Version, r.JSONRPC) | |
195 | } | |
196 | if r.Error != nil { | |
197 | t.Fatalf("Unxpected error on response: %s", buf) | |
198 | } | |
199 | } | |
200 | ||
201 | func TestMultipleServerBefore(t *testing.T) { | |
202 | var done = make(chan struct{}) | |
203 | ecm := jsonrpc.EndpointCodecMap{ | |
204 | "add": jsonrpc.EndpointCodec{ | |
205 | Endpoint: endpoint.Nop, | |
206 | Decode: nopDecoder, | |
207 | Encode: nopEncoder, | |
208 | }, | |
209 | } | |
210 | handler := jsonrpc.NewServer( | |
211 | ecm, | |
212 | jsonrpc.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { | |
213 | ctx = context.WithValue(ctx, "one", 1) | |
214 | ||
215 | return ctx | |
216 | }), | |
217 | jsonrpc.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { | |
218 | if _, ok := ctx.Value("one").(int); !ok { | |
219 | t.Error("Value was not set properly when multiple ServerBefores are used") | |
220 | } | |
221 | ||
222 | close(done) | |
223 | return ctx | |
224 | }), | |
225 | ) | |
226 | server := httptest.NewServer(handler) | |
227 | defer server.Close() | |
228 | http.Post(server.URL, "application/json", addBody()) // nolint | |
229 | ||
230 | select { | |
231 | case <-done: | |
232 | case <-time.After(time.Second): | |
233 | t.Fatal("timeout waiting for finalizer") | |
234 | } | |
235 | } | |
236 | ||
237 | func TestMultipleServerAfter(t *testing.T) { | |
238 | var done = make(chan struct{}) | |
239 | ecm := jsonrpc.EndpointCodecMap{ | |
240 | "add": jsonrpc.EndpointCodec{ | |
241 | Endpoint: endpoint.Nop, | |
242 | Decode: nopDecoder, | |
243 | Encode: nopEncoder, | |
244 | }, | |
245 | } | |
246 | handler := jsonrpc.NewServer( | |
247 | ecm, | |
248 | jsonrpc.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { | |
249 | ctx = context.WithValue(ctx, "one", 1) | |
250 | ||
251 | return ctx | |
252 | }), | |
253 | jsonrpc.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { | |
254 | if _, ok := ctx.Value("one").(int); !ok { | |
255 | t.Error("Value was not set properly when multiple ServerAfters are used") | |
256 | } | |
257 | ||
258 | close(done) | |
259 | return ctx | |
260 | }), | |
261 | ) | |
262 | server := httptest.NewServer(handler) | |
263 | defer server.Close() | |
264 | http.Post(server.URL, "application/json", addBody()) // nolint | |
265 | ||
266 | select { | |
267 | case <-done: | |
268 | case <-time.After(time.Second): | |
269 | t.Fatal("timeout waiting for finalizer") | |
270 | } | |
271 | } | |
272 | ||
273 | func TestCanFinalize(t *testing.T) { | |
274 | var done = make(chan struct{}) | |
275 | var finalizerCalled bool | |
276 | ecm := jsonrpc.EndpointCodecMap{ | |
277 | "add": jsonrpc.EndpointCodec{ | |
278 | Endpoint: endpoint.Nop, | |
279 | Decode: nopDecoder, | |
280 | Encode: nopEncoder, | |
281 | }, | |
282 | } | |
283 | handler := jsonrpc.NewServer( | |
284 | ecm, | |
285 | jsonrpc.ServerFinalizer(func(ctx context.Context, code int, req *http.Request) { | |
286 | finalizerCalled = true | |
287 | close(done) | |
288 | }), | |
289 | ) | |
290 | server := httptest.NewServer(handler) | |
291 | defer server.Close() | |
292 | http.Post(server.URL, "application/json", addBody()) // nolint | |
293 | ||
294 | select { | |
295 | case <-done: | |
296 | case <-time.After(time.Second): | |
297 | t.Fatal("timeout waiting for finalizer") | |
298 | } | |
299 | ||
300 | if !finalizerCalled { | |
301 | t.Fatal("Finalizer was not called.") | |
302 | } | |
303 | } | |
304 | ||
305 | func testServer(t *testing.T) (step func(), resp <-chan *http.Response) { | |
306 | var ( | |
307 | stepch = make(chan bool) | |
308 | endpoint = func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
309 | <-stepch | |
310 | return struct{}{}, nil | |
311 | } | |
312 | response = make(chan *http.Response) | |
313 | ecm = jsonrpc.EndpointCodecMap{ | |
314 | "add": jsonrpc.EndpointCodec{ | |
315 | Endpoint: endpoint, | |
316 | Decode: nopDecoder, | |
317 | Encode: nopEncoder, | |
318 | }, | |
319 | } | |
320 | handler = jsonrpc.NewServer(ecm) | |
321 | ) | |
322 | go func() { | |
323 | server := httptest.NewServer(handler) | |
324 | defer server.Close() | |
325 | rb := strings.NewReader(`{"jsonrpc": "2.0", "method": "add", "params": [3, 2], "id": 1}`) | |
326 | resp, err := http.Post(server.URL, "application/json", rb) | |
327 | if err != nil { | |
328 | t.Error(err) | |
329 | return | |
330 | } | |
331 | response <- resp | |
332 | }() | |
333 | return func() { stepch <- true }, response | |
334 | } |