Merge tag 'upstream/0.8.0' into debian/sid
Upstream version 0.8.0
Daniel Swarbrick
1 year, 4 months ago
106 | 106 | - [Martini](https://github.com/go-martini/martini) |
107 | 107 | - [Beego](http://beego.me/) |
108 | 108 | - [Revel](https://revel.github.io/) (considered [harmful](https://github.com/go-kit/kit/issues/350)) |
109 | - [GoBuffalo](https://gobuffalo.io/) | |
109 | 110 | |
110 | 111 | ## Additional reading |
111 | 112 | |
112 | - [Architecting for the Cloud](http://fr.slideshare.net/stonse/architecting-for-the-cloud-using-netflixoss-codemash-workshop-29852233) — Netflix | |
113 | - [Architecting for the Cloud](https://slideshare.net/stonse/architecting-for-the-cloud-using-netflixoss-codemash-workshop-29852233) — Netflix | |
113 | 114 | - [Dapper, a Large-Scale Distributed Systems Tracing Infrastructure](http://research.google.com/pubs/pub36356.html) — Google |
114 | 115 | - [Your Server as a Function](http://monkey.org/~marius/funsrv.pdf) (PDF) — Twitter |
115 | 116 |
0 | package casbin | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "errors" | |
5 | ||
6 | stdcasbin "github.com/casbin/casbin" | |
7 | "github.com/go-kit/kit/endpoint" | |
8 | ) | |
9 | ||
10 | type contextKey string | |
11 | ||
12 | const ( | |
13 | // CasbinModelContextKey holds the key to store the access control model | |
14 | // in context, it can be a path to configuration file or a casbin/model | |
15 | // Model. | |
16 | CasbinModelContextKey contextKey = "CasbinModel" | |
17 | ||
18 | // CasbinPolicyContextKey holds the key to store the access control policy | |
19 | // in context, it can be a path to policy file or an implementation of | |
20 | // casbin/persist Adapter interface. | |
21 | CasbinPolicyContextKey contextKey = "CasbinPolicy" | |
22 | ||
23 | // CasbinEnforcerContextKey holds the key to retrieve the active casbin | |
24 | // Enforcer. | |
25 | CasbinEnforcerContextKey contextKey = "CasbinEnforcer" | |
26 | ) | |
27 | ||
28 | var ( | |
29 | // ErrModelContextMissing denotes a casbin model was not passed into | |
30 | // the parsing of middleware's context. | |
31 | ErrModelContextMissing = errors.New("CasbinModel is required in context") | |
32 | ||
33 | // ErrPolicyContextMissing denotes a casbin policy was not passed into | |
34 | // the parsing of middleware's context. | |
35 | ErrPolicyContextMissing = errors.New("CasbinPolicy is required in context") | |
36 | ||
37 | // ErrUnauthorized denotes the subject is not authorized to do the action | |
38 | // intended on the given object, based on the context model and policy. | |
39 | ErrUnauthorized = errors.New("Unauthorized Access") | |
40 | ) | |
41 | ||
42 | // NewEnforcer checks whether the subject is authorized to do the specified | |
43 | // action on the given object. If a valid access control model and policy | |
44 | // is given, then the generated casbin Enforcer is stored in the context | |
45 | // with CasbinEnforcer as the key. | |
46 | func NewEnforcer( | |
47 | subject string, object interface{}, action string, | |
48 | ) endpoint.Middleware { | |
49 | return func(next endpoint.Endpoint) endpoint.Endpoint { | |
50 | return func(ctx context.Context, request interface{}) ( | |
51 | response interface{}, err error, | |
52 | ) { | |
53 | casbinModel := ctx.Value(CasbinModelContextKey) | |
54 | casbinPolicy := ctx.Value(CasbinPolicyContextKey) | |
55 | ||
56 | enforcer := stdcasbin.NewEnforcer(casbinModel, casbinPolicy) | |
57 | ctx = context.WithValue(ctx, CasbinEnforcerContextKey, enforcer) | |
58 | if !enforcer.Enforce(subject, object, action) { | |
59 | return nil, ErrUnauthorized | |
60 | } | |
61 | return next(ctx, request) | |
62 | } | |
63 | } | |
64 | } |
0 | package casbin | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "testing" | |
5 | ||
6 | stdcasbin "github.com/casbin/casbin" | |
7 | fileadapter "github.com/casbin/casbin/persist/file-adapter" | |
8 | ) | |
9 | ||
10 | func TestStructBaseContext(t *testing.T) { | |
11 | e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } | |
12 | ||
13 | m := stdcasbin.NewModel() | |
14 | m.AddDef("r", "r", "sub, obj, act") | |
15 | m.AddDef("p", "p", "sub, obj, act") | |
16 | m.AddDef("e", "e", "some(where (p.eft == allow))") | |
17 | m.AddDef("m", "m", "r.sub == p.sub && keyMatch(r.obj, p.obj) && regexMatch(r.act, p.act)") | |
18 | ||
19 | a := fileadapter.NewAdapter("testdata/keymatch_policy.csv") | |
20 | ||
21 | ctx := context.WithValue(context.Background(), CasbinModelContextKey, m) | |
22 | ctx = context.WithValue(ctx, CasbinPolicyContextKey, a) | |
23 | ||
24 | // positive case | |
25 | middleware := NewEnforcer("alice", "/alice_data/resource1", "GET")(e) | |
26 | ctx1, err := middleware(ctx, struct{}{}) | |
27 | if err != nil { | |
28 | t.Fatalf("Enforcer returned error: %s", err) | |
29 | } | |
30 | _, ok := ctx1.(context.Context).Value(CasbinEnforcerContextKey).(*stdcasbin.Enforcer) | |
31 | if !ok { | |
32 | t.Fatalf("context should contains the active enforcer") | |
33 | } | |
34 | ||
35 | // negative case | |
36 | middleware = NewEnforcer("alice", "/alice_data/resource2", "POST")(e) | |
37 | _, err = middleware(ctx, struct{}{}) | |
38 | if err == nil { | |
39 | t.Fatalf("Enforcer should return error") | |
40 | } | |
41 | } | |
42 | ||
43 | func TestFileBaseContext(t *testing.T) { | |
44 | e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } | |
45 | ctx := context.WithValue(context.Background(), CasbinModelContextKey, "testdata/basic_model.conf") | |
46 | ctx = context.WithValue(ctx, CasbinPolicyContextKey, "testdata/basic_policy.csv") | |
47 | ||
48 | // positive case | |
49 | middleware := NewEnforcer("alice", "data1", "read")(e) | |
50 | _, err := middleware(ctx, struct{}{}) | |
51 | if err != nil { | |
52 | t.Fatalf("Enforcer returned error: %s", err) | |
53 | } | |
54 | } |
0 | [request_definition] | |
1 | r = sub, obj, act | |
2 | ||
3 | [policy_definition] | |
4 | p = sub, obj, act | |
5 | ||
6 | [policy_effect] | |
7 | e = some(where (p.eft == allow)) | |
8 | ||
9 | [matchers] | |
10 | m = r.sub == p.sub && r.obj == p.obj && r.act == p.act⏎ |
0 | p, alice, /alice_data/*, GET | |
1 | p, alice, /alice_data/resource1, POST | |
2 | ||
3 | p, bob, /alice_data/resource2, GET | |
4 | p, bob, /bob_data/*, POST | |
5 | ||
6 | p, cathy, /cathy_data, (GET)|(POST)⏎ |
54 | 54 | ``` |
55 | 55 | |
56 | 56 | In order for the parser and the signer to work, the authorization headers need |
57 | to be passed between the request and the context. `ToHTTPContext()`, | |
58 | `FromHTTPContext()`, `ToGRPCContext()`, and `FromGRPCContext()` are given as | |
57 | to be passed between the request and the context. `HTTPToContext()`, | |
58 | `ContextToHTTP()`, `GRPCToContext()`, and `ContextToGRPC()` are given as | |
59 | 59 | helpers to do this. These functions implement the correlating transport's |
60 | 60 | RequestFunc interface and can be passed as ClientBefore or ServerBefore |
61 | 61 | options. |
76 | 76 | options := []httptransport.ClientOption{} |
77 | 77 | var exampleEndpoint endpoint.Endpoint |
78 | 78 | { |
79 | exampleEndpoint = grpctransport.NewClient(..., grpctransport.ClientBefore(jwt.FromGRPCContext())).Endpoint() | |
79 | exampleEndpoint = grpctransport.NewClient(..., grpctransport.ClientBefore(jwt.ContextToGRPC())).Endpoint() | |
80 | 80 | exampleEndpoint = jwt.NewSigner( |
81 | 81 | "kid-header", |
82 | 82 | []byte("SigningString"), |
107 | 107 | endpoints.CreateUserEndpoint, |
108 | 108 | DecodeGRPCCreateUserRequest, |
109 | 109 | EncodeGRPCCreateUserResponse, |
110 | append(options, grpctransport.ServerBefore(jwt.ToGRPCContext()))..., | |
110 | append(options, grpctransport.ServerBefore(jwt.GRPCToContext()))..., | |
111 | 111 | ), |
112 | 112 | getUser: grpctransport.NewServer( |
113 | 113 | ctx, |
0 | machine: | |
1 | pre: | |
2 | - curl -sSL https://s3.amazonaws.com/circle-downloads/install-circleci-docker.sh | bash -s -- 1.10.0 | |
3 | - sudo rm -rf /usr/local/go | |
4 | - curl -sSL https://storage.googleapis.com/golang/go1.9.linux-amd64.tar.gz | sudo tar xz -C /usr/local | |
5 | services: | |
6 | - docker | |
0 | version: 2 | |
7 | 1 | |
8 | dependencies: | |
9 | pre: | |
10 | - sudo curl -L "https://github.com/docker/compose/releases/download/1.10.0/docker-compose-linux-x86_64" -o /usr/local/bin/docker-compose | |
11 | - sudo chmod +x /usr/local/bin/docker-compose | |
12 | - docker-compose -f docker-compose-integration.yml up -d --force-recreate | |
13 | ||
14 | test: | |
15 | pre: | |
16 | - mkdir -p /home/ubuntu/.go_workspace/src/github.com/go-kit | |
17 | - mv /home/ubuntu/kit /home/ubuntu/.go_workspace/src/github.com/go-kit | |
18 | - ln -s /home/ubuntu/.go_workspace/src/github.com/go-kit/kit /home/ubuntu/kit | |
19 | - go get -t github.com/go-kit/kit/... | |
20 | override: | |
21 | - go test -v -race -tags integration github.com/go-kit/kit/...: | |
22 | environment: | |
23 | ETCD_ADDR: http://localhost:2379 | |
24 | CONSUL_ADDR: localhost:8500 | |
25 | ZK_ADDR: localhost:2181 | |
26 | EUREKA_ADDR: http://localhost:8761/eureka | |
2 | jobs: | |
3 | build: | |
4 | machine: true | |
5 | working_directory: /home/circleci/.go_workspace/src/github.com/go-kit/kit | |
6 | environment: | |
7 | ETCD_ADDR: http://localhost:2379 | |
8 | CONSUL_ADDR: localhost:8500 | |
9 | ZK_ADDR: localhost:2181 | |
10 | EUREKA_ADDR: http://localhost:8761/eureka | |
11 | steps: | |
12 | - checkout | |
13 | - run: wget -q https://storage.googleapis.com/golang/go1.11.linux-amd64.tar.gz | |
14 | - run: sudo rm -rf /usr/local/go | |
15 | - run: sudo tar -C /usr/local -xzf go1.11.linux-amd64.tar.gz | |
16 | - run: docker-compose -f docker-compose-integration.yml up -d --force-recreate | |
17 | - run: go get -t github.com/go-kit/kit/... | |
18 | - run: go test -v -race -tags integration github.com/go-kit/kit/... |
0 | # kitgen | |
1 | kitgen is an experimental code generation utility that helps with some of the | |
2 | boilerplate code required to implement the "onion" pattern `go-kit` utilizes. | |
3 | ||
4 | ## Usage | |
5 | Before using this tool please explore the [testdata]() directory for examples | |
6 | of the inputs it requires and the outputs that will be produced. _You may not | |
7 | need this tool._ If you are new to and just learning `go-kit` or if your use | |
8 | case involves introducing `go-kit` to an existing codebase you are better | |
9 | suited by slowly building out the "onion" by hand. | |
10 | ||
11 | Before starting you need to *install* `kitgen` utility — see instructions below. | |
12 | 1. **Define** your service. Create a `.go` file with the definition of your | |
13 | Service interface and any of the custom types it refers to: | |
14 | ```go | |
15 | // service.go | |
16 | package profilesvc // don't forget to name your package | |
17 | ||
18 | type Service interface { | |
19 | PostProfile(ctx context.Context, p Profile) error | |
20 | // ... | |
21 | } | |
22 | type Profile struct { | |
23 | ID string `json:"id"` | |
24 | Name string `json:"name,omitempty"` | |
25 | // ... | |
26 | } | |
27 | ``` | |
28 | 2. **Generate** your code. Run the following command: | |
29 | ```sh | |
30 | kitgen ./service.go | |
31 | # kitgen has a couple of flags that you may find useful | |
32 | ||
33 | # keep all code in the root directory | |
34 | kitgen -repo-layout flat ./service.go | |
35 | ||
36 | # put generated code elsewhere | |
37 | kitgen -target-dir ~/Projects/gohome/src/home.com/kitchenservice/brewcoffee | |
38 | ``` | |
39 | ||
40 | ## Installation | |
41 | 1. **Fetch** the `inlinefiles` utility. Go generate will use it to create your | |
42 | code: | |
43 | ``` | |
44 | go get github.com/nyarly/inlinefiles | |
45 | ``` | |
46 | 2. **Install** the binary for easy access to `kitgen`. Run the following commands: | |
47 | ```sh | |
48 | cd $GOPATH/src/github.com/go-kit/kit/cmd/kitgen | |
49 | go install | |
50 | ||
51 | # Check installation by running: | |
52 | kitgen -h | |
53 | ``` |
25 | 25 | return outer(next) |
26 | 26 | } |
27 | 27 | } |
28 | ||
29 | // Failer may be implemented by Go kit response types that contain business | |
30 | // logic error details. If Failed returns a non-nil error, the Go kit transport | |
31 | // layer may interpret this as a business logic error, and may encode it | |
32 | // differently than a regular, successful response. | |
33 | // | |
34 | // It's not necessary for your response types to implement Failer, but it may | |
35 | // help for more sophisticated use cases. The addsvc example shows how Failer | |
36 | // should be used by a complete application. | |
37 | type Failer interface { | |
38 | Failed() error | |
39 | } |
13 | 13 | "github.com/apache/thrift/lib/go/thrift" |
14 | 14 | lightstep "github.com/lightstep/lightstep-tracer-go" |
15 | 15 | stdopentracing "github.com/opentracing/opentracing-go" |
16 | zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing" | |
16 | 17 | zipkin "github.com/openzipkin/zipkin-go" |
17 | zipkinot "github.com/openzipkin/zipkin-go-opentracing" | |
18 | 18 | zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http" |
19 | 19 | "sourcegraph.com/sourcegraph/appdash" |
20 | 20 | appdashot "sourcegraph.com/sourcegraph/appdash/opentracing" |
42 | 42 | thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson") |
43 | 43 | thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered") |
44 | 44 | thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing") |
45 | zipkinV2URL = fs.String("zipkin-url", "", "Enable Zipkin v2 tracing (zipkin-go) via HTTP Reporter URL e.g. http://localhost:94111/api/v2/spans") | |
45 | zipkinV2URL = fs.String("zipkin-url", "", "Enable Zipkin v2 tracing (zipkin-go) via HTTP Reporter URL e.g. http://localhost:9411/api/v2/spans") | |
46 | 46 | zipkinV1URL = fs.String("zipkin-v1-url", "", "Enable Zipkin v1 tracing (zipkin-go-opentracing) via a collector URL e.g. http://localhost:9411/api/v1/spans") |
47 | 47 | lightstepToken = fs.String("lightstep-token", "", "Enable LightStep tracing via a LightStep access token") |
48 | 48 | appdashAddr = fs.String("appdash-addr", "", "Enable Appdash tracing via an Appdash server host:port") |
204 | 204 | fmt.Fprintf(os.Stdout, "%q + %q = %q\n", a, b, v) |
205 | 205 | |
206 | 206 | default: |
207 | fmt.Fprintf(os.Stderr, "error: invalid method %q\n", method) | |
207 | fmt.Fprintf(os.Stderr, "error: invalid method %q\n", *method) | |
208 | 208 | os.Exit(1) |
209 | 209 | } |
210 | 210 | } |
13 | 13 | lightstep "github.com/lightstep/lightstep-tracer-go" |
14 | 14 | "github.com/oklog/oklog/pkg/group" |
15 | 15 | stdopentracing "github.com/opentracing/opentracing-go" |
16 | zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing" | |
16 | 17 | zipkin "github.com/openzipkin/zipkin-go" |
17 | zipkinot "github.com/openzipkin/zipkin-go-opentracing" | |
18 | 18 | zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http" |
19 | 19 | stdprometheus "github.com/prometheus/client_golang/prometheus" |
20 | 20 | "github.com/prometheus/client_golang/prometheus/promhttp" |
15 | 15 | } |
16 | 16 | |
17 | 17 | pact := dsl.Pact{ |
18 | Port: 6666, | |
19 | 18 | Consumer: "addsvc", |
20 | 19 | Provider: "stringsvc", |
21 | 20 | } |
24 | 23 | pact.AddInteraction(). |
25 | 24 | UponReceiving("stringsvc uppercase"). |
26 | 25 | WithRequest(dsl.Request{ |
27 | Headers: map[string]string{"Content-Type": "application/json; charset=utf-8"}, | |
26 | Headers: dsl.MapMatcher{"Content-Type": dsl.String("application/json; charset=utf-8")}, | |
28 | 27 | Method: "POST", |
29 | Path: "/uppercase", | |
28 | Path: dsl.String("/uppercase"), | |
30 | 29 | Body: `{"s":"foo"}`, |
31 | 30 | }). |
32 | 31 | WillRespondWith(dsl.Response{ |
33 | 32 | Status: 200, |
34 | Headers: map[string]string{"Content-Type": "application/json; charset=utf-8"}, | |
33 | Headers: dsl.MapMatcher{"Content-Type": dsl.String("application/json; charset=utf-8")}, | |
35 | 34 | Body: `{"v":"FOO"}`, |
36 | 35 | }) |
37 | 36 |
97 | 97 | } |
98 | 98 | } |
99 | 99 | |
100 | // Failer is an interface that should be implemented by response types. | |
101 | // Response encoders can check if responses are Failer, and if so if they've | |
102 | // failed, and if so encode them using a separate write path based on the error. | |
103 | type Failer interface { | |
104 | Failed() error | |
105 | } | |
100 | // compile time assertions for our response types implementing endpoint.Failer. | |
101 | var ( | |
102 | _ endpoint.Failer = SumResponse{} | |
103 | _ endpoint.Failer = ConcatResponse{} | |
104 | ) | |
106 | 105 | |
107 | 106 | // SumRequest collects the request parameters for the Sum method. |
108 | 107 | type SumRequest struct { |
115 | 114 | Err error `json:"-"` // should be intercepted by Failed/errorEncoder |
116 | 115 | } |
117 | 116 | |
118 | // Failed implements Failer. | |
117 | // Failed implements endpoint.Failer. | |
119 | 118 | func (r SumResponse) Failed() error { return r.Err } |
120 | 119 | |
121 | 120 | // ConcatRequest collects the request parameters for the Concat method. |
129 | 128 | Err error `json:"-"` |
130 | 129 | } |
131 | 130 | |
132 | // Failed implements Failer. | |
131 | // Failed implements endpoint.Failer. | |
133 | 132 | func (r ConcatResponse) Failed() error { return r.Err } |
106 | 106 | zipkinClient, |
107 | 107 | } |
108 | 108 | |
109 | // Each individual endpoint is an http/transport.Client (which implements | |
109 | // Each individual endpoint is an grpc/transport.Client (which implements | |
110 | 110 | // endpoint.Endpoint) that gets wrapped with various middlewares. If you |
111 | 111 | // made your own client library, you'd do this work there, so your server |
112 | 112 | // could rely on a consistent set of client behavior. |
234 | 234 | // encodeHTTPGenericResponse is a transport/http.EncodeResponseFunc that encodes |
235 | 235 | // the response as JSON to the response writer. Primarily useful in a server. |
236 | 236 | func encodeHTTPGenericResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { |
237 | if f, ok := response.(addendpoint.Failer); ok && f.Failed() != nil { | |
237 | if f, ok := response.(endpoint.Failer); ok && f.Failed() != nil { | |
238 | 238 | errorEncoder(ctx, f.Failed(), w) |
239 | 239 | return nil |
240 | 240 | } |
1 | 1 | |
2 | 2 | This example demonstrates how to use Go kit to implement a REST-y HTTP service. |
3 | 3 | It leverages the excellent [gorilla mux package](https://github.com/gorilla/mux) for routing. |
4 | ||
5 | Run the example with the optional port address for the service: | |
6 | ||
7 | ```bash | |
8 | $ go run ./cmd/profilesvc/main.go -http.addr :8080 | |
9 | ts=2018-05-01T16:13:12.849086255Z caller=main.go:47 transport=HTTP addr=:8080 | |
10 | ``` | |
11 | ||
12 | Create a Profile: | |
13 | ||
14 | ```bash | |
15 | $ curl -d '{"id":"1234","Name":"Go Kit"}' -H "Content-Type: application/json" -X POST http://localhost:8080/profiles/ | |
16 | {} | |
17 | ``` | |
18 | ||
19 | Get the profile you just created | |
20 | ||
21 | ```bash | |
22 | $ curl localhost:8080/profiles/1234 | |
23 | {"profile":{"id":"1234","name":"Go Kit"}} | |
24 | ``` |
67 | 67 | options := []httptransport.ClientOption{} |
68 | 68 | |
69 | 69 | // Note that the request encoders need to modify the request URL, changing |
70 | // the path and method. That's fine: we simply need to provide specific | |
71 | // encoders for each endpoint. | |
70 | // the path. That's fine: we simply need to provide specific encoders for | |
71 | // each endpoint. | |
72 | 72 | |
73 | 73 | return Endpoints{ |
74 | 74 | PostProfileEndpoint: httptransport.NewClient("POST", tgt, encodePostProfileRequest, decodePostProfileResponse, options...).Endpoint(), |
216 | 216 | |
217 | 217 | func encodePostProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { |
218 | 218 | // r.Methods("POST").Path("/profiles/") |
219 | req.Method, req.URL.Path = "POST", "/profiles/" | |
219 | req.URL.Path = "/profiles/" | |
220 | 220 | return encodeRequest(ctx, req, request) |
221 | 221 | } |
222 | 222 | |
224 | 224 | // r.Methods("GET").Path("/profiles/{id}") |
225 | 225 | r := request.(getProfileRequest) |
226 | 226 | profileID := url.QueryEscape(r.ID) |
227 | req.Method, req.URL.Path = "GET", "/profiles/"+profileID | |
227 | req.URL.Path = "/profiles/" + profileID | |
228 | 228 | return encodeRequest(ctx, req, request) |
229 | 229 | } |
230 | 230 | |
232 | 232 | // r.Methods("PUT").Path("/profiles/{id}") |
233 | 233 | r := request.(putProfileRequest) |
234 | 234 | profileID := url.QueryEscape(r.ID) |
235 | req.Method, req.URL.Path = "PUT", "/profiles/"+profileID | |
235 | req.URL.Path = "/profiles/" + profileID | |
236 | 236 | return encodeRequest(ctx, req, request) |
237 | 237 | } |
238 | 238 | |
240 | 240 | // r.Methods("PATCH").Path("/profiles/{id}") |
241 | 241 | r := request.(patchProfileRequest) |
242 | 242 | profileID := url.QueryEscape(r.ID) |
243 | req.Method, req.URL.Path = "PATCH", "/profiles/"+profileID | |
243 | req.URL.Path = "/profiles/" + profileID | |
244 | 244 | return encodeRequest(ctx, req, request) |
245 | 245 | } |
246 | 246 | |
248 | 248 | // r.Methods("DELETE").Path("/profiles/{id}") |
249 | 249 | r := request.(deleteProfileRequest) |
250 | 250 | profileID := url.QueryEscape(r.ID) |
251 | req.Method, req.URL.Path = "DELETE", "/profiles/"+profileID | |
251 | req.URL.Path = "/profiles/" + profileID | |
252 | 252 | return encodeRequest(ctx, req, request) |
253 | 253 | } |
254 | 254 | |
256 | 256 | // r.Methods("GET").Path("/profiles/{id}/addresses/") |
257 | 257 | r := request.(getAddressesRequest) |
258 | 258 | profileID := url.QueryEscape(r.ProfileID) |
259 | req.Method, req.URL.Path = "GET", "/profiles/"+profileID+"/addresses/" | |
259 | req.URL.Path = "/profiles/" + profileID + "/addresses/" | |
260 | 260 | return encodeRequest(ctx, req, request) |
261 | 261 | } |
262 | 262 | |
265 | 265 | r := request.(getAddressRequest) |
266 | 266 | profileID := url.QueryEscape(r.ProfileID) |
267 | 267 | addressID := url.QueryEscape(r.AddressID) |
268 | req.Method, req.URL.Path = "GET", "/profiles/"+profileID+"/addresses/"+addressID | |
268 | req.URL.Path = "/profiles/" + profileID + "/addresses/" + addressID | |
269 | 269 | return encodeRequest(ctx, req, request) |
270 | 270 | } |
271 | 271 | |
273 | 273 | // r.Methods("POST").Path("/profiles/{id}/addresses/") |
274 | 274 | r := request.(postAddressRequest) |
275 | 275 | profileID := url.QueryEscape(r.ProfileID) |
276 | req.Method, req.URL.Path = "POST", "/profiles/"+profileID+"/addresses/" | |
276 | req.URL.Path = "/profiles/" + profileID + "/addresses/" | |
277 | 277 | return encodeRequest(ctx, req, request) |
278 | 278 | } |
279 | 279 | |
282 | 282 | r := request.(deleteAddressRequest) |
283 | 283 | profileID := url.QueryEscape(r.ProfileID) |
284 | 284 | addressID := url.QueryEscape(r.AddressID) |
285 | req.Method, req.URL.Path = "DELETE", "/profiles/"+profileID+"/addresses/"+addressID | |
285 | req.URL.Path = "/profiles/" + profileID + "/addresses/" + addressID | |
286 | 286 | return encodeRequest(ctx, req, request) |
287 | 287 | } |
288 | 288 |
13 | 13 | |
14 | 14 | // StringService provides operations on strings. |
15 | 15 | type StringService interface { |
16 | Uppercase(context.Context, string) (string, error) | |
17 | Count(context.Context, string) int | |
16 | Uppercase(string) (string, error) | |
17 | Count(string) int | |
18 | 18 | } |
19 | 19 | |
20 | 20 | // stringService is a concrete implementation of StringService |
21 | 21 | type stringService struct{} |
22 | 22 | |
23 | func (stringService) Uppercase(_ context.Context, s string) (string, error) { | |
23 | func (stringService) Uppercase(s string) (string, error) { | |
24 | 24 | if s == "" { |
25 | 25 | return "", ErrEmpty |
26 | 26 | } |
27 | 27 | return strings.ToUpper(s), nil |
28 | 28 | } |
29 | 29 | |
30 | func (stringService) Count(_ context.Context, s string) int { | |
30 | func (stringService) Count(s string) int { | |
31 | 31 | return len(s) |
32 | 32 | } |
33 | 33 | |
54 | 54 | |
55 | 55 | // Endpoints are a primary abstraction in go-kit. An endpoint represents a single RPC (method in our service interface) |
56 | 56 | func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint { |
57 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
57 | return func(_ context.Context, request interface{}) (interface{}, error) { | |
58 | 58 | req := request.(uppercaseRequest) |
59 | v, err := svc.Uppercase(ctx, req.S) | |
59 | v, err := svc.Uppercase(req.S) | |
60 | 60 | if err != nil { |
61 | 61 | return uppercaseResponse{v, err.Error()}, nil |
62 | 62 | } |
65 | 65 | } |
66 | 66 | |
67 | 67 | func makeCountEndpoint(svc StringService) endpoint.Endpoint { |
68 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
68 | return func(_ context.Context, request interface{}) (interface{}, error) { | |
69 | 69 | req := request.(countRequest) |
70 | v := svc.Count(ctx, req.S) | |
70 | v := svc.Count(req.S) | |
71 | 71 | return countResponse{v}, nil |
72 | 72 | } |
73 | 73 | } |
8 | 8 | ) |
9 | 9 | |
10 | 10 | func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint { |
11 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
11 | return func(_ context.Context, request interface{}) (interface{}, error) { | |
12 | 12 | req := request.(uppercaseRequest) |
13 | 13 | v, err := svc.Uppercase(req.S) |
14 | 14 | if err != nil { |
19 | 19 | } |
20 | 20 | |
21 | 21 | func makeCountEndpoint(svc StringService) endpoint.Endpoint { |
22 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
22 | return func(_ context.Context, request interface{}) (interface{}, error) { | |
23 | 23 | req := request.(countRequest) |
24 | 24 | v := svc.Count(req.S) |
25 | 25 | return countResponse{v}, nil |
0 | package main | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "log" | |
7 | "strings" | |
8 | "flag" | |
9 | "net/http" | |
10 | ||
11 | "github.com/go-kit/kit/endpoint" | |
12 | natstransport "github.com/go-kit/kit/transport/nats" | |
13 | httptransport "github.com/go-kit/kit/transport/http" | |
14 | ||
15 | "github.com/nats-io/go-nats" | |
16 | ) | |
17 | ||
18 | // StringService provides operations on strings. | |
19 | type StringService interface { | |
20 | Uppercase(context.Context, string) (string, error) | |
21 | Count(context.Context, string) int | |
22 | } | |
23 | ||
24 | // stringService is a concrete implementation of StringService | |
25 | type stringService struct{} | |
26 | ||
27 | func (stringService) Uppercase(_ context.Context, s string) (string, error) { | |
28 | if s == "" { | |
29 | return "", ErrEmpty | |
30 | } | |
31 | return strings.ToUpper(s), nil | |
32 | } | |
33 | ||
34 | func (stringService) Count(_ context.Context, s string) int { | |
35 | return len(s) | |
36 | } | |
37 | ||
38 | // ErrEmpty is returned when an input string is empty. | |
39 | var ErrEmpty = errors.New("empty string") | |
40 | ||
41 | // For each method, we define request and response structs | |
42 | type uppercaseRequest struct { | |
43 | S string `json:"s"` | |
44 | } | |
45 | ||
46 | type uppercaseResponse struct { | |
47 | V string `json:"v"` | |
48 | Err string `json:"err,omitempty"` // errors don't define JSON marshaling | |
49 | } | |
50 | ||
51 | type countRequest struct { | |
52 | S string `json:"s"` | |
53 | } | |
54 | ||
55 | type countResponse struct { | |
56 | V int `json:"v"` | |
57 | } | |
58 | ||
59 | // Endpoints are a primary abstraction in go-kit. An endpoint represents a single RPC (method in our service interface) | |
60 | func makeUppercaseHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint { | |
61 | return natstransport.NewPublisher( | |
62 | nc, | |
63 | "stringsvc.uppercase", | |
64 | natstransport.EncodeJSONRequest, | |
65 | decodeUppercaseResponse, | |
66 | ).Endpoint() | |
67 | } | |
68 | ||
69 | func makeCountHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint { | |
70 | return natstransport.NewPublisher( | |
71 | nc, | |
72 | "stringsvc.count", | |
73 | natstransport.EncodeJSONRequest, | |
74 | decodeCountResponse, | |
75 | ).Endpoint() | |
76 | } | |
77 | ||
78 | func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint { | |
79 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
80 | req := request.(uppercaseRequest) | |
81 | v, err := svc.Uppercase(ctx, req.S) | |
82 | if err != nil { | |
83 | return uppercaseResponse{v, err.Error()}, nil | |
84 | } | |
85 | return uppercaseResponse{v, ""}, nil | |
86 | } | |
87 | } | |
88 | ||
89 | func makeCountEndpoint(svc StringService) endpoint.Endpoint { | |
90 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
91 | req := request.(countRequest) | |
92 | v := svc.Count(ctx, req.S) | |
93 | return countResponse{v}, nil | |
94 | } | |
95 | } | |
96 | ||
97 | // Transports expose the service to the network. In this fourth example we utilize JSON over NATS and HTTP. | |
98 | func main() { | |
99 | svc := stringService{} | |
100 | ||
101 | natsURL := flag.String("nats-url", nats.DefaultURL, "URL for connection to NATS") | |
102 | flag.Parse() | |
103 | ||
104 | nc, err := nats.Connect(*natsURL) | |
105 | if err != nil { | |
106 | log.Fatal(err) | |
107 | } | |
108 | defer nc.Close() | |
109 | ||
110 | uppercaseHTTPHandler := httptransport.NewServer( | |
111 | makeUppercaseHTTPEndpoint(nc), | |
112 | decodeUppercaseHTTPRequest, | |
113 | httptransport.EncodeJSONResponse, | |
114 | ) | |
115 | ||
116 | countHTTPHandler := httptransport.NewServer( | |
117 | makeCountHTTPEndpoint(nc), | |
118 | decodeCountHTTPRequest, | |
119 | httptransport.EncodeJSONResponse, | |
120 | ) | |
121 | ||
122 | uppercaseHandler := natstransport.NewSubscriber( | |
123 | makeUppercaseEndpoint(svc), | |
124 | decodeUppercaseRequest, | |
125 | natstransport.EncodeJSONResponse, | |
126 | ) | |
127 | ||
128 | countHandler := natstransport.NewSubscriber( | |
129 | makeCountEndpoint(svc), | |
130 | decodeCountRequest, | |
131 | natstransport.EncodeJSONResponse, | |
132 | ) | |
133 | ||
134 | uSub, err := nc.QueueSubscribe("stringsvc.uppercase", "stringsvc", uppercaseHandler.ServeMsg(nc)) | |
135 | if err != nil { | |
136 | log.Fatal(err) | |
137 | } | |
138 | defer uSub.Unsubscribe() | |
139 | ||
140 | cSub, err := nc.QueueSubscribe("stringsvc.count", "stringsvc", countHandler.ServeMsg(nc)) | |
141 | if err != nil { | |
142 | log.Fatal(err) | |
143 | } | |
144 | defer cSub.Unsubscribe() | |
145 | ||
146 | http.Handle("/uppercase", uppercaseHTTPHandler) | |
147 | http.Handle("/count", countHTTPHandler) | |
148 | log.Fatal(http.ListenAndServe(":8080", nil)) | |
149 | ||
150 | } | |
151 | ||
152 | func decodeUppercaseHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) { | |
153 | var request uppercaseRequest | |
154 | if err := json.NewDecoder(r.Body).Decode(&request); err != nil { | |
155 | return nil, err | |
156 | } | |
157 | return request, nil | |
158 | } | |
159 | ||
160 | func decodeCountHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) { | |
161 | var request countRequest | |
162 | if err := json.NewDecoder(r.Body).Decode(&request); err != nil { | |
163 | return nil, err | |
164 | } | |
165 | return request, nil | |
166 | } | |
167 | ||
168 | func decodeUppercaseResponse(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
169 | var response uppercaseResponse | |
170 | ||
171 | if err := json.Unmarshal(msg.Data, &response); err != nil { | |
172 | return nil, err | |
173 | } | |
174 | ||
175 | return response, nil | |
176 | } | |
177 | ||
178 | func decodeCountResponse(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
179 | var response countResponse | |
180 | ||
181 | if err := json.Unmarshal(msg.Data, &response); err != nil { | |
182 | return nil, err | |
183 | } | |
184 | ||
185 | return response, nil | |
186 | } | |
187 | ||
188 | func decodeUppercaseRequest(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
189 | var request uppercaseRequest | |
190 | ||
191 | if err := json.Unmarshal(msg.Data, &request); err != nil { | |
192 | return nil, err | |
193 | } | |
194 | return request, nil | |
195 | } | |
196 | ||
197 | func decodeCountRequest(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
198 | var request countRequest | |
199 | ||
200 | if err := json.Unmarshal(msg.Data, &request); err != nil { | |
201 | return nil, err | |
202 | } | |
203 | return request, nil | |
204 | } | |
205 |
0 | // Package level implements leveled logging on top of package log. To use the | |
1 | // level package, create a logger as per normal in your func main, and wrap it | |
2 | // with level.NewFilter. | |
0 | // Package level implements leveled logging on top of Go kit's log package. To | |
1 | // use the level package, create a logger as per normal in your func main, and | |
2 | // wrap it with level.NewFilter. | |
3 | 3 | // |
4 | 4 | // var logger log.Logger |
5 | 5 | // logger = log.NewLogfmtLogger(os.Stderr) |
175 | 175 | func DebugValue() Value { return debugValue } |
176 | 176 | |
177 | 177 | var ( |
178 | // key is of type interfae{} so that it allocates once during package | |
178 | // key is of type interface{} so that it allocates once during package | |
179 | 179 | // initialization and avoids allocating every time the value is added to a |
180 | 180 | // []interface{} later. |
181 | 181 | key interface{} = "level" |
0 | // Package logrus provides an adapter to the | |
1 | // go-kit log.Logger interface. | |
2 | package logrus | |
3 | ||
4 | import ( | |
5 | "errors" | |
6 | "fmt" | |
7 | ||
8 | "github.com/go-kit/kit/log" | |
9 | "github.com/sirupsen/logrus" | |
10 | ) | |
11 | ||
12 | type logrusLogger struct { | |
13 | *logrus.Logger | |
14 | } | |
15 | ||
16 | var errMissingValue = errors.New("(MISSING)") | |
17 | ||
18 | // NewLogrusLogger returns a go-kit log.Logger that sends log events to a Logrus logger. | |
19 | func NewLogrusLogger(logger *logrus.Logger) log.Logger { | |
20 | return &logrusLogger{logger} | |
21 | } | |
22 | ||
23 | func (l logrusLogger) Log(keyvals ...interface{}) error { | |
24 | fields := logrus.Fields{} | |
25 | for i := 0; i < len(keyvals); i += 2 { | |
26 | if i+1 < len(keyvals) { | |
27 | fields[fmt.Sprint(keyvals[i])] = keyvals[i+1] | |
28 | } else { | |
29 | fields[fmt.Sprint(keyvals[i])] = errMissingValue | |
30 | } | |
31 | } | |
32 | l.WithFields(fields).Info() | |
33 | return nil | |
34 | } |
0 | package logrus_test | |
1 | ||
2 | import ( | |
3 | "bytes" | |
4 | "errors" | |
5 | "strings" | |
6 | "testing" | |
7 | ||
8 | log "github.com/go-kit/kit/log/logrus" | |
9 | "github.com/sirupsen/logrus" | |
10 | ) | |
11 | ||
12 | func TestLogrusLogger(t *testing.T) { | |
13 | t.Parallel() | |
14 | buf := &bytes.Buffer{} | |
15 | logrusLogger := logrus.New() | |
16 | logrusLogger.Out = buf | |
17 | logrusLogger.Formatter = &logrus.TextFormatter{TimestampFormat: "02-01-2006 15:04:05", FullTimestamp: true} | |
18 | logger := log.NewLogrusLogger(logrusLogger) | |
19 | ||
20 | if err := logger.Log("hello", "world"); err != nil { | |
21 | t.Fatal(err) | |
22 | } | |
23 | if want, have := "hello=world\n", strings.Split(buf.String(), " ")[3]; want != have { | |
24 | t.Errorf("want %#v, have %#v", want, have) | |
25 | } | |
26 | ||
27 | buf.Reset() | |
28 | if err := logger.Log("a", 1, "err", errors.New("error")); err != nil { | |
29 | t.Fatal(err) | |
30 | } | |
31 | if want, have := "a=1 err=error", strings.TrimSpace(strings.SplitAfterN(buf.String(), " ", 4)[3]); want != have { | |
32 | t.Errorf("want %#v, have %#v", want, have) | |
33 | } | |
34 | ||
35 | buf.Reset() | |
36 | if err := logger.Log("a", 1, "b"); err != nil { | |
37 | t.Fatal(err) | |
38 | } | |
39 | if want, have := "a=1 b=\"(MISSING)\"", strings.TrimSpace(strings.SplitAfterN(buf.String(), " ", 4)[3]); want != have { | |
40 | t.Errorf("want %#v, have %#v", want, have) | |
41 | } | |
42 | ||
43 | buf.Reset() | |
44 | if err := logger.Log("my_map", mymap{0: 0}); err != nil { | |
45 | t.Fatal(err) | |
46 | } | |
47 | if want, have := "my_map=special_behavior", strings.TrimSpace(strings.Split(buf.String(), " ")[3]); want != have { | |
48 | t.Errorf("want %#v, have %#v", want, have) | |
49 | } | |
50 | } | |
51 | ||
52 | type mymap map[int]int | |
53 | ||
54 | func (m mymap) String() string { return "special_behavior" } |
0 | 0 | // The code in this file is adapted from github.com/mattn/go-colorable. |
1 | ||
2 | // +build windows | |
3 | 1 | |
4 | 2 | package term |
5 | 3 |
0 | 0 | package log |
1 | 1 | |
2 | 2 | import ( |
3 | "runtime" | |
4 | "strconv" | |
5 | "strings" | |
3 | 6 | "time" |
4 | ||
5 | "github.com/go-stack/stack" | |
6 | 7 | ) |
7 | 8 | |
8 | 9 | // A Valuer generates a log value. When passed to With or WithPrefix in a |
80 | 81 | // Caller returns a Valuer that returns a file and line from a specified depth |
81 | 82 | // in the callstack. Users will probably want to use DefaultCaller. |
82 | 83 | func Caller(depth int) Valuer { |
83 | return func() interface{} { return stack.Caller(depth) } | |
84 | return func() interface{} { | |
85 | _, file, line, _ := runtime.Caller(depth) | |
86 | idx := strings.LastIndexByte(file, '/') | |
87 | // using idx+1 below handles both of following cases: | |
88 | // idx == -1 because no "/" was found, or | |
89 | // idx >= 0 and we want to start at the character after the found "/". | |
90 | return file[idx+1:] + ":" + strconv.Itoa(line) | |
91 | } | |
84 | 92 | } |
85 | 93 | |
86 | 94 | var ( |
0 | // Package influxstatsd provides support for InfluxData's StatsD Telegraf plugin. It's very | |
1 | // similar to StatsD, but supports arbitrary tags per-metric, which map to Go | |
2 | // kit's label values. So, while label values are no-ops in StatsD, they are | |
3 | // supported here. For more details, see the article at | |
4 | // https://www.influxdata.com/blog/getting-started-with-sending-statsd-metrics-to-telegraf-influxdb/ | |
5 | // | |
6 | // This package batches observations and emits them on some schedule to the | |
7 | // remote server. This is useful even if you connect to your service | |
8 | // over UDP. Emitting one network packet per observation can quickly overwhelm | |
9 | // even the fastest internal network. | |
10 | package influxstatsd | |
11 | ||
12 | import ( | |
13 | "fmt" | |
14 | "io" | |
15 | "strings" | |
16 | "sync" | |
17 | "sync/atomic" | |
18 | "time" | |
19 | ||
20 | "github.com/go-kit/kit/log" | |
21 | "github.com/go-kit/kit/metrics" | |
22 | "github.com/go-kit/kit/metrics/generic" | |
23 | "github.com/go-kit/kit/metrics/internal/lv" | |
24 | "github.com/go-kit/kit/metrics/internal/ratemap" | |
25 | "github.com/go-kit/kit/util/conn" | |
26 | ) | |
27 | ||
28 | // Influxstatsd receives metrics observations and forwards them to a server. | |
29 | // Create a Influxstatsd object, use it to create metrics, and pass those | |
30 | // metrics as dependencies to the components that will use them. | |
31 | // | |
32 | // All metrics are buffered until WriteTo is called. Counters and gauges are | |
33 | // aggregated into a single observation per timeseries per write. Timings and | |
34 | // histograms are buffered but not aggregated. | |
35 | // | |
36 | // To regularly report metrics to an io.Writer, use the WriteLoop helper method. | |
37 | // To send to a InfluxStatsD server, use the SendLoop helper method. | |
38 | type Influxstatsd struct { | |
39 | mtx sync.RWMutex | |
40 | prefix string | |
41 | rates *ratemap.RateMap | |
42 | counters *lv.Space | |
43 | gauges map[string]*gaugeNode | |
44 | timings *lv.Space | |
45 | histograms *lv.Space | |
46 | logger log.Logger | |
47 | lvs lv.LabelValues | |
48 | } | |
49 | ||
50 | // New returns a Influxstatsd object that may be used to create metrics. Prefix is | |
51 | // applied to all created metrics. Callers must ensure that regular calls to | |
52 | // WriteTo are performed, either manually or with one of the helper methods. | |
53 | func New(prefix string, logger log.Logger, lvs ...string) *Influxstatsd { | |
54 | if len(lvs)%2 != 0 { | |
55 | panic("odd number of LabelValues; programmer error!") | |
56 | } | |
57 | return &Influxstatsd{ | |
58 | prefix: prefix, | |
59 | rates: ratemap.New(), | |
60 | counters: lv.NewSpace(), | |
61 | gauges: map[string]*gaugeNode{}, // https://github.com/go-kit/kit/pull/588 | |
62 | timings: lv.NewSpace(), | |
63 | histograms: lv.NewSpace(), | |
64 | logger: logger, | |
65 | lvs: lvs, | |
66 | } | |
67 | } | |
68 | ||
69 | // NewCounter returns a counter, sending observations to this Influxstatsd object. | |
70 | func (d *Influxstatsd) NewCounter(name string, sampleRate float64) *Counter { | |
71 | d.rates.Set(name, sampleRate) | |
72 | return &Counter{ | |
73 | name: name, | |
74 | obs: d.counters.Observe, | |
75 | } | |
76 | } | |
77 | ||
78 | // NewGauge returns a gauge, sending observations to this Influxstatsd object. | |
79 | func (d *Influxstatsd) NewGauge(name string) *Gauge { | |
80 | d.mtx.Lock() | |
81 | n, ok := d.gauges[name] | |
82 | if !ok { | |
83 | n = &gaugeNode{gauge: &Gauge{g: generic.NewGauge(name), influx: d}} | |
84 | d.gauges[name] = n | |
85 | } | |
86 | d.mtx.Unlock() | |
87 | return n.gauge | |
88 | } | |
89 | ||
90 | // NewTiming returns a histogram whose observations are interpreted as | |
91 | // millisecond durations, and are forwarded to this Influxstatsd object. | |
92 | func (d *Influxstatsd) NewTiming(name string, sampleRate float64) *Timing { | |
93 | d.rates.Set(name, sampleRate) | |
94 | return &Timing{ | |
95 | name: name, | |
96 | obs: d.timings.Observe, | |
97 | } | |
98 | } | |
99 | ||
100 | // NewHistogram returns a histogram whose observations are of an unspecified | |
101 | // unit, and are forwarded to this Influxstatsd object. | |
102 | func (d *Influxstatsd) NewHistogram(name string, sampleRate float64) *Histogram { | |
103 | d.rates.Set(name, sampleRate) | |
104 | return &Histogram{ | |
105 | name: name, | |
106 | obs: d.histograms.Observe, | |
107 | } | |
108 | } | |
109 | ||
110 | // WriteLoop is a helper method that invokes WriteTo to the passed writer every | |
111 | // time the passed channel fires. This method blocks until the channel is | |
112 | // closed, so clients probably want to run it in its own goroutine. For typical | |
113 | // usage, create a time.Ticker and pass its C channel to this method. | |
114 | func (d *Influxstatsd) WriteLoop(c <-chan time.Time, w io.Writer) { | |
115 | for range c { | |
116 | if _, err := d.WriteTo(w); err != nil { | |
117 | d.logger.Log("during", "WriteTo", "err", err) | |
118 | } | |
119 | } | |
120 | } | |
121 | ||
122 | // SendLoop is a helper method that wraps WriteLoop, passing a managed | |
123 | // connection to the network and address. Like WriteLoop, this method blocks | |
124 | // until the channel is closed, so clients probably want to start it in its own | |
125 | // goroutine. For typical usage, create a time.Ticker and pass its C channel to | |
126 | // this method. | |
127 | func (d *Influxstatsd) SendLoop(c <-chan time.Time, network, address string) { | |
128 | d.WriteLoop(c, conn.NewDefaultManager(network, address, d.logger)) | |
129 | } | |
130 | ||
131 | // WriteTo flushes the buffered content of the metrics to the writer, in | |
132 | // InfluxStatsD format. WriteTo abides best-effort semantics, so observations are | |
133 | // lost if there is a problem with the write. Clients should be sure to call | |
134 | // WriteTo regularly, ideally through the WriteLoop or SendLoop helper methods. | |
135 | func (d *Influxstatsd) WriteTo(w io.Writer) (count int64, err error) { | |
136 | var n int | |
137 | ||
138 | d.counters.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool { | |
139 | n, err = fmt.Fprintf(w, "%s%s%s:%f|c%s\n", d.prefix, name, d.tagValues(lvs), sum(values), sampling(d.rates.Get(name))) | |
140 | if err != nil { | |
141 | return false | |
142 | } | |
143 | count += int64(n) | |
144 | return true | |
145 | }) | |
146 | if err != nil { | |
147 | return count, err | |
148 | } | |
149 | ||
150 | d.mtx.RLock() | |
151 | for _, root := range d.gauges { | |
152 | root.walk(func(name string, lvs lv.LabelValues, value float64) bool { | |
153 | n, err = fmt.Fprintf(w, "%s%s%s:%f|g\n", d.prefix, name, d.tagValues(lvs), value) | |
154 | if err != nil { | |
155 | return false | |
156 | } | |
157 | count += int64(n) | |
158 | return true | |
159 | }) | |
160 | } | |
161 | d.mtx.RUnlock() | |
162 | ||
163 | d.timings.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool { | |
164 | sampleRate := d.rates.Get(name) | |
165 | for _, value := range values { | |
166 | n, err = fmt.Fprintf(w, "%s%s%s:%f|ms%s\n", d.prefix, name, d.tagValues(lvs), value, sampling(sampleRate)) | |
167 | if err != nil { | |
168 | return false | |
169 | } | |
170 | count += int64(n) | |
171 | } | |
172 | return true | |
173 | }) | |
174 | if err != nil { | |
175 | return count, err | |
176 | } | |
177 | ||
178 | d.histograms.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool { | |
179 | sampleRate := d.rates.Get(name) | |
180 | for _, value := range values { | |
181 | n, err = fmt.Fprintf(w, "%s%s%s:%f|h%s\n", d.prefix, name, d.tagValues(lvs), value, sampling(sampleRate)) | |
182 | if err != nil { | |
183 | return false | |
184 | } | |
185 | count += int64(n) | |
186 | } | |
187 | return true | |
188 | }) | |
189 | if err != nil { | |
190 | return count, err | |
191 | } | |
192 | ||
193 | return count, err | |
194 | } | |
195 | ||
196 | func sum(a []float64) float64 { | |
197 | var v float64 | |
198 | for _, f := range a { | |
199 | v += f | |
200 | } | |
201 | return v | |
202 | } | |
203 | ||
204 | func last(a []float64) float64 { | |
205 | return a[len(a)-1] | |
206 | } | |
207 | ||
208 | func sampling(r float64) string { | |
209 | var sv string | |
210 | if r < 1.0 { | |
211 | sv = fmt.Sprintf("|@%f", r) | |
212 | } | |
213 | return sv | |
214 | } | |
215 | ||
216 | func (d *Influxstatsd) tagValues(labelValues []string) string { | |
217 | if len(labelValues) == 0 && len(d.lvs) == 0 { | |
218 | return "" | |
219 | } | |
220 | if len(labelValues)%2 != 0 { | |
221 | panic("tagValues received a labelValues with an odd number of strings") | |
222 | } | |
223 | pairs := make([]string, 0, (len(d.lvs)+len(labelValues))/2) | |
224 | for i := 0; i < len(d.lvs); i += 2 { | |
225 | pairs = append(pairs, d.lvs[i]+"="+d.lvs[i+1]) | |
226 | } | |
227 | for i := 0; i < len(labelValues); i += 2 { | |
228 | pairs = append(pairs, labelValues[i]+"="+labelValues[i+1]) | |
229 | } | |
230 | return "," + strings.Join(pairs, ",") | |
231 | } | |
232 | ||
233 | type observeFunc func(name string, lvs lv.LabelValues, value float64) | |
234 | ||
235 | // Counter is a InfluxStatsD counter. Observations are forwarded to a Influxstatsd | |
236 | // object, and aggregated (summed) per timeseries. | |
237 | type Counter struct { | |
238 | name string | |
239 | lvs lv.LabelValues | |
240 | obs observeFunc | |
241 | } | |
242 | ||
243 | // With implements metrics.Counter. | |
244 | func (c *Counter) With(labelValues ...string) metrics.Counter { | |
245 | return &Counter{ | |
246 | name: c.name, | |
247 | lvs: c.lvs.With(labelValues...), | |
248 | obs: c.obs, | |
249 | } | |
250 | } | |
251 | ||
252 | // Add implements metrics.Counter. | |
253 | func (c *Counter) Add(delta float64) { | |
254 | c.obs(c.name, c.lvs, delta) | |
255 | } | |
256 | ||
257 | // Gauge is a InfluxStatsD gauge. Observations are forwarded to a Influxstatsd | |
258 | // object, and aggregated (the last observation selected) per timeseries. | |
259 | type Gauge struct { | |
260 | g *generic.Gauge | |
261 | influx *Influxstatsd | |
262 | set int32 | |
263 | } | |
264 | ||
265 | // With implements metrics.Gauge. | |
266 | func (g *Gauge) With(labelValues ...string) metrics.Gauge { | |
267 | g.influx.mtx.RLock() | |
268 | node := g.influx.gauges[g.g.Name] | |
269 | g.influx.mtx.RUnlock() | |
270 | ||
271 | ga := &Gauge{g: g.g.With(labelValues...).(*generic.Gauge), influx: g.influx} | |
272 | return node.addGauge(ga, ga.g.LabelValues()) | |
273 | } | |
274 | ||
275 | // Set implements metrics.Gauge. | |
276 | func (g *Gauge) Set(value float64) { | |
277 | g.g.Set(value) | |
278 | g.touch() | |
279 | } | |
280 | ||
281 | // Add implements metrics.Gauge. | |
282 | func (g *Gauge) Add(delta float64) { | |
283 | g.g.Add(delta) | |
284 | g.touch() | |
285 | } | |
286 | ||
287 | // Timing is a InfluxStatsD timing, or metrics.Histogram. Observations are | |
288 | // forwarded to a Influxstatsd object, and collected (but not aggregated) per | |
289 | // timeseries. | |
290 | type Timing struct { | |
291 | name string | |
292 | lvs lv.LabelValues | |
293 | obs observeFunc | |
294 | } | |
295 | ||
296 | // With implements metrics.Timing. | |
297 | func (t *Timing) With(labelValues ...string) metrics.Histogram { | |
298 | return &Timing{ | |
299 | name: t.name, | |
300 | lvs: t.lvs.With(labelValues...), | |
301 | obs: t.obs, | |
302 | } | |
303 | } | |
304 | ||
305 | // Observe implements metrics.Histogram. Value is interpreted as milliseconds. | |
306 | func (t *Timing) Observe(value float64) { | |
307 | t.obs(t.name, t.lvs, value) | |
308 | } | |
309 | ||
310 | // Histogram is a InfluxStatsD histrogram. Observations are forwarded to a | |
311 | // Influxstatsd object, and collected (but not aggregated) per timeseries. | |
312 | type Histogram struct { | |
313 | name string | |
314 | lvs lv.LabelValues | |
315 | obs observeFunc | |
316 | } | |
317 | ||
318 | // With implements metrics.Histogram. | |
319 | func (h *Histogram) With(labelValues ...string) metrics.Histogram { | |
320 | return &Histogram{ | |
321 | name: h.name, | |
322 | lvs: h.lvs.With(labelValues...), | |
323 | obs: h.obs, | |
324 | } | |
325 | } | |
326 | ||
327 | // Observe implements metrics.Histogram. | |
328 | func (h *Histogram) Observe(value float64) { | |
329 | h.obs(h.name, h.lvs, value) | |
330 | } | |
331 | ||
332 | type pair struct{ label, value string } | |
333 | ||
334 | type gaugeNode struct { | |
335 | mtx sync.RWMutex | |
336 | gauge *Gauge | |
337 | children map[pair]*gaugeNode | |
338 | } | |
339 | ||
340 | func (n *gaugeNode) addGauge(g *Gauge, lvs lv.LabelValues) *Gauge { | |
341 | n.mtx.Lock() | |
342 | defer n.mtx.Unlock() | |
343 | if len(lvs) == 0 { | |
344 | if n.gauge == nil { | |
345 | n.gauge = g | |
346 | } | |
347 | return n.gauge | |
348 | } | |
349 | if len(lvs) < 2 { | |
350 | panic("too few LabelValues; programmer error!") | |
351 | } | |
352 | head, tail := pair{lvs[0], lvs[1]}, lvs[2:] | |
353 | if n.children == nil { | |
354 | n.children = map[pair]*gaugeNode{} | |
355 | } | |
356 | child, ok := n.children[head] | |
357 | if !ok { | |
358 | child = &gaugeNode{} | |
359 | n.children[head] = child | |
360 | } | |
361 | return child.addGauge(g, tail) | |
362 | } | |
363 | ||
364 | func (n *gaugeNode) walk(fn func(string, lv.LabelValues, float64) bool) bool { | |
365 | n.mtx.RLock() | |
366 | defer n.mtx.RUnlock() | |
367 | if n.gauge != nil { | |
368 | value, ok := n.gauge.read() | |
369 | if ok && !fn(n.gauge.g.Name, n.gauge.g.LabelValues(), value) { | |
370 | return false | |
371 | } | |
372 | } | |
373 | for _, child := range n.children { | |
374 | if !child.walk(fn) { | |
375 | return false | |
376 | } | |
377 | } | |
378 | return true | |
379 | } | |
380 | ||
381 | func (g *Gauge) touch() { | |
382 | atomic.StoreInt32(&(g.set), 1) | |
383 | } | |
384 | ||
385 | func (g *Gauge) read() (float64, bool) { | |
386 | set := atomic.SwapInt32(&(g.set), 0) | |
387 | return g.g.Value(), set != 0 | |
388 | } |
0 | package influxstatsd | |
1 | ||
2 | import ( | |
3 | "testing" | |
4 | ||
5 | "github.com/go-kit/kit/log" | |
6 | "github.com/go-kit/kit/metrics/teststat" | |
7 | ) | |
8 | ||
9 | func TestCounter(t *testing.T) { | |
10 | prefix, name := "abc.", "def" | |
11 | label, value := "label", "value" | |
12 | regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|c$` | |
13 | d := New(prefix, log.NewNopLogger()) | |
14 | counter := d.NewCounter(name, 1.0).With(label, value) | |
15 | valuef := teststat.SumLines(d, regex) | |
16 | if err := teststat.TestCounter(counter, valuef); err != nil { | |
17 | t.Fatal(err) | |
18 | } | |
19 | } | |
20 | ||
21 | func TestCounterSampled(t *testing.T) { | |
22 | // This will involve multiplying the observed sum by the inverse of the | |
23 | // sample rate and checking against the expected value within some | |
24 | // tolerance. | |
25 | t.Skip("TODO") | |
26 | } | |
27 | ||
28 | func TestGauge(t *testing.T) { | |
29 | prefix, name := "ghi.", "jkl" | |
30 | label, value := "xyz", "abc" | |
31 | regex := `^` + prefix + name + `,hostname=foohost,` + label + `=` + value + `:([0-9\.]+)\|g$` | |
32 | d := New(prefix, log.NewNopLogger(), "hostname", "foohost") | |
33 | gauge := d.NewGauge(name).With(label, value) | |
34 | valuef := teststat.LastLine(d, regex) | |
35 | if err := teststat.TestGauge(gauge, valuef); err != nil { | |
36 | t.Fatal(err) | |
37 | } | |
38 | } | |
39 | ||
40 | // InfluxStatsD histograms just emit all observations. So, we collect them into | |
41 | // a generic histogram, and run the statistics test on that. | |
42 | ||
43 | func TestHistogram(t *testing.T) { | |
44 | prefix, name := "influxstatsd.", "histogram_test" | |
45 | label, value := "abc", "def" | |
46 | regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|h$` | |
47 | d := New(prefix, log.NewNopLogger()) | |
48 | histogram := d.NewHistogram(name, 1.0).With(label, value) | |
49 | quantiles := teststat.Quantiles(d, regex, 50) // no |@0.X | |
50 | if err := teststat.TestHistogram(histogram, quantiles, 0.01); err != nil { | |
51 | t.Fatal(err) | |
52 | } | |
53 | } | |
54 | ||
55 | func TestHistogramSampled(t *testing.T) { | |
56 | prefix, name := "influxstatsd.", "sampled_histogram_test" | |
57 | label, value := "foo", "bar" | |
58 | regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|h\|@0\.01[0]*$` | |
59 | d := New(prefix, log.NewNopLogger()) | |
60 | histogram := d.NewHistogram(name, 0.01).With(label, value) | |
61 | quantiles := teststat.Quantiles(d, regex, 50) | |
62 | if err := teststat.TestHistogram(histogram, quantiles, 0.02); err != nil { | |
63 | t.Fatal(err) | |
64 | } | |
65 | } | |
66 | ||
67 | func TestTiming(t *testing.T) { | |
68 | prefix, name := "influxstatsd.", "timing_test" | |
69 | label, value := "wiggle", "bottom" | |
70 | regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|ms$` | |
71 | d := New(prefix, log.NewNopLogger()) | |
72 | histogram := d.NewTiming(name, 1.0).With(label, value) | |
73 | quantiles := teststat.Quantiles(d, regex, 50) // no |@0.X | |
74 | if err := teststat.TestHistogram(histogram, quantiles, 0.01); err != nil { | |
75 | t.Fatal(err) | |
76 | } | |
77 | } | |
78 | ||
79 | func TestTimingSampled(t *testing.T) { | |
80 | prefix, name := "influxstatsd.", "sampled_timing_test" | |
81 | label, value := "internal", "external" | |
82 | regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|ms\|@0.03[0]*$` | |
83 | d := New(prefix, log.NewNopLogger()) | |
84 | histogram := d.NewTiming(name, 0.03).With(label, value) | |
85 | quantiles := teststat.Quantiles(d, regex, 50) | |
86 | if err := teststat.TestHistogram(histogram, quantiles, 0.02); err != nil { | |
87 | t.Fatal(err) | |
88 | } | |
89 | } |
16 | 16 | if want, have := name, low.Name; want != have { |
17 | 17 | t.Errorf("Name: want %q, have %q", want, have) |
18 | 18 | } |
19 | value := func() float64 { return low.Value() } | |
20 | if err := teststat.TestCounter(top, value); err != nil { | |
19 | if err := teststat.TestCounter(top, low.Value); err != nil { | |
21 | 20 | t.Fatal(err) |
22 | 21 | } |
23 | 22 | } |
32 | 31 | if want, have := name, low.Name; want != have { |
33 | 32 | t.Errorf("Name: want %q, have %q", want, have) |
34 | 33 | } |
35 | value := func() float64 { return low.Value() } | |
36 | if err := teststat.TestCounter(top, value); err != nil { | |
34 | if err := teststat.TestCounter(top, low.Value); err != nil { | |
37 | 35 | t.Fatal(err) |
38 | 36 | } |
39 | 37 | } |
39 | 39 | |
40 | 40 | have := map[string]float64{} |
41 | 41 | s.Walk(func(name string, lvs LabelValues, obs []float64) bool { |
42 | //t.Logf("%s %v => %v", name, lvs, obs) | |
43 | 42 | have[name+" ["+strings.Join(lvs, "")+"]"] += sum(obs) |
44 | 43 | return true |
45 | 44 | }) |
11 | 11 | "strings" |
12 | 12 | "testing" |
13 | 13 | |
14 | "github.com/go-kit/kit/metrics/teststat" | |
14 | 15 | stdprometheus "github.com/prometheus/client_golang/prometheus" |
15 | ||
16 | "github.com/go-kit/kit/metrics/teststat" | |
17 | 16 | ) |
18 | 17 | |
19 | 18 | func TestCounter(t *testing.T) { |
197 | 196 | if !ok { |
198 | 197 | t.Fatalf("expected error, got %s", reflect.TypeOf(x)) |
199 | 198 | } |
200 | if want, have := "inconsistent label cardinality", err.Error(); want != have { | |
199 | if want, have := "inconsistent label cardinality", err.Error(); !strings.HasPrefix(have, want) { | |
201 | 200 | t.Fatalf("want %q, have %q", want, have) |
202 | 201 | } |
203 | 202 | }() |
46 | 46 | re := regexp.MustCompile(regex) |
47 | 47 | buf := &bytes.Buffer{} |
48 | 48 | w.WriteTo(buf) |
49 | //fmt.Fprintf(os.Stderr, "%s\n", buf.String()) | |
50 | 49 | s := bufio.NewScanner(buf) |
51 | 50 | for s.Scan() { |
52 | 51 | match := re.FindStringSubmatch(s.Text()) |
63 | 63 | z = math.Sqrt(-math.Log((1.0 - y) / 2.0)) |
64 | 64 | x = (((c[3]*z+c[2])*z+c[1])*z + c[0]) / ((d[1]*z+d[0])*z + 1.0) |
65 | 65 | } |
66 | x = x - (math.Erf(x)-y)/(2.0/math.SqrtPi*math.Exp(-x*x)) | |
67 | x = x - (math.Erf(x)-y)/(2.0/math.SqrtPi*math.Exp(-x*x)) | |
66 | x -= (math.Erf(x) - y) / (2.0 / math.SqrtPi * math.Exp(-x*x)) | |
67 | x -= (math.Erf(x) - y) / (2.0 / math.SqrtPi * math.Exp(-x*x)) | |
68 | 68 | } |
69 | 69 | |
70 | 70 | return x |
0 | 0 | package consul |
1 | 1 | |
2 | 2 | import ( |
3 | "errors" | |
3 | 4 | "fmt" |
4 | "io" | |
5 | "time" | |
5 | 6 | |
6 | 7 | consul "github.com/hashicorp/consul/api" |
7 | 8 | |
8 | 9 | "github.com/go-kit/kit/log" |
9 | 10 | "github.com/go-kit/kit/sd" |
10 | 11 | "github.com/go-kit/kit/sd/internal/instance" |
12 | "github.com/go-kit/kit/util/conn" | |
11 | 13 | ) |
12 | 14 | |
13 | 15 | const defaultIndex = 0 |
16 | ||
17 | // errStopped notifies the loop to quit. aka stopped via quitc | |
18 | var errStopped = errors.New("quit and closed consul instancer") | |
14 | 19 | |
15 | 20 | // Instancer yields instances for a service in Consul. |
16 | 21 | type Instancer struct { |
58 | 63 | var ( |
59 | 64 | instances []string |
60 | 65 | err error |
66 | d time.Duration = 10 * time.Millisecond | |
61 | 67 | ) |
62 | 68 | for { |
63 | 69 | instances, lastIndex, err = s.getInstances(lastIndex, s.quitc) |
64 | 70 | switch { |
65 | case err == io.EOF: | |
71 | case err == errStopped: | |
66 | 72 | return // stopped via quitc |
67 | 73 | case err != nil: |
68 | 74 | s.logger.Log("err", err) |
75 | time.Sleep(d) | |
76 | d = conn.Exponential(d) | |
69 | 77 | s.cache.Update(sd.Event{Err: err}) |
70 | 78 | default: |
71 | 79 | s.cache.Update(sd.Event{Instances: instances}) |
80 | d = 10 * time.Millisecond | |
72 | 81 | } |
73 | 82 | } |
74 | 83 | } |
118 | 127 | case res := <-resc: |
119 | 128 | return res.instances, res.index, nil |
120 | 129 | case <-interruptc: |
121 | return nil, 0, io.EOF | |
130 | return nil, 0, errStopped | |
122 | 131 | } |
123 | 132 | } |
124 | 133 |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "context" |
4 | consul "github.com/hashicorp/consul/api" | |
5 | "io" | |
4 | 6 | "testing" |
5 | ||
6 | consul "github.com/hashicorp/consul/api" | |
7 | "time" | |
7 | 8 | |
8 | 9 | "github.com/go-kit/kit/log" |
9 | 10 | "github.com/go-kit/kit/sd" |
130 | 131 | t.Errorf("want %q, have %q", want, have) |
131 | 132 | } |
132 | 133 | } |
134 | ||
135 | type eofTestClient struct { | |
136 | client *testClient | |
137 | eofSig chan bool | |
138 | called chan struct{} | |
139 | } | |
140 | ||
141 | func neweofTestClient(client *testClient, sig chan bool, called chan struct{}) Client { | |
142 | return &eofTestClient{client: client, eofSig: sig, called: called} | |
143 | } | |
144 | ||
145 | func (c *eofTestClient) Register(r *consul.AgentServiceRegistration) error { | |
146 | return c.client.Register(r) | |
147 | } | |
148 | ||
149 | func (c *eofTestClient) Deregister(r *consul.AgentServiceRegistration) error { | |
150 | return c.client.Deregister(r) | |
151 | } | |
152 | ||
153 | func (c *eofTestClient) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { | |
154 | c.called <- struct{}{} | |
155 | shouldEOF := <-c.eofSig | |
156 | if shouldEOF { | |
157 | return nil, &consul.QueryMeta{}, io.EOF | |
158 | } | |
159 | return c.client.Service(service, tag, passingOnly, queryOpts) | |
160 | } | |
161 | ||
162 | func TestInstancerWithEOF(t *testing.T) { | |
163 | var ( | |
164 | sig = make(chan bool, 1) | |
165 | called = make(chan struct{}, 1) | |
166 | logger = log.NewNopLogger() | |
167 | client = neweofTestClient(newTestClient(consulState), sig, called) | |
168 | ) | |
169 | ||
170 | sig <- false | |
171 | s := NewInstancer(client, logger, "search", []string{"api"}, true) | |
172 | defer s.Stop() | |
173 | ||
174 | select { | |
175 | case <-called: | |
176 | case <-time.Tick(time.Millisecond * 500): | |
177 | t.Error("failed, to receive call") | |
178 | } | |
179 | ||
180 | state := s.cache.State() | |
181 | if want, have := 2, len(state.Instances); want != have { | |
182 | t.Errorf("want %d, have %d", want, have) | |
183 | } | |
184 | ||
185 | // some error occurred resulting in io.EOF | |
186 | sig <- true | |
187 | ||
188 | // Service Called Once | |
189 | select { | |
190 | case <-called: | |
191 | case <-time.Tick(time.Millisecond * 500): | |
192 | t.Error("failed, to receive call in time") | |
193 | } | |
194 | ||
195 | sig <- false | |
196 | ||
197 | // loop should continue | |
198 | select { | |
199 | case <-called: | |
200 | case <-time.Tick(time.Millisecond * 500): | |
201 | t.Error("failed, to receive call in time") | |
202 | } | |
203 | } |
16 | 16 | func TestIntegration(t *testing.T) { |
17 | 17 | consulAddr := os.Getenv("CONSUL_ADDR") |
18 | 18 | if consulAddr == "" { |
19 | t.Fatal("CONSUL_ADDR is not set") | |
19 | t.Skip("CONSUL_ADDR not set; skipping integration test") | |
20 | 20 | } |
21 | 21 | stdClient, err := stdconsul.NewClient(&stdconsul.Config{ |
22 | 22 | Address: consulAddr, |
9 | 9 | "net/http" |
10 | 10 | "time" |
11 | 11 | |
12 | etcd "github.com/coreos/etcd/client" | |
12 | etcd "go.etcd.io/etcd/client" | |
13 | 13 | ) |
14 | 14 | |
15 | 15 | var ( |
7 | 7 | |
8 | 8 | "golang.org/x/net/context" |
9 | 9 | |
10 | etcd "github.com/coreos/etcd/client" | |
10 | etcd "go.etcd.io/etcd/client" | |
11 | 11 | ) |
12 | 12 | |
13 | 13 | func TestNewClient(t *testing.T) { |
3 | 3 | "errors" |
4 | 4 | "testing" |
5 | 5 | |
6 | stdetcd "github.com/coreos/etcd/client" | |
6 | stdetcd "go.etcd.io/etcd/client" | |
7 | 7 | |
8 | 8 | "github.com/go-kit/kit/log" |
9 | 9 | "github.com/go-kit/kit/sd" |
3 | 3 | "sync" |
4 | 4 | "time" |
5 | 5 | |
6 | etcd "github.com/coreos/etcd/client" | |
6 | etcd "go.etcd.io/etcd/client" | |
7 | 7 | |
8 | 8 | "github.com/go-kit/kit/log" |
9 | 9 | ) |
5 | 5 | "errors" |
6 | 6 | "time" |
7 | 7 | |
8 | "github.com/coreos/etcd/clientv3" | |
9 | "github.com/coreos/etcd/pkg/transport" | |
8 | "go.etcd.io/etcd/clientv3" | |
9 | "go.etcd.io/etcd/pkg/transport" | |
10 | 10 | ) |
11 | 11 | |
12 | 12 | var ( |
227 | 227 | } |
228 | 228 | if c.watcher != nil { |
229 | 229 | c.watcher.Close() |
230 | } | |
231 | if c.wcf != nil { | |
230 | 232 | c.wcf() |
231 | 233 | } |
232 | 234 | } |
33 | 33 | |
34 | 34 | // Register our instance. |
35 | 35 | registrar.Register() |
36 | t.Logf("Registered") | |
36 | t.Log("Registered") | |
37 | 37 | |
38 | 38 | // Retrieve entries from etcd manually. |
39 | 39 | entries, err = client.GetEntries(settings.key) |
55 | 55 | if err != nil { |
56 | 56 | t.Fatalf("NewInstancer: %v", err) |
57 | 57 | } |
58 | t.Logf("Constructed Instancer OK") | |
58 | t.Log("Constructed Instancer OK") | |
59 | 59 | defer instancer.Stop() |
60 | 60 | |
61 | 61 | endpointer := sd.NewEndpointer( |
63 | 63 | func(string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, nil, nil }, |
64 | 64 | log.With(log.NewLogfmtLogger(os.Stderr), "component", "instancer"), |
65 | 65 | ) |
66 | t.Logf("Constructed Endpointer OK") | |
66 | t.Log("Constructed Endpointer OK") | |
67 | 67 | defer endpointer.Close() |
68 | 68 | |
69 | 69 | if !within(time.Second, func() bool { |
70 | 70 | endpoints, err := endpointer.Endpoints() |
71 | 71 | return err == nil && len(endpoints) == 1 |
72 | 72 | }) { |
73 | t.Fatalf("Endpointer didn't see Register in time") | |
74 | } | |
75 | t.Logf("Endpointer saw Register OK") | |
73 | t.Fatal("Endpointer didn't see Register in time") | |
74 | } | |
75 | t.Log("Endpointer saw Register OK") | |
76 | 76 | |
77 | 77 | // Deregister first instance of test data. |
78 | 78 | registrar.Deregister() |
79 | t.Logf("Deregistered") | |
79 | t.Log("Deregistered") | |
80 | 80 | |
81 | 81 | // Check it was deregistered. |
82 | 82 | if !within(time.Second, func() bool { |
163 | 163 | runIntegration(settings, client, service, t) |
164 | 164 | } |
165 | 165 | |
166 | func TestIntegrationRegistrarOnly(t *testing.T) { | |
167 | settings := testIntegrationSettings(t) | |
168 | client, err := NewClient(context.Background(), []string{settings.addr}, ClientOptions{ | |
169 | DialTimeout: 2 * time.Second, | |
170 | DialKeepAlive: 2 * time.Second, | |
171 | }) | |
172 | if err != nil { | |
173 | t.Fatalf("NewClient(%q): %v", settings.addr, err) | |
174 | } | |
175 | ||
176 | service := Service{ | |
177 | Key: settings.key, | |
178 | Value: settings.value, | |
179 | TTL: NewTTLOption(time.Second*3, time.Second*10), | |
180 | } | |
181 | defer client.Deregister(service) | |
182 | ||
183 | // Verify test data is initially empty. | |
184 | entries, err := client.GetEntries(settings.key) | |
185 | if err != nil { | |
186 | t.Fatalf("GetEntries(%q): expected no error, got one: %v", settings.key, err) | |
187 | } | |
188 | if len(entries) > 0 { | |
189 | t.Fatalf("GetEntries(%q): expected no instance entries, got %d", settings.key, len(entries)) | |
190 | } | |
191 | t.Logf("GetEntries(%q): %v (OK)", settings.key, entries) | |
192 | ||
193 | // Instantiate a new Registrar, passing in test data. | |
194 | registrar := NewRegistrar( | |
195 | client, | |
196 | service, | |
197 | log.With(log.NewLogfmtLogger(os.Stderr), "component", "registrar"), | |
198 | ) | |
199 | ||
200 | // Register our instance. | |
201 | registrar.Register() | |
202 | t.Log("Registered") | |
203 | ||
204 | // Deregister our instance. (so we test registrar only scenario) | |
205 | registrar.Deregister() | |
206 | t.Log("Deregistered") | |
207 | ||
208 | } | |
209 | ||
166 | 210 | func within(d time.Duration, f func() bool) bool { |
167 | 211 | deadline := time.Now().Add(d) |
168 | 212 | for time.Now().Before(deadline) { |
3 | 3 | |
4 | 4 | import ( |
5 | 5 | "bytes" |
6 | "log" | |
7 | 6 | "os" |
8 | 7 | "testing" |
9 | 8 | "time" |
17 | 16 | |
18 | 17 | func TestMain(m *testing.M) { |
19 | 18 | zkAddr := os.Getenv("ZK_ADDR") |
20 | if zkAddr == "" { | |
21 | log.Fatal("ZK_ADDR is not set") | |
22 | } | |
23 | host = []string{zkAddr} | |
19 | if zkAddr != "" { | |
20 | host = []string{zkAddr} | |
21 | } | |
22 | m.Run() | |
24 | 23 | } |
25 | 24 | |
26 | 25 | func TestCreateParentNodesOnServer(t *testing.T) { |
26 | if len(host) == 0 { | |
27 | t.Skip("ZK_ADDR not set; skipping integration test") | |
28 | } | |
27 | 29 | payload := [][]byte{[]byte("Payload"), []byte("Test")} |
28 | 30 | c1, err := NewClient(host, logger, Payload(payload)) |
29 | 31 | if err != nil { |
66 | 68 | } |
67 | 69 | |
68 | 70 | func TestCreateBadParentNodesOnServer(t *testing.T) { |
71 | if len(host) == 0 { | |
72 | t.Skip("ZK_ADDR not set; skipping integration test") | |
73 | } | |
69 | 74 | c, _ := NewClient(host, logger) |
70 | 75 | defer c.Stop() |
71 | 76 | |
77 | 82 | } |
78 | 83 | |
79 | 84 | func TestCredentials1(t *testing.T) { |
85 | if len(host) == 0 { | |
86 | t.Skip("ZK_ADDR not set; skipping integration test") | |
87 | } | |
80 | 88 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") |
81 | 89 | c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret")) |
82 | 90 | defer c.Stop() |
89 | 97 | } |
90 | 98 | |
91 | 99 | func TestCredentials2(t *testing.T) { |
100 | if len(host) == 0 { | |
101 | t.Skip("ZK_ADDR not set; skipping integration test") | |
102 | } | |
92 | 103 | acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret") |
93 | 104 | c, _ := NewClient(host, logger, ACL(acl)) |
94 | 105 | defer c.Stop() |
101 | 112 | } |
102 | 113 | |
103 | 114 | func TestConnection(t *testing.T) { |
115 | if len(host) == 0 { | |
116 | t.Skip("ZK_ADDR not set; skipping integration test") | |
117 | } | |
104 | 118 | c, _ := NewClient(host, logger) |
105 | 119 | c.Stop() |
106 | 120 | |
112 | 126 | } |
113 | 127 | |
114 | 128 | func TestGetEntriesOnServer(t *testing.T) { |
129 | if len(host) == 0 { | |
130 | t.Skip("ZK_ADDR not set; skipping integration test") | |
131 | } | |
115 | 132 | var instancePayload = "10.0.3.204:8002" |
116 | 133 | |
117 | 134 | c1, err := NewClient(host, logger) |
157 | 174 | } |
158 | 175 | |
159 | 176 | func TestGetEntriesPayloadOnServer(t *testing.T) { |
177 | if len(host) == 0 { | |
178 | t.Skip("ZK_ADDR not set; skipping integration test") | |
179 | } | |
160 | 180 | c, err := NewClient(host, logger) |
161 | 181 | if err != nil { |
162 | 182 | t.Fatalf("Connect returned error: %v", err) |
18 | 18 | binding to use. Instrumentation exists for `kit/transport/http` and |
19 | 19 | `kit/transport/grpc`. The bindings are highlighted in the [addsvc] example. For |
20 | 20 | more information regarding Zipkin feel free to visit [Zipkin's Gitter]. |
21 | ||
22 | ## OpenCensus | |
23 | ||
24 | Go kit supports transport and endpoint middlewares for the [OpenCensus] | |
25 | instrumentation library. OpenCensus provides a cross language consistent data | |
26 | model and instrumentation libraries for tracing and metrics. From this data | |
27 | model it allows exports to various tracing and metrics backends including but | |
28 | not limited to Zipkin, Prometheus, Stackdriver Trace & Monitoring, Jaeger, | |
29 | AWS X-Ray and Datadog. Go kit uses the [opencensus-go] implementation to power | |
30 | its middlewares. | |
21 | 31 | |
22 | 32 | ## OpenTracing |
23 | 33 | |
61 | 71 | |
62 | 72 | [Zipkin]: http://zipkin.io/ |
63 | 73 | [Open Zipkin GitHub]: https://github.com/openzipkin |
64 | [zipkin-go-opentracing]: https://github.com/openzipkin/zipkin-go-opentracing | |
74 | [zipkin-go-opentracing]: https://github.com/openzipkin-contrib/zipkin-go-opentracing | |
65 | 75 | [zipkin-go]: https://github.com/openzipkin/zipkin-go |
66 | 76 | [Zipkin's Gitter]: https://gitter.im/openzipkin/zipkin |
67 | 77 | |
70 | 80 | |
71 | 81 | [LightStep]: http://lightstep.com/ |
72 | 82 | [lightstep-tracer-go]: https://github.com/lightstep/lightstep-tracer-go |
83 | ||
84 | [OpenCensus]: https://opencensus.io/ | |
85 | [opencensus-go]: https://github.com/census-instrumentation/opencensus-go |
0 | // Package opencensus provides Go kit integration to the OpenCensus project. | |
1 | // OpenCensus is a single distribution of libraries for metrics and distributed | |
2 | // tracing with minimal overhead that allows you to export data to multiple | |
3 | // backends. The Go kit OpenCencus package as provided here contains middlewares | |
4 | // for tracing. | |
5 | package opencensus |
0 | package opencensus | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "strconv" | |
5 | ||
6 | "go.opencensus.io/trace" | |
7 | ||
8 | "github.com/go-kit/kit/endpoint" | |
9 | "github.com/go-kit/kit/sd/lb" | |
10 | ) | |
11 | ||
12 | // TraceEndpointDefaultName is the default endpoint span name to use. | |
13 | const TraceEndpointDefaultName = "gokit/endpoint" | |
14 | ||
15 | // TraceEndpoint returns an Endpoint middleware, tracing a Go kit endpoint. | |
16 | // This endpoint tracer should be used in combination with a Go kit Transport | |
17 | // tracing middleware, generic OpenCensus transport middleware or custom before | |
18 | // and after transport functions as service propagation of SpanContext is not | |
19 | // provided in this middleware. | |
20 | func TraceEndpoint(name string, options ...EndpointOption) endpoint.Middleware { | |
21 | if name == "" { | |
22 | name = TraceEndpointDefaultName | |
23 | } | |
24 | ||
25 | cfg := &EndpointOptions{} | |
26 | ||
27 | for _, o := range options { | |
28 | o(cfg) | |
29 | } | |
30 | ||
31 | return func(next endpoint.Endpoint) endpoint.Endpoint { | |
32 | return func(ctx context.Context, request interface{}) (response interface{}, err error) { | |
33 | ctx, span := trace.StartSpan(ctx, name) | |
34 | if len(cfg.Attributes) > 0 { | |
35 | span.AddAttributes(cfg.Attributes...) | |
36 | } | |
37 | defer span.End() | |
38 | ||
39 | defer func() { | |
40 | if err != nil { | |
41 | if lberr, ok := err.(lb.RetryError); ok { | |
42 | // handle errors originating from lb.Retry | |
43 | attrs := make([]trace.Attribute, 0, len(lberr.RawErrors)) | |
44 | for idx, rawErr := range lberr.RawErrors { | |
45 | attrs = append(attrs, trace.StringAttribute( | |
46 | "gokit.retry.error."+strconv.Itoa(idx+1), rawErr.Error(), | |
47 | )) | |
48 | } | |
49 | span.AddAttributes(attrs...) | |
50 | span.SetStatus(trace.Status{ | |
51 | Code: trace.StatusCodeUnknown, | |
52 | Message: lberr.Final.Error(), | |
53 | }) | |
54 | return | |
55 | } | |
56 | // generic error | |
57 | span.SetStatus(trace.Status{ | |
58 | Code: trace.StatusCodeUnknown, | |
59 | Message: err.Error(), | |
60 | }) | |
61 | return | |
62 | } | |
63 | ||
64 | // test for business error | |
65 | if res, ok := response.(endpoint.Failer); ok && res.Failed() != nil { | |
66 | span.AddAttributes( | |
67 | trace.StringAttribute("gokit.business.error", res.Failed().Error()), | |
68 | ) | |
69 | if cfg.IgnoreBusinessError { | |
70 | span.SetStatus(trace.Status{Code: trace.StatusCodeOK}) | |
71 | return | |
72 | } | |
73 | // treating business error as real error in span. | |
74 | span.SetStatus(trace.Status{ | |
75 | Code: trace.StatusCodeUnknown, | |
76 | Message: res.Failed().Error(), | |
77 | }) | |
78 | return | |
79 | } | |
80 | ||
81 | // no errors identified | |
82 | span.SetStatus(trace.Status{Code: trace.StatusCodeOK}) | |
83 | }() | |
84 | response, err = next(ctx, request) | |
85 | return | |
86 | } | |
87 | } | |
88 | } |
0 | package opencensus | |
1 | ||
2 | import "go.opencensus.io/trace" | |
3 | ||
4 | // EndpointOptions holds the options for tracing an endpoint | |
5 | type EndpointOptions struct { | |
6 | // IgnoreBusinessError if set to true will not treat a business error | |
7 | // identified through the endpoint.Failer interface as a span error. | |
8 | IgnoreBusinessError bool | |
9 | ||
10 | // Attributes holds the default attributes which will be set on span | |
11 | // creation by our Endpoint middleware. | |
12 | Attributes []trace.Attribute | |
13 | } | |
14 | ||
15 | // EndpointOption allows for functional options to our OpenCensus endpoint | |
16 | // tracing middleware. | |
17 | type EndpointOption func(*EndpointOptions) | |
18 | ||
19 | // WithEndpointConfig sets all configuration options at once by use of the | |
20 | // EndpointOptions struct. | |
21 | func WithEndpointConfig(options EndpointOptions) EndpointOption { | |
22 | return func(o *EndpointOptions) { | |
23 | *o = options | |
24 | } | |
25 | } | |
26 | ||
27 | // WithEndpointAttributes sets the default attributes for the spans created by | |
28 | // the Endpoint tracer. | |
29 | func WithEndpointAttributes(attrs ...trace.Attribute) EndpointOption { | |
30 | return func(o *EndpointOptions) { | |
31 | o.Attributes = attrs | |
32 | } | |
33 | } | |
34 | ||
35 | // WithIgnoreBusinessError if set to true will not treat a business error | |
36 | // identified through the endpoint.Failer interface as a span error. | |
37 | func WithIgnoreBusinessError(val bool) EndpointOption { | |
38 | return func(o *EndpointOptions) { | |
39 | o.IgnoreBusinessError = val | |
40 | } | |
41 | } |
0 | package opencensus_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "errors" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | "go.opencensus.io/trace" | |
9 | ||
10 | "github.com/go-kit/kit/endpoint" | |
11 | "github.com/go-kit/kit/sd" | |
12 | "github.com/go-kit/kit/sd/lb" | |
13 | "github.com/go-kit/kit/tracing/opencensus" | |
14 | ) | |
15 | ||
16 | const ( | |
17 | span1 = "" | |
18 | span2 = "SPAN-2" | |
19 | span3 = "SPAN-3" | |
20 | span4 = "SPAN-4" | |
21 | span5 = "SPAN-5" | |
22 | ) | |
23 | ||
24 | var ( | |
25 | err1 = errors.New("some error") | |
26 | err2 = errors.New("other error") | |
27 | err3 = errors.New("some business error") | |
28 | err4 = errors.New("other business error") | |
29 | ) | |
30 | ||
31 | // compile time assertion | |
32 | var _ endpoint.Failer = failedResponse{} | |
33 | ||
34 | type failedResponse struct { | |
35 | err error | |
36 | } | |
37 | ||
38 | func (r failedResponse) Failed() error { return r.err } | |
39 | ||
40 | func passEndpoint(_ context.Context, req interface{}) (interface{}, error) { | |
41 | if err, _ := req.(error); err != nil { | |
42 | return nil, err | |
43 | } | |
44 | return req, nil | |
45 | } | |
46 | ||
47 | func TestTraceEndpoint(t *testing.T) { | |
48 | ctx := context.Background() | |
49 | ||
50 | e := &recordingExporter{} | |
51 | trace.RegisterExporter(e) | |
52 | trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) | |
53 | ||
54 | // span 1 | |
55 | span1Attrs := []trace.Attribute{ | |
56 | trace.StringAttribute("string", "value"), | |
57 | trace.Int64Attribute("int64", 42), | |
58 | } | |
59 | mw := opencensus.TraceEndpoint( | |
60 | span1, opencensus.WithEndpointAttributes(span1Attrs...), | |
61 | ) | |
62 | mw(endpoint.Nop)(ctx, nil) | |
63 | ||
64 | // span 2 | |
65 | opts := opencensus.EndpointOptions{} | |
66 | mw = opencensus.TraceEndpoint(span2, opencensus.WithEndpointConfig(opts)) | |
67 | mw(passEndpoint)(ctx, err1) | |
68 | ||
69 | // span3 | |
70 | mw = opencensus.TraceEndpoint(span3) | |
71 | ep := lb.Retry(5, 1*time.Second, lb.NewRoundRobin(sd.FixedEndpointer{passEndpoint})) | |
72 | mw(ep)(ctx, err2) | |
73 | ||
74 | // span4 | |
75 | mw = opencensus.TraceEndpoint(span4) | |
76 | mw(passEndpoint)(ctx, failedResponse{err: err3}) | |
77 | ||
78 | // span4 | |
79 | mw = opencensus.TraceEndpoint(span5, opencensus.WithIgnoreBusinessError(true)) | |
80 | mw(passEndpoint)(ctx, failedResponse{err: err4}) | |
81 | ||
82 | // check span count | |
83 | spans := e.Flush() | |
84 | if want, have := 5, len(spans); want != have { | |
85 | t.Fatalf("incorrected number of spans, wanted %d, got %d", want, have) | |
86 | } | |
87 | ||
88 | // test span 1 | |
89 | span := spans[0] | |
90 | if want, have := int32(trace.StatusCodeOK), span.Code; want != have { | |
91 | t.Errorf("incorrect status code, wanted %d, got %d", want, have) | |
92 | } | |
93 | ||
94 | if want, have := opencensus.TraceEndpointDefaultName, span.Name; want != have { | |
95 | t.Errorf("incorrect span name, wanted %q, got %q", want, have) | |
96 | } | |
97 | ||
98 | if want, have := 2, len(span.Attributes); want != have { | |
99 | t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have) | |
100 | } | |
101 | ||
102 | // test span 2 | |
103 | span = spans[1] | |
104 | if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have { | |
105 | t.Errorf("incorrect status code, wanted %d, got %d", want, have) | |
106 | } | |
107 | ||
108 | if want, have := span2, span.Name; want != have { | |
109 | t.Errorf("incorrect span name, wanted %q, got %q", want, have) | |
110 | } | |
111 | ||
112 | if want, have := 0, len(span.Attributes); want != have { | |
113 | t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have) | |
114 | } | |
115 | ||
116 | // test span 3 | |
117 | span = spans[2] | |
118 | if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have { | |
119 | t.Errorf("incorrect status code, wanted %d, got %d", want, have) | |
120 | } | |
121 | ||
122 | if want, have := span3, span.Name; want != have { | |
123 | t.Errorf("incorrect span name, wanted %q, got %q", want, have) | |
124 | } | |
125 | ||
126 | if want, have := 5, len(span.Attributes); want != have { | |
127 | t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have) | |
128 | } | |
129 | ||
130 | // test span 4 | |
131 | span = spans[3] | |
132 | if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have { | |
133 | t.Errorf("incorrect status code, wanted %d, got %d", want, have) | |
134 | } | |
135 | ||
136 | if want, have := span4, span.Name; want != have { | |
137 | t.Errorf("incorrect span name, wanted %q, got %q", want, have) | |
138 | } | |
139 | ||
140 | if want, have := 1, len(span.Attributes); want != have { | |
141 | t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have) | |
142 | } | |
143 | ||
144 | // test span 5 | |
145 | span = spans[4] | |
146 | if want, have := int32(trace.StatusCodeOK), span.Code; want != have { | |
147 | t.Errorf("incorrect status code, wanted %d, got %d", want, have) | |
148 | } | |
149 | ||
150 | if want, have := span5, span.Name; want != have { | |
151 | t.Errorf("incorrect span name, wanted %q, got %q", want, have) | |
152 | } | |
153 | ||
154 | if want, have := 1, len(span.Attributes); want != have { | |
155 | t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have) | |
156 | } | |
157 | ||
158 | } |
0 | package opencensus | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | ||
5 | "go.opencensus.io/trace" | |
6 | "go.opencensus.io/trace/propagation" | |
7 | "google.golang.org/grpc/codes" | |
8 | "google.golang.org/grpc/metadata" | |
9 | "google.golang.org/grpc/status" | |
10 | ||
11 | kitgrpc "github.com/go-kit/kit/transport/grpc" | |
12 | ) | |
13 | ||
14 | const propagationKey = "grpc-trace-bin" | |
15 | ||
16 | // GRPCClientTrace enables OpenCensus tracing of a Go kit gRPC transport client. | |
17 | func GRPCClientTrace(options ...TracerOption) kitgrpc.ClientOption { | |
18 | cfg := TracerOptions{} | |
19 | ||
20 | for _, option := range options { | |
21 | option(&cfg) | |
22 | } | |
23 | ||
24 | if cfg.Sampler == nil { | |
25 | cfg.Sampler = trace.AlwaysSample() | |
26 | } | |
27 | ||
28 | clientBefore := kitgrpc.ClientBefore( | |
29 | func(ctx context.Context, md *metadata.MD) context.Context { | |
30 | var name string | |
31 | ||
32 | if cfg.Name != "" { | |
33 | name = cfg.Name | |
34 | } else { | |
35 | name = ctx.Value(kitgrpc.ContextKeyRequestMethod).(string) | |
36 | } | |
37 | ||
38 | ctx, span := trace.StartSpan( | |
39 | ctx, | |
40 | name, | |
41 | trace.WithSampler(cfg.Sampler), | |
42 | trace.WithSpanKind(trace.SpanKindClient), | |
43 | ) | |
44 | ||
45 | if !cfg.Public { | |
46 | traceContextBinary := string(propagation.Binary(span.SpanContext())) | |
47 | (*md)[propagationKey] = append((*md)[propagationKey], traceContextBinary) | |
48 | } | |
49 | ||
50 | return ctx | |
51 | }, | |
52 | ) | |
53 | ||
54 | clientFinalizer := kitgrpc.ClientFinalizer( | |
55 | func(ctx context.Context, err error) { | |
56 | if span := trace.FromContext(ctx); span != nil { | |
57 | if s, ok := status.FromError(err); ok { | |
58 | span.SetStatus(trace.Status{Code: int32(s.Code()), Message: s.Message()}) | |
59 | } else { | |
60 | span.SetStatus(trace.Status{Code: int32(codes.Unknown), Message: err.Error()}) | |
61 | } | |
62 | span.End() | |
63 | } | |
64 | }, | |
65 | ) | |
66 | ||
67 | return func(c *kitgrpc.Client) { | |
68 | clientBefore(c) | |
69 | clientFinalizer(c) | |
70 | } | |
71 | } | |
72 | ||
73 | // GRPCServerTrace enables OpenCensus tracing of a Go kit gRPC transport server. | |
74 | func GRPCServerTrace(options ...TracerOption) kitgrpc.ServerOption { | |
75 | cfg := TracerOptions{} | |
76 | ||
77 | for _, option := range options { | |
78 | option(&cfg) | |
79 | } | |
80 | ||
81 | if cfg.Sampler == nil { | |
82 | cfg.Sampler = trace.AlwaysSample() | |
83 | } | |
84 | ||
85 | serverBefore := kitgrpc.ServerBefore( | |
86 | func(ctx context.Context, md metadata.MD) context.Context { | |
87 | var name string | |
88 | ||
89 | if cfg.Name != "" { | |
90 | name = cfg.Name | |
91 | } else { | |
92 | name, _ = ctx.Value(kitgrpc.ContextKeyRequestMethod).(string) | |
93 | if name == "" { | |
94 | // we can't find the gRPC method. probably the | |
95 | // unaryInterceptor was not wired up. | |
96 | name = "unknown grpc method" | |
97 | } | |
98 | } | |
99 | ||
100 | var ( | |
101 | parentContext trace.SpanContext | |
102 | traceContext = md[propagationKey] | |
103 | ok bool | |
104 | ) | |
105 | ||
106 | if len(traceContext) > 0 { | |
107 | traceContextBinary := []byte(traceContext[0]) | |
108 | parentContext, ok = propagation.FromBinary(traceContextBinary) | |
109 | if ok && !cfg.Public { | |
110 | ctx, _ = trace.StartSpanWithRemoteParent( | |
111 | ctx, | |
112 | name, | |
113 | parentContext, | |
114 | trace.WithSpanKind(trace.SpanKindServer), | |
115 | trace.WithSampler(cfg.Sampler), | |
116 | ) | |
117 | return ctx | |
118 | } | |
119 | } | |
120 | ctx, span := trace.StartSpan( | |
121 | ctx, | |
122 | name, | |
123 | trace.WithSpanKind(trace.SpanKindServer), | |
124 | trace.WithSampler(cfg.Sampler), | |
125 | ) | |
126 | if ok { | |
127 | span.AddLink( | |
128 | trace.Link{ | |
129 | TraceID: parentContext.TraceID, | |
130 | SpanID: parentContext.SpanID, | |
131 | Type: trace.LinkTypeChild, | |
132 | }, | |
133 | ) | |
134 | } | |
135 | return ctx | |
136 | }, | |
137 | ) | |
138 | ||
139 | serverFinalizer := kitgrpc.ServerFinalizer( | |
140 | func(ctx context.Context, err error) { | |
141 | if span := trace.FromContext(ctx); span != nil { | |
142 | if s, ok := status.FromError(err); ok { | |
143 | span.SetStatus(trace.Status{Code: int32(s.Code()), Message: s.Message()}) | |
144 | } else { | |
145 | span.SetStatus(trace.Status{Code: int32(codes.Internal), Message: err.Error()}) | |
146 | } | |
147 | span.End() | |
148 | } | |
149 | }, | |
150 | ) | |
151 | ||
152 | return func(s *kitgrpc.Server) { | |
153 | serverBefore(s) | |
154 | serverFinalizer(s) | |
155 | } | |
156 | } |
0 | package opencensus_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "errors" | |
5 | "testing" | |
6 | ||
7 | "go.opencensus.io/trace" | |
8 | "go.opencensus.io/trace/propagation" | |
9 | "google.golang.org/grpc" | |
10 | "google.golang.org/grpc/codes" | |
11 | "google.golang.org/grpc/metadata" | |
12 | ||
13 | "github.com/go-kit/kit/endpoint" | |
14 | ockit "github.com/go-kit/kit/tracing/opencensus" | |
15 | grpctransport "github.com/go-kit/kit/transport/grpc" | |
16 | ) | |
17 | ||
18 | type dummy struct{} | |
19 | ||
20 | const traceContextKey = "grpc-trace-bin" | |
21 | ||
22 | func unaryInterceptor( | |
23 | ctx context.Context, method string, req, reply interface{}, | |
24 | cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, | |
25 | ) error { | |
26 | return nil | |
27 | } | |
28 | ||
29 | func TestGRPCClientTrace(t *testing.T) { | |
30 | rec := &recordingExporter{} | |
31 | ||
32 | trace.RegisterExporter(rec) | |
33 | ||
34 | cc, err := grpc.Dial( | |
35 | "", | |
36 | grpc.WithUnaryInterceptor(unaryInterceptor), | |
37 | grpc.WithInsecure(), | |
38 | ) | |
39 | if err != nil { | |
40 | t.Fatalf("unable to create gRPC dialer: %s", err.Error()) | |
41 | } | |
42 | ||
43 | traces := []struct { | |
44 | name string | |
45 | err error | |
46 | }{ | |
47 | {"", nil}, | |
48 | {"CustomName", nil}, | |
49 | {"", errors.New("dummy-error")}, | |
50 | } | |
51 | ||
52 | for _, tr := range traces { | |
53 | clientTracer := ockit.GRPCClientTrace(ockit.WithName(tr.name)) | |
54 | ||
55 | ep := grpctransport.NewClient( | |
56 | cc, | |
57 | "dummyService", | |
58 | "dummyMethod", | |
59 | func(context.Context, interface{}) (interface{}, error) { | |
60 | return nil, nil | |
61 | }, | |
62 | func(context.Context, interface{}) (interface{}, error) { | |
63 | return nil, tr.err | |
64 | }, | |
65 | dummy{}, | |
66 | clientTracer, | |
67 | ).Endpoint() | |
68 | ||
69 | ctx, parentSpan := trace.StartSpan(context.Background(), "test") | |
70 | ||
71 | _, err = ep(ctx, nil) | |
72 | if want, have := tr.err, err; want != have { | |
73 | t.Fatalf("unexpected error, want %s, have %s", tr.err.Error(), err.Error()) | |
74 | } | |
75 | ||
76 | spans := rec.Flush() | |
77 | if want, have := 1, len(spans); want != have { | |
78 | t.Fatalf("incorrect number of spans, want %d, have %d", want, have) | |
79 | } | |
80 | span := spans[0] | |
81 | if want, have := parentSpan.SpanContext().SpanID, span.ParentSpanID; want != have { | |
82 | t.Errorf("incorrect parent ID, want %s, have %s", want, have) | |
83 | } | |
84 | ||
85 | if want, have := tr.name, span.Name; want != have && want != "" { | |
86 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
87 | } | |
88 | ||
89 | if want, have := "/dummyService/dummyMethod", span.Name; want != have && tr.name == "" { | |
90 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
91 | } | |
92 | ||
93 | code := trace.StatusCodeOK | |
94 | if tr.err != nil { | |
95 | code = trace.StatusCodeUnknown | |
96 | ||
97 | if want, have := err.Error(), span.Status.Message; want != have { | |
98 | t.Errorf("incorrect span status msg, want %s, have %s", want, have) | |
99 | } | |
100 | } | |
101 | ||
102 | if want, have := int32(code), span.Status.Code; want != have { | |
103 | t.Errorf("incorrect span status code, want %d, have %d", want, have) | |
104 | } | |
105 | } | |
106 | } | |
107 | ||
108 | func TestGRPCServerTrace(t *testing.T) { | |
109 | rec := &recordingExporter{} | |
110 | ||
111 | trace.RegisterExporter(rec) | |
112 | ||
113 | traces := []struct { | |
114 | useParent bool | |
115 | name string | |
116 | err error | |
117 | }{ | |
118 | {false, "", nil}, | |
119 | {true, "", nil}, | |
120 | {true, "CustomName", nil}, | |
121 | {true, "", errors.New("dummy-error")}, | |
122 | } | |
123 | ||
124 | for _, tr := range traces { | |
125 | var ( | |
126 | ctx = context.Background() | |
127 | parentSpan *trace.Span | |
128 | ) | |
129 | ||
130 | server := grpctransport.NewServer( | |
131 | endpoint.Nop, | |
132 | func(context.Context, interface{}) (interface{}, error) { | |
133 | return nil, nil | |
134 | }, | |
135 | func(context.Context, interface{}) (interface{}, error) { | |
136 | return nil, tr.err | |
137 | }, | |
138 | ockit.GRPCServerTrace(ockit.WithName(tr.name)), | |
139 | ) | |
140 | ||
141 | if tr.useParent { | |
142 | _, parentSpan = trace.StartSpan(context.Background(), "test") | |
143 | traceContextBinary := propagation.Binary(parentSpan.SpanContext()) | |
144 | ||
145 | md := metadata.MD{} | |
146 | md.Set(traceContextKey, string(traceContextBinary)) | |
147 | ctx = metadata.NewIncomingContext(ctx, md) | |
148 | } | |
149 | ||
150 | server.ServeGRPC(ctx, nil) | |
151 | ||
152 | spans := rec.Flush() | |
153 | ||
154 | if want, have := 1, len(spans); want != have { | |
155 | t.Fatalf("incorrect number of spans, want %d, have %d", want, have) | |
156 | } | |
157 | ||
158 | if tr.useParent { | |
159 | if want, have := parentSpan.SpanContext().TraceID, spans[0].TraceID; want != have { | |
160 | t.Errorf("incorrect trace ID, want %s, have %s", want, have) | |
161 | } | |
162 | ||
163 | if want, have := parentSpan.SpanContext().SpanID, spans[0].ParentSpanID; want != have { | |
164 | t.Errorf("incorrect span ID, want %s, have %s", want, have) | |
165 | } | |
166 | } | |
167 | ||
168 | if want, have := tr.name, spans[0].Name; want != have && want != "" { | |
169 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
170 | } | |
171 | ||
172 | if tr.err != nil { | |
173 | if want, have := int32(codes.Internal), spans[0].Status.Code; want != have { | |
174 | t.Errorf("incorrect span status code, want %d, have %d", want, have) | |
175 | } | |
176 | ||
177 | if want, have := tr.err.Error(), spans[0].Status.Message; want != have { | |
178 | t.Errorf("incorrect span status message, want %s, have %s", want, have) | |
179 | } | |
180 | } | |
181 | } | |
182 | } |
0 | package opencensus | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "net/http" | |
5 | ||
6 | "go.opencensus.io/plugin/ochttp" | |
7 | "go.opencensus.io/plugin/ochttp/propagation/b3" | |
8 | "go.opencensus.io/trace" | |
9 | ||
10 | kithttp "github.com/go-kit/kit/transport/http" | |
11 | ) | |
12 | ||
13 | // HTTPClientTrace enables OpenCensus tracing of a Go kit HTTP transport client. | |
14 | func HTTPClientTrace(options ...TracerOption) kithttp.ClientOption { | |
15 | cfg := TracerOptions{} | |
16 | ||
17 | for _, option := range options { | |
18 | option(&cfg) | |
19 | } | |
20 | ||
21 | if cfg.Sampler == nil { | |
22 | cfg.Sampler = trace.AlwaysSample() | |
23 | } | |
24 | ||
25 | if !cfg.Public && cfg.HTTPPropagate == nil { | |
26 | cfg.HTTPPropagate = &b3.HTTPFormat{} | |
27 | } | |
28 | ||
29 | clientBefore := kithttp.ClientBefore( | |
30 | func(ctx context.Context, req *http.Request) context.Context { | |
31 | var name string | |
32 | ||
33 | if cfg.Name != "" { | |
34 | name = cfg.Name | |
35 | } else { | |
36 | // OpenCensus states Path being default naming for a client span | |
37 | name = req.Method + " " + req.URL.Path | |
38 | } | |
39 | ||
40 | ctx, span := trace.StartSpan( | |
41 | ctx, | |
42 | name, | |
43 | trace.WithSampler(cfg.Sampler), | |
44 | trace.WithSpanKind(trace.SpanKindClient), | |
45 | ) | |
46 | ||
47 | span.AddAttributes( | |
48 | trace.StringAttribute(ochttp.HostAttribute, req.URL.Host), | |
49 | trace.StringAttribute(ochttp.MethodAttribute, req.Method), | |
50 | trace.StringAttribute(ochttp.PathAttribute, req.URL.Path), | |
51 | trace.StringAttribute(ochttp.UserAgentAttribute, req.UserAgent()), | |
52 | ) | |
53 | ||
54 | if !cfg.Public { | |
55 | cfg.HTTPPropagate.SpanContextToRequest(span.SpanContext(), req) | |
56 | } | |
57 | ||
58 | return ctx | |
59 | }, | |
60 | ) | |
61 | ||
62 | clientAfter := kithttp.ClientAfter( | |
63 | func(ctx context.Context, res *http.Response) context.Context { | |
64 | if span := trace.FromContext(ctx); span != nil { | |
65 | span.SetStatus(ochttp.TraceStatus(res.StatusCode, http.StatusText(res.StatusCode))) | |
66 | span.AddAttributes( | |
67 | trace.Int64Attribute(ochttp.StatusCodeAttribute, int64(res.StatusCode)), | |
68 | ) | |
69 | } | |
70 | return ctx | |
71 | }, | |
72 | ) | |
73 | ||
74 | clientFinalizer := kithttp.ClientFinalizer( | |
75 | func(ctx context.Context, err error) { | |
76 | if span := trace.FromContext(ctx); span != nil { | |
77 | if err != nil { | |
78 | span.SetStatus(trace.Status{ | |
79 | Code: trace.StatusCodeUnknown, | |
80 | Message: err.Error(), | |
81 | }) | |
82 | } | |
83 | span.End() | |
84 | } | |
85 | }, | |
86 | ) | |
87 | ||
88 | return func(c *kithttp.Client) { | |
89 | clientBefore(c) | |
90 | clientAfter(c) | |
91 | clientFinalizer(c) | |
92 | } | |
93 | } | |
94 | ||
95 | // HTTPServerTrace enables OpenCensus tracing of a Go kit HTTP transport server. | |
96 | func HTTPServerTrace(options ...TracerOption) kithttp.ServerOption { | |
97 | cfg := TracerOptions{} | |
98 | ||
99 | for _, option := range options { | |
100 | option(&cfg) | |
101 | } | |
102 | ||
103 | if cfg.Sampler == nil { | |
104 | cfg.Sampler = trace.AlwaysSample() | |
105 | } | |
106 | ||
107 | if !cfg.Public && cfg.HTTPPropagate == nil { | |
108 | cfg.HTTPPropagate = &b3.HTTPFormat{} | |
109 | } | |
110 | ||
111 | serverBefore := kithttp.ServerBefore( | |
112 | func(ctx context.Context, req *http.Request) context.Context { | |
113 | var ( | |
114 | spanContext trace.SpanContext | |
115 | span *trace.Span | |
116 | name string | |
117 | ok bool | |
118 | ) | |
119 | ||
120 | if cfg.Name != "" { | |
121 | name = cfg.Name | |
122 | } else { | |
123 | name = req.Method + " " + req.URL.Path | |
124 | } | |
125 | ||
126 | spanContext, ok = cfg.HTTPPropagate.SpanContextFromRequest(req) | |
127 | if ok && !cfg.Public { | |
128 | ctx, span = trace.StartSpanWithRemoteParent( | |
129 | ctx, | |
130 | name, | |
131 | spanContext, | |
132 | trace.WithSpanKind(trace.SpanKindServer), | |
133 | trace.WithSampler(cfg.Sampler), | |
134 | ) | |
135 | } else { | |
136 | ctx, span = trace.StartSpan( | |
137 | ctx, | |
138 | name, | |
139 | trace.WithSpanKind(trace.SpanKindServer), | |
140 | trace.WithSampler(cfg.Sampler), | |
141 | ) | |
142 | if ok { | |
143 | span.AddLink(trace.Link{ | |
144 | TraceID: spanContext.TraceID, | |
145 | SpanID: spanContext.SpanID, | |
146 | Type: trace.LinkTypeChild, | |
147 | Attributes: nil, | |
148 | }) | |
149 | } | |
150 | } | |
151 | ||
152 | span.AddAttributes( | |
153 | trace.StringAttribute(ochttp.MethodAttribute, req.Method), | |
154 | trace.StringAttribute(ochttp.PathAttribute, req.URL.Path), | |
155 | ) | |
156 | ||
157 | return ctx | |
158 | }, | |
159 | ) | |
160 | ||
161 | serverFinalizer := kithttp.ServerFinalizer( | |
162 | func(ctx context.Context, code int, r *http.Request) { | |
163 | if span := trace.FromContext(ctx); span != nil { | |
164 | span.SetStatus(ochttp.TraceStatus(code, http.StatusText(code))) | |
165 | ||
166 | if rs, ok := ctx.Value(kithttp.ContextKeyResponseSize).(int64); ok { | |
167 | span.AddAttributes( | |
168 | trace.Int64Attribute("http.response_size", rs), | |
169 | ) | |
170 | } | |
171 | ||
172 | span.End() | |
173 | } | |
174 | }, | |
175 | ) | |
176 | ||
177 | return func(s *kithttp.Server) { | |
178 | serverBefore(s) | |
179 | serverFinalizer(s) | |
180 | } | |
181 | } |
0 | package opencensus_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "errors" | |
5 | "net/http" | |
6 | "net/http/httptest" | |
7 | "net/url" | |
8 | "testing" | |
9 | ||
10 | "go.opencensus.io/plugin/ochttp" | |
11 | "go.opencensus.io/plugin/ochttp/propagation/b3" | |
12 | "go.opencensus.io/plugin/ochttp/propagation/tracecontext" | |
13 | "go.opencensus.io/trace" | |
14 | "go.opencensus.io/trace/propagation" | |
15 | ||
16 | "github.com/go-kit/kit/endpoint" | |
17 | ockit "github.com/go-kit/kit/tracing/opencensus" | |
18 | kithttp "github.com/go-kit/kit/transport/http" | |
19 | ) | |
20 | ||
21 | func TestHTTPClientTrace(t *testing.T) { | |
22 | var ( | |
23 | err error | |
24 | rec = &recordingExporter{} | |
25 | rURL, _ = url.Parse("http://test.com/dummy/path") | |
26 | ) | |
27 | ||
28 | trace.RegisterExporter(rec) | |
29 | ||
30 | traces := []struct { | |
31 | name string | |
32 | err error | |
33 | }{ | |
34 | {"", nil}, | |
35 | {"CustomName", nil}, | |
36 | {"", errors.New("dummy-error")}, | |
37 | } | |
38 | ||
39 | for _, tr := range traces { | |
40 | clientTracer := ockit.HTTPClientTrace(ockit.WithName(tr.name)) | |
41 | ep := kithttp.NewClient( | |
42 | "GET", | |
43 | rURL, | |
44 | func(ctx context.Context, r *http.Request, i interface{}) error { | |
45 | return nil | |
46 | }, | |
47 | func(ctx context.Context, r *http.Response) (response interface{}, err error) { | |
48 | return nil, tr.err | |
49 | }, | |
50 | clientTracer, | |
51 | ).Endpoint() | |
52 | ||
53 | ctx, parentSpan := trace.StartSpan(context.Background(), "test") | |
54 | ||
55 | _, err = ep(ctx, nil) | |
56 | if want, have := tr.err, err; want != have { | |
57 | t.Fatalf("unexpected error, want %s, have %s", tr.err.Error(), err.Error()) | |
58 | } | |
59 | ||
60 | spans := rec.Flush() | |
61 | if want, have := 1, len(spans); want != have { | |
62 | t.Fatalf("incorrect number of spans, want %d, have %d", want, have) | |
63 | } | |
64 | ||
65 | span := spans[0] | |
66 | if want, have := parentSpan.SpanContext().SpanID, span.ParentSpanID; want != have { | |
67 | t.Errorf("incorrect parent ID, want %s, have %s", want, have) | |
68 | } | |
69 | ||
70 | if want, have := tr.name, span.Name; want != have && want != "" { | |
71 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
72 | } | |
73 | ||
74 | if want, have := "GET /dummy/path", span.Name; want != have && tr.name == "" { | |
75 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
76 | } | |
77 | ||
78 | code := trace.StatusCodeOK | |
79 | if tr.err != nil { | |
80 | code = trace.StatusCodeUnknown | |
81 | ||
82 | if want, have := err.Error(), span.Status.Message; want != have { | |
83 | t.Errorf("incorrect span status msg, want %s, have %s", want, have) | |
84 | } | |
85 | } | |
86 | ||
87 | if want, have := int32(code), span.Status.Code; want != have { | |
88 | t.Errorf("incorrect span status code, want %d, have %d", want, have) | |
89 | } | |
90 | } | |
91 | } | |
92 | ||
93 | func TestHTTPServerTrace(t *testing.T) { | |
94 | rec := &recordingExporter{} | |
95 | ||
96 | trace.RegisterExporter(rec) | |
97 | ||
98 | traces := []struct { | |
99 | useParent bool | |
100 | name string | |
101 | err error | |
102 | propagation propagation.HTTPFormat | |
103 | }{ | |
104 | {false, "", nil, nil}, | |
105 | {true, "", nil, nil}, | |
106 | {true, "CustomName", nil, &b3.HTTPFormat{}}, | |
107 | {true, "", errors.New("dummy-error"), &tracecontext.HTTPFormat{}}, | |
108 | } | |
109 | ||
110 | for _, tr := range traces { | |
111 | var client http.Client | |
112 | ||
113 | handler := kithttp.NewServer( | |
114 | endpoint.Nop, | |
115 | func(context.Context, *http.Request) (interface{}, error) { return nil, nil }, | |
116 | func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dummy") }, | |
117 | ockit.HTTPServerTrace( | |
118 | ockit.WithName(tr.name), | |
119 | ockit.WithHTTPPropagation(tr.propagation), | |
120 | ), | |
121 | ) | |
122 | ||
123 | server := httptest.NewServer(handler) | |
124 | defer server.Close() | |
125 | ||
126 | const httpMethod = "GET" | |
127 | ||
128 | req, err := http.NewRequest(httpMethod, server.URL, nil) | |
129 | if err != nil { | |
130 | t.Fatalf("unable to create HTTP request: %s", err.Error()) | |
131 | } | |
132 | ||
133 | if tr.useParent { | |
134 | client = http.Client{ | |
135 | Transport: &ochttp.Transport{ | |
136 | Propagation: tr.propagation, | |
137 | }, | |
138 | } | |
139 | } | |
140 | ||
141 | resp, err := client.Do(req.WithContext(context.Background())) | |
142 | if err != nil { | |
143 | t.Fatalf("unable to send HTTP request: %s", err.Error()) | |
144 | } | |
145 | resp.Body.Close() | |
146 | ||
147 | spans := rec.Flush() | |
148 | ||
149 | expectedSpans := 1 | |
150 | if tr.useParent { | |
151 | expectedSpans++ | |
152 | } | |
153 | ||
154 | if want, have := expectedSpans, len(spans); want != have { | |
155 | t.Fatalf("incorrect number of spans, want %d, have %d", want, have) | |
156 | } | |
157 | ||
158 | if tr.useParent { | |
159 | if want, have := spans[1].TraceID, spans[0].TraceID; want != have { | |
160 | t.Errorf("incorrect trace ID, want %s, have %s", want, have) | |
161 | } | |
162 | ||
163 | if want, have := spans[1].SpanID, spans[0].ParentSpanID; want != have { | |
164 | t.Errorf("incorrect span ID, want %s, have %s", want, have) | |
165 | } | |
166 | } | |
167 | ||
168 | if want, have := tr.name, spans[0].Name; want != have && want != "" { | |
169 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
170 | } | |
171 | ||
172 | if want, have := "GET /", spans[0].Name; want != have && tr.name == "" { | |
173 | t.Errorf("incorrect span name, want %s, have %s", want, have) | |
174 | } | |
175 | } | |
176 | } |
0 | package opencensus_test | |
1 | ||
2 | import ( | |
3 | "sync" | |
4 | ||
5 | "go.opencensus.io/trace" | |
6 | ) | |
7 | ||
8 | type recordingExporter struct { | |
9 | mu sync.Mutex | |
10 | data []*trace.SpanData | |
11 | } | |
12 | ||
13 | func (e *recordingExporter) ExportSpan(d *trace.SpanData) { | |
14 | e.mu.Lock() | |
15 | defer e.mu.Unlock() | |
16 | ||
17 | e.data = append(e.data, d) | |
18 | } | |
19 | ||
20 | func (e *recordingExporter) Flush() (data []*trace.SpanData) { | |
21 | e.mu.Lock() | |
22 | defer e.mu.Unlock() | |
23 | ||
24 | data = e.data | |
25 | e.data = nil | |
26 | return | |
27 | } |
0 | package opencensus | |
1 | ||
2 | import ( | |
3 | "go.opencensus.io/plugin/ochttp/propagation/b3" | |
4 | "go.opencensus.io/trace" | |
5 | "go.opencensus.io/trace/propagation" | |
6 | ) | |
7 | ||
8 | // defaultHTTPPropagate holds OpenCensus' default HTTP propagation format which | |
9 | // currently is Zipkin's B3. | |
10 | var defaultHTTPPropagate propagation.HTTPFormat = &b3.HTTPFormat{} | |
11 | ||
12 | // TracerOption allows for functional options to our OpenCensus tracing | |
13 | // middleware. | |
14 | type TracerOption func(o *TracerOptions) | |
15 | ||
16 | // WithTracerConfig sets all configuration options at once. | |
17 | func WithTracerConfig(options TracerOptions) TracerOption { | |
18 | return func(o *TracerOptions) { | |
19 | *o = options | |
20 | } | |
21 | } | |
22 | ||
23 | // WithSampler sets the sampler to use by our OpenCensus Tracer. | |
24 | func WithSampler(sampler trace.Sampler) TracerOption { | |
25 | return func(o *TracerOptions) { | |
26 | o.Sampler = sampler | |
27 | } | |
28 | } | |
29 | ||
30 | // WithName sets the name for an instrumented transport endpoint. If name is omitted | |
31 | // at tracing middleware creation, the method of the transport or transport rpc | |
32 | // name is used. | |
33 | func WithName(name string) TracerOption { | |
34 | return func(o *TracerOptions) { | |
35 | o.Name = name | |
36 | } | |
37 | } | |
38 | ||
39 | // IsPublic should be set to true for publicly accessible servers and for | |
40 | // clients that should not propagate their current trace metadata. | |
41 | // On the server side a new trace will always be started regardless of any | |
42 | // trace metadata being found in the incoming request. If any trace metadata | |
43 | // is found, it will be added as a linked trace instead. | |
44 | func IsPublic(isPublic bool) TracerOption { | |
45 | return func(o *TracerOptions) { | |
46 | o.Public = isPublic | |
47 | } | |
48 | } | |
49 | ||
50 | // WithHTTPPropagation sets the propagation handlers for the HTTP transport | |
51 | // middlewares. If used on a non HTTP transport this is a noop. | |
52 | func WithHTTPPropagation(p propagation.HTTPFormat) TracerOption { | |
53 | return func(o *TracerOptions) { | |
54 | if p == nil { | |
55 | // reset to default OC HTTP format | |
56 | o.HTTPPropagate = defaultHTTPPropagate | |
57 | return | |
58 | } | |
59 | o.HTTPPropagate = p | |
60 | } | |
61 | } | |
62 | ||
63 | // TracerOptions holds configuration for our tracing middlewares | |
64 | type TracerOptions struct { | |
65 | Sampler trace.Sampler | |
66 | Name string | |
67 | Public bool | |
68 | HTTPPropagate propagation.HTTPFormat | |
69 | } |
52 | 52 | func (w metadataReaderWriter) Set(key, val string) { |
53 | 53 | key = strings.ToLower(key) |
54 | 54 | if strings.HasSuffix(key, "-bin") { |
55 | val = string(base64.StdEncoding.EncodeToString([]byte(val))) | |
55 | val = base64.StdEncoding.EncodeToString([]byte(val)) | |
56 | 56 | } |
57 | 57 | (*w.MD)[key] = append((*w.MD)[key], val) |
58 | 58 | } |
141 | 141 | |
142 | 142 | rpcMethod, ok := ctx.Value(kitgrpc.ContextKeyRequestMethod).(string) |
143 | 143 | if !ok { |
144 | config.logger.Log("unable to retrieve method name: missing gRPC interceptor hook") | |
144 | config.logger.Log("err", "unable to retrieve method name: missing gRPC interceptor hook") | |
145 | 145 | } else { |
146 | 146 | tags["grpc.method"] = rpcMethod |
147 | 147 | } |
0 | package amqp | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "github.com/streadway/amqp" | |
5 | ) | |
6 | ||
7 | // DecodeRequestFunc extracts a user-domain request object from | |
8 | // an AMQP Delivery object. It is designed to be used in AMQP Subscribers. | |
9 | type DecodeRequestFunc func(context.Context, *amqp.Delivery) (request interface{}, err error) | |
10 | ||
11 | // EncodeRequestFunc encodes the passed request object into | |
12 | // an AMQP Publishing object. It is designed to be used in AMQP Publishers. | |
13 | type EncodeRequestFunc func(context.Context, *amqp.Publishing, interface{}) error | |
14 | ||
15 | // EncodeResponseFunc encodes the passed reponse object to | |
16 | // an AMQP Publishing object. It is designed to be used in AMQP Subscribers. | |
17 | type EncodeResponseFunc func(context.Context, *amqp.Publishing, interface{}) error | |
18 | ||
19 | // DecodeResponseFunc extracts a user-domain response object from | |
20 | // an AMQP Delivery object. It is designed to be used in AMQP Publishers. | |
21 | type DecodeResponseFunc func(context.Context, *amqp.Delivery) (response interface{}, err error) |
0 | package amqp | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "time" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/streadway/amqp" | |
8 | ) | |
9 | ||
10 | // The golang AMQP implementation requires the []byte representation of | |
11 | // correlation id strings to have a maximum length of 255 bytes. | |
12 | const maxCorrelationIdLength = 255 | |
13 | ||
14 | // Publisher wraps an AMQP channel and queue, and provides a method that | |
15 | // implements endpoint.Endpoint. | |
16 | type Publisher struct { | |
17 | ch Channel | |
18 | q *amqp.Queue | |
19 | enc EncodeRequestFunc | |
20 | dec DecodeResponseFunc | |
21 | before []RequestFunc | |
22 | after []PublisherResponseFunc | |
23 | timeout time.Duration | |
24 | } | |
25 | ||
26 | // NewPublisher constructs a usable Publisher for a single remote method. | |
27 | func NewPublisher( | |
28 | ch Channel, | |
29 | q *amqp.Queue, | |
30 | enc EncodeRequestFunc, | |
31 | dec DecodeResponseFunc, | |
32 | options ...PublisherOption, | |
33 | ) *Publisher { | |
34 | p := &Publisher{ | |
35 | ch: ch, | |
36 | q: q, | |
37 | enc: enc, | |
38 | dec: dec, | |
39 | timeout: 10 * time.Second, | |
40 | } | |
41 | for _, option := range options { | |
42 | option(p) | |
43 | } | |
44 | return p | |
45 | } | |
46 | ||
47 | // PublisherOption sets an optional parameter for clients. | |
48 | type PublisherOption func(*Publisher) | |
49 | ||
50 | // PublisherBefore sets the RequestFuncs that are applied to the outgoing AMQP | |
51 | // request before it's invoked. | |
52 | func PublisherBefore(before ...RequestFunc) PublisherOption { | |
53 | return func(p *Publisher) { p.before = append(p.before, before...) } | |
54 | } | |
55 | ||
56 | // PublisherAfter sets the ClientResponseFuncs applied to the incoming AMQP | |
57 | // request prior to it being decoded. This is useful for obtaining anything off | |
58 | // of the response and adding onto the context prior to decoding. | |
59 | func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { | |
60 | return func(p *Publisher) { p.after = append(p.after, after...) } | |
61 | } | |
62 | ||
63 | // PublisherTimeout sets the available timeout for an AMQP request. | |
64 | func PublisherTimeout(timeout time.Duration) PublisherOption { | |
65 | return func(p *Publisher) { p.timeout = timeout } | |
66 | } | |
67 | ||
68 | // Endpoint returns a usable endpoint that invokes the remote endpoint. | |
69 | func (p Publisher) Endpoint() endpoint.Endpoint { | |
70 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
71 | ctx, cancel := context.WithTimeout(ctx, p.timeout) | |
72 | defer cancel() | |
73 | ||
74 | pub := amqp.Publishing{ | |
75 | ReplyTo: p.q.Name, | |
76 | CorrelationId: randomString(randInt(5, maxCorrelationIdLength)), | |
77 | } | |
78 | ||
79 | if err := p.enc(ctx, &pub, request); err != nil { | |
80 | return nil, err | |
81 | } | |
82 | ||
83 | for _, f := range p.before { | |
84 | ctx = f(ctx, &pub) | |
85 | } | |
86 | ||
87 | deliv, err := p.publishAndConsumeFirstMatchingResponse(ctx, &pub) | |
88 | if err != nil { | |
89 | return nil, err | |
90 | } | |
91 | ||
92 | for _, f := range p.after { | |
93 | ctx = f(ctx, deliv) | |
94 | } | |
95 | response, err := p.dec(ctx, deliv) | |
96 | if err != nil { | |
97 | return nil, err | |
98 | } | |
99 | ||
100 | return response, nil | |
101 | } | |
102 | } | |
103 | ||
104 | // publishAndConsumeFirstMatchingResponse publishes the specified Publishing | |
105 | // and returns the first Delivery object with the matching correlationId. | |
106 | // If the context times out while waiting for a reply, an error will be returned. | |
107 | func (p Publisher) publishAndConsumeFirstMatchingResponse( | |
108 | ctx context.Context, | |
109 | pub *amqp.Publishing, | |
110 | ) (*amqp.Delivery, error) { | |
111 | err := p.ch.Publish( | |
112 | getPublishExchange(ctx), | |
113 | getPublishKey(ctx), | |
114 | false, //mandatory | |
115 | false, //immediate | |
116 | *pub, | |
117 | ) | |
118 | if err != nil { | |
119 | return nil, err | |
120 | } | |
121 | autoAck := getConsumeAutoAck(ctx) | |
122 | ||
123 | msg, err := p.ch.Consume( | |
124 | p.q.Name, | |
125 | "", //consumer | |
126 | autoAck, | |
127 | false, //exclusive | |
128 | false, //noLocal | |
129 | false, //noWait | |
130 | getConsumeArgs(ctx), | |
131 | ) | |
132 | if err != nil { | |
133 | return nil, err | |
134 | } | |
135 | ||
136 | for { | |
137 | select { | |
138 | case d := <-msg: | |
139 | if d.CorrelationId == pub.CorrelationId { | |
140 | if !autoAck { | |
141 | d.Ack(false) //multiple | |
142 | } | |
143 | return &d, nil | |
144 | } | |
145 | ||
146 | case <-ctx.Done(): | |
147 | return nil, ctx.Err() | |
148 | } | |
149 | } | |
150 | ||
151 | } |
0 | package amqp_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | amqptransport "github.com/go-kit/kit/transport/amqp" | |
10 | "github.com/streadway/amqp" | |
11 | ) | |
12 | ||
13 | var ( | |
14 | defaultContentType = "" | |
15 | defaultContentEncoding = "" | |
16 | ) | |
17 | ||
18 | // TestBadEncode tests if encode errors are handled properly. | |
19 | func TestBadEncode(t *testing.T) { | |
20 | ch := &mockChannel{f: nullFunc} | |
21 | q := &amqp.Queue{Name: "some queue"} | |
22 | pub := amqptransport.NewPublisher( | |
23 | ch, | |
24 | q, | |
25 | func(context.Context, *amqp.Publishing, interface{}) error { return errors.New("err!") }, | |
26 | func(context.Context, *amqp.Delivery) (response interface{}, err error) { return struct{}{}, nil }, | |
27 | ) | |
28 | errChan := make(chan error, 1) | |
29 | var err error | |
30 | go func() { | |
31 | _, err := pub.Endpoint()(context.Background(), struct{}{}) | |
32 | errChan <- err | |
33 | ||
34 | }() | |
35 | select { | |
36 | case err = <-errChan: | |
37 | break | |
38 | ||
39 | case <-time.After(100 * time.Millisecond): | |
40 | t.Fatal("Timed out waiting for result") | |
41 | } | |
42 | if err == nil { | |
43 | t.Error("expected error") | |
44 | } | |
45 | if want, have := "err!", err.Error(); want != have { | |
46 | t.Errorf("want %s, have %s", want, have) | |
47 | } | |
48 | } | |
49 | ||
50 | // TestBadDecode tests if decode errors are handled properly. | |
51 | func TestBadDecode(t *testing.T) { | |
52 | cid := "correlation" | |
53 | ch := &mockChannel{ | |
54 | f: nullFunc, | |
55 | c: make(chan amqp.Publishing, 1), | |
56 | deliveries: []amqp.Delivery{ | |
57 | amqp.Delivery{ | |
58 | CorrelationId: cid, | |
59 | }, | |
60 | }, | |
61 | } | |
62 | q := &amqp.Queue{Name: "some queue"} | |
63 | ||
64 | pub := amqptransport.NewPublisher( | |
65 | ch, | |
66 | q, | |
67 | func(context.Context, *amqp.Publishing, interface{}) error { return nil }, | |
68 | func(context.Context, *amqp.Delivery) (response interface{}, err error) { | |
69 | return struct{}{}, errors.New("err!") | |
70 | }, | |
71 | amqptransport.PublisherBefore( | |
72 | amqptransport.SetCorrelationID(cid), | |
73 | ), | |
74 | ) | |
75 | ||
76 | var err error | |
77 | errChan := make(chan error, 1) | |
78 | go func() { | |
79 | _, err := pub.Endpoint()(context.Background(), struct{}{}) | |
80 | errChan <- err | |
81 | ||
82 | }() | |
83 | ||
84 | select { | |
85 | case err = <-errChan: | |
86 | break | |
87 | ||
88 | case <-time.After(100 * time.Millisecond): | |
89 | t.Fatal("Timed out waiting for result") | |
90 | } | |
91 | ||
92 | if err == nil { | |
93 | t.Error("expected error") | |
94 | } | |
95 | if want, have := "err!", err.Error(); want != have { | |
96 | t.Errorf("want %s, have %s", want, have) | |
97 | } | |
98 | } | |
99 | ||
100 | // TestPublisherTimeout ensures that the publisher timeout mechanism works. | |
101 | func TestPublisherTimeout(t *testing.T) { | |
102 | ch := &mockChannel{ | |
103 | f: nullFunc, | |
104 | c: make(chan amqp.Publishing, 1), | |
105 | deliveries: []amqp.Delivery{}, // no reply from mock subscriber | |
106 | } | |
107 | q := &amqp.Queue{Name: "some queue"} | |
108 | ||
109 | pub := amqptransport.NewPublisher( | |
110 | ch, | |
111 | q, | |
112 | func(context.Context, *amqp.Publishing, interface{}) error { return nil }, | |
113 | func(context.Context, *amqp.Delivery) (response interface{}, err error) { | |
114 | return struct{}{}, nil | |
115 | }, | |
116 | amqptransport.PublisherTimeout(50*time.Millisecond), | |
117 | ) | |
118 | ||
119 | var err error | |
120 | errChan := make(chan error, 1) | |
121 | go func() { | |
122 | _, err := pub.Endpoint()(context.Background(), struct{}{}) | |
123 | errChan <- err | |
124 | ||
125 | }() | |
126 | ||
127 | select { | |
128 | case err = <-errChan: | |
129 | break | |
130 | ||
131 | case <-time.After(100 * time.Millisecond): | |
132 | t.Fatal("timed out waiting for result") | |
133 | } | |
134 | ||
135 | if err == nil { | |
136 | t.Error("expected error") | |
137 | } | |
138 | if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { | |
139 | t.Errorf("want %s, have %s", want, have) | |
140 | } | |
141 | } | |
142 | ||
143 | func TestSuccessfulPublisher(t *testing.T) { | |
144 | cid := "correlation" | |
145 | mockReq := testReq{437} | |
146 | mockRes := testRes{ | |
147 | Squadron: mockReq.Squadron, | |
148 | Name: names[mockReq.Squadron], | |
149 | } | |
150 | b, err := json.Marshal(mockRes) | |
151 | if err != nil { | |
152 | t.Fatal(err) | |
153 | } | |
154 | reqChan := make(chan amqp.Publishing, 1) | |
155 | ch := &mockChannel{ | |
156 | f: nullFunc, | |
157 | c: reqChan, | |
158 | deliveries: []amqp.Delivery{ | |
159 | amqp.Delivery{ | |
160 | CorrelationId: cid, | |
161 | Body: b, | |
162 | }, | |
163 | }, | |
164 | } | |
165 | q := &amqp.Queue{Name: "some queue"} | |
166 | ||
167 | pub := amqptransport.NewPublisher( | |
168 | ch, | |
169 | q, | |
170 | testReqEncoder, | |
171 | testResDeliveryDecoder, | |
172 | amqptransport.PublisherBefore( | |
173 | amqptransport.SetCorrelationID(cid), | |
174 | ), | |
175 | ) | |
176 | var publishing amqp.Publishing | |
177 | var res testRes | |
178 | var ok bool | |
179 | resChan := make(chan interface{}, 1) | |
180 | errChan := make(chan error, 1) | |
181 | go func() { | |
182 | res, err := pub.Endpoint()(context.Background(), mockReq) | |
183 | if err != nil { | |
184 | errChan <- err | |
185 | } else { | |
186 | resChan <- res | |
187 | } | |
188 | }() | |
189 | ||
190 | select { | |
191 | case publishing = <-reqChan: | |
192 | break | |
193 | ||
194 | case <-time.After(100 * time.Millisecond): | |
195 | t.Fatal("timed out waiting for request") | |
196 | } | |
197 | if want, have := defaultContentType, publishing.ContentType; want != have { | |
198 | t.Errorf("want %s, have %s", want, have) | |
199 | } | |
200 | if want, have := defaultContentEncoding, publishing.ContentEncoding; want != have { | |
201 | t.Errorf("want %s, have %s", want, have) | |
202 | } | |
203 | ||
204 | select { | |
205 | case response := <-resChan: | |
206 | res, ok = response.(testRes) | |
207 | if !ok { | |
208 | t.Error("failed to assert endpoint response type") | |
209 | } | |
210 | break | |
211 | ||
212 | case err = <-errChan: | |
213 | break | |
214 | ||
215 | case <-time.After(100 * time.Millisecond): | |
216 | t.Fatal("timed out waiting for result") | |
217 | } | |
218 | ||
219 | if err != nil { | |
220 | t.Fatal(err) | |
221 | } | |
222 | if want, have := mockRes.Name, res.Name; want != have { | |
223 | t.Errorf("want %s, have %s", want, have) | |
224 | } | |
225 | } |
0 | package amqp | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "time" | |
5 | ||
6 | "github.com/streadway/amqp" | |
7 | ) | |
8 | ||
9 | // RequestFunc may take information from a publisher request and put it into a | |
10 | // request context. In Subscribers, RequestFuncs are executed prior to invoking | |
11 | // the endpoint. | |
12 | type RequestFunc func(context.Context, *amqp.Publishing) context.Context | |
13 | ||
14 | // SubscriberResponseFunc may take information from a request context and use it to | |
15 | // manipulate a Publisher. SubscriberResponseFuncs are only executed in | |
16 | // subscribers, after invoking the endpoint but prior to publishing a reply. | |
17 | type SubscriberResponseFunc func(context.Context, | |
18 | *amqp.Delivery, | |
19 | Channel, | |
20 | *amqp.Publishing, | |
21 | ) context.Context | |
22 | ||
23 | // PublisherResponseFunc may take information from an AMQP request and make the | |
24 | // response available for consumption. PublisherResponseFunc are only executed | |
25 | // in publishers, after a request has been made, but prior to it being decoded. | |
26 | type PublisherResponseFunc func(context.Context, *amqp.Delivery) context.Context | |
27 | ||
28 | // SetPublishExchange returns a RequestFunc that sets the Exchange field | |
29 | // of an AMQP Publish call. | |
30 | func SetPublishExchange(publishExchange string) RequestFunc { | |
31 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
32 | return context.WithValue(ctx, ContextKeyExchange, publishExchange) | |
33 | } | |
34 | } | |
35 | ||
36 | // SetPublishKey returns a RequestFunc that sets the Key field | |
37 | // of an AMQP Publish call. | |
38 | func SetPublishKey(publishKey string) RequestFunc { | |
39 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
40 | return context.WithValue(ctx, ContextKeyPublishKey, publishKey) | |
41 | } | |
42 | } | |
43 | ||
44 | // SetPublishDeliveryMode sets the delivery mode of a Publishing. | |
45 | // Please refer to AMQP delivery mode constants in the AMQP package. | |
46 | func SetPublishDeliveryMode(dmode uint8) RequestFunc { | |
47 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
48 | pub.DeliveryMode = dmode | |
49 | return ctx | |
50 | } | |
51 | } | |
52 | ||
53 | // SetNackSleepDuration returns a RequestFunc that sets the amount of time | |
54 | // to sleep in the event of a Nack. | |
55 | // This has to be used in conjunction with an error encoder that Nack and sleeps. | |
56 | // One example is the SingleNackRequeueErrorEncoder. | |
57 | // It is designed to be used by Subscribers. | |
58 | func SetNackSleepDuration(duration time.Duration) RequestFunc { | |
59 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
60 | return context.WithValue(ctx, ContextKeyNackSleepDuration, duration) | |
61 | } | |
62 | } | |
63 | ||
64 | // SetConsumeAutoAck returns a RequestFunc that sets whether or not to autoAck | |
65 | // messages when consuming. | |
66 | // When set to false, the publisher will Ack the first message it receives with | |
67 | // a matching correlationId. | |
68 | // It is designed to be used by Publishers. | |
69 | func SetConsumeAutoAck(autoAck bool) RequestFunc { | |
70 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
71 | return context.WithValue(ctx, ContextKeyAutoAck, autoAck) | |
72 | } | |
73 | } | |
74 | ||
75 | // SetConsumeArgs returns a RequestFunc that set the arguments for amqp Consume | |
76 | // function. | |
77 | // It is designed to be used by Publishers. | |
78 | func SetConsumeArgs(args amqp.Table) RequestFunc { | |
79 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
80 | return context.WithValue(ctx, ContextKeyConsumeArgs, args) | |
81 | } | |
82 | } | |
83 | ||
84 | // SetContentType returns a RequestFunc that sets the ContentType field of | |
85 | // an AMQP Publishing. | |
86 | func SetContentType(contentType string) RequestFunc { | |
87 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
88 | pub.ContentType = contentType | |
89 | return ctx | |
90 | } | |
91 | } | |
92 | ||
93 | // SetContentEncoding returns a RequestFunc that sets the ContentEncoding field | |
94 | // of an AMQP Publishing. | |
95 | func SetContentEncoding(contentEncoding string) RequestFunc { | |
96 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
97 | pub.ContentEncoding = contentEncoding | |
98 | return ctx | |
99 | } | |
100 | } | |
101 | ||
102 | // SetCorrelationID returns a RequestFunc that sets the CorrelationId field | |
103 | // of an AMQP Publishing. | |
104 | func SetCorrelationID(cid string) RequestFunc { | |
105 | return func(ctx context.Context, pub *amqp.Publishing) context.Context { | |
106 | pub.CorrelationId = cid | |
107 | return ctx | |
108 | } | |
109 | } | |
110 | ||
111 | // SetAckAfterEndpoint returns a SubscriberResponseFunc that prompts the service | |
112 | // to Ack the Delivery object after successfully evaluating the endpoint, | |
113 | // and before it encodes the response. | |
114 | // It is designed to be used by Subscribers. | |
115 | func SetAckAfterEndpoint(multiple bool) SubscriberResponseFunc { | |
116 | return func(ctx context.Context, | |
117 | deliv *amqp.Delivery, | |
118 | ch Channel, | |
119 | pub *amqp.Publishing, | |
120 | ) context.Context { | |
121 | deliv.Ack(multiple) | |
122 | return ctx | |
123 | } | |
124 | } | |
125 | ||
126 | func getPublishExchange(ctx context.Context) string { | |
127 | if exchange := ctx.Value(ContextKeyExchange); exchange != nil { | |
128 | return exchange.(string) | |
129 | } | |
130 | return "" | |
131 | } | |
132 | ||
133 | func getPublishKey(ctx context.Context) string { | |
134 | if publishKey := ctx.Value(ContextKeyPublishKey); publishKey != nil { | |
135 | return publishKey.(string) | |
136 | } | |
137 | return "" | |
138 | } | |
139 | ||
140 | func getNackSleepDuration(ctx context.Context) time.Duration { | |
141 | if duration := ctx.Value(ContextKeyNackSleepDuration); duration != nil { | |
142 | return duration.(time.Duration) | |
143 | } | |
144 | return 0 | |
145 | } | |
146 | ||
147 | func getConsumeAutoAck(ctx context.Context) bool { | |
148 | if autoAck := ctx.Value(ContextKeyAutoAck); autoAck != nil { | |
149 | return autoAck.(bool) | |
150 | } | |
151 | return false | |
152 | } | |
153 | ||
154 | func getConsumeArgs(ctx context.Context) amqp.Table { | |
155 | if args := ctx.Value(ContextKeyConsumeArgs); args != nil { | |
156 | return args.(amqp.Table) | |
157 | } | |
158 | return nil | |
159 | } | |
160 | ||
161 | type contextKey int | |
162 | ||
163 | const ( | |
164 | // ContextKeyExchange is the value of the reply Exchange in | |
165 | // amqp.Publish. | |
166 | ContextKeyExchange contextKey = iota | |
167 | // ContextKeyPublishKey is the value of the ReplyTo field in | |
168 | // amqp.Publish. | |
169 | ContextKeyPublishKey | |
170 | // ContextKeyNackSleepDuration is the duration to sleep for if the | |
171 | // service Nack and requeues a message. | |
172 | // This is to prevent sporadic send-resending of message | |
173 | // when a message is constantly Nack'd and requeued. | |
174 | ContextKeyNackSleepDuration | |
175 | // ContextKeyAutoAck is the value of autoAck field when calling | |
176 | // amqp.Channel.Consume. | |
177 | ContextKeyAutoAck | |
178 | // ContextKeyConsumeArgs is the value of consumeArgs field when calling | |
179 | // amqp.Channel.Consume. | |
180 | ContextKeyConsumeArgs | |
181 | ) |
0 | package amqp | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "time" | |
6 | ||
7 | "github.com/go-kit/kit/endpoint" | |
8 | "github.com/go-kit/kit/log" | |
9 | "github.com/streadway/amqp" | |
10 | ) | |
11 | ||
12 | // Subscriber wraps an endpoint and provides a handler for AMQP Delivery messages. | |
13 | type Subscriber struct { | |
14 | e endpoint.Endpoint | |
15 | dec DecodeRequestFunc | |
16 | enc EncodeResponseFunc | |
17 | before []RequestFunc | |
18 | after []SubscriberResponseFunc | |
19 | errorEncoder ErrorEncoder | |
20 | logger log.Logger | |
21 | } | |
22 | ||
23 | // NewSubscriber constructs a new subscriber, which provides a handler | |
24 | // for AMQP Delivery messages. | |
25 | func NewSubscriber( | |
26 | e endpoint.Endpoint, | |
27 | dec DecodeRequestFunc, | |
28 | enc EncodeResponseFunc, | |
29 | options ...SubscriberOption, | |
30 | ) *Subscriber { | |
31 | s := &Subscriber{ | |
32 | e: e, | |
33 | dec: dec, | |
34 | enc: enc, | |
35 | errorEncoder: DefaultErrorEncoder, | |
36 | logger: log.NewNopLogger(), | |
37 | } | |
38 | for _, option := range options { | |
39 | option(s) | |
40 | } | |
41 | return s | |
42 | } | |
43 | ||
44 | // SubscriberOption sets an optional parameter for subscribers. | |
45 | type SubscriberOption func(*Subscriber) | |
46 | ||
47 | // SubscriberBefore functions are executed on the publisher delivery object | |
48 | // before the request is decoded. | |
49 | func SubscriberBefore(before ...RequestFunc) SubscriberOption { | |
50 | return func(s *Subscriber) { s.before = append(s.before, before...) } | |
51 | } | |
52 | ||
53 | // SubscriberAfter functions are executed on the subscriber reply after the | |
54 | // endpoint is invoked, but before anything is published to the reply. | |
55 | func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { | |
56 | return func(s *Subscriber) { s.after = append(s.after, after...) } | |
57 | } | |
58 | ||
59 | // SubscriberErrorEncoder is used to encode errors to the subscriber reply | |
60 | // whenever they're encountered in the processing of a request. Clients can | |
61 | // use this to provide custom error formatting. By default, | |
62 | // errors will be published with the DefaultErrorEncoder. | |
63 | func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { | |
64 | return func(s *Subscriber) { s.errorEncoder = ee } | |
65 | } | |
66 | ||
67 | // SubscriberErrorLogger is used to log non-terminal errors. By default, no errors | |
68 | // are logged. This is intended as a diagnostic measure. Finer-grained control | |
69 | // of error handling, including logging in more detail, should be performed in a | |
70 | // custom SubscriberErrorEncoder which has access to the context. | |
71 | func SubscriberErrorLogger(logger log.Logger) SubscriberOption { | |
72 | return func(s *Subscriber) { s.logger = logger } | |
73 | } | |
74 | ||
75 | // ServeDelivery handles AMQP Delivery messages | |
76 | // It is strongly recommended to use *amqp.Channel as the | |
77 | // Channel interface implementation. | |
78 | func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) { | |
79 | return func(deliv *amqp.Delivery) { | |
80 | ctx, cancel := context.WithCancel(context.Background()) | |
81 | defer cancel() | |
82 | ||
83 | pub := amqp.Publishing{} | |
84 | ||
85 | for _, f := range s.before { | |
86 | ctx = f(ctx, &pub) | |
87 | } | |
88 | ||
89 | request, err := s.dec(ctx, deliv) | |
90 | if err != nil { | |
91 | s.logger.Log("err", err) | |
92 | s.errorEncoder(ctx, err, deliv, ch, &pub) | |
93 | return | |
94 | } | |
95 | ||
96 | response, err := s.e(ctx, request) | |
97 | if err != nil { | |
98 | s.logger.Log("err", err) | |
99 | s.errorEncoder(ctx, err, deliv, ch, &pub) | |
100 | return | |
101 | } | |
102 | ||
103 | for _, f := range s.after { | |
104 | ctx = f(ctx, deliv, ch, &pub) | |
105 | } | |
106 | ||
107 | if err := s.enc(ctx, &pub, response); err != nil { | |
108 | s.logger.Log("err", err) | |
109 | s.errorEncoder(ctx, err, deliv, ch, &pub) | |
110 | return | |
111 | } | |
112 | ||
113 | if err := s.publishResponse(ctx, deliv, ch, &pub); err != nil { | |
114 | s.logger.Log("err", err) | |
115 | s.errorEncoder(ctx, err, deliv, ch, &pub) | |
116 | return | |
117 | } | |
118 | } | |
119 | ||
120 | } | |
121 | ||
122 | func (s Subscriber) publishResponse( | |
123 | ctx context.Context, | |
124 | deliv *amqp.Delivery, | |
125 | ch Channel, | |
126 | pub *amqp.Publishing, | |
127 | ) error { | |
128 | if pub.CorrelationId == "" { | |
129 | pub.CorrelationId = deliv.CorrelationId | |
130 | } | |
131 | ||
132 | replyExchange := getPublishExchange(ctx) | |
133 | replyTo := getPublishKey(ctx) | |
134 | if replyTo == "" { | |
135 | replyTo = deliv.ReplyTo | |
136 | } | |
137 | ||
138 | return ch.Publish( | |
139 | replyExchange, | |
140 | replyTo, | |
141 | false, // mandatory | |
142 | false, // immediate | |
143 | *pub, | |
144 | ) | |
145 | } | |
146 | ||
147 | // EncodeJSONResponse marshals the response as JSON as part of the | |
148 | // payload of the AMQP Publishing object. | |
149 | func EncodeJSONResponse( | |
150 | ctx context.Context, | |
151 | pub *amqp.Publishing, | |
152 | response interface{}, | |
153 | ) error { | |
154 | b, err := json.Marshal(response) | |
155 | if err != nil { | |
156 | return err | |
157 | } | |
158 | pub.Body = b | |
159 | return nil | |
160 | } | |
161 | ||
162 | // EncodeNopResponse is a response function that does nothing. | |
163 | func EncodeNopResponse( | |
164 | ctx context.Context, | |
165 | pub *amqp.Publishing, | |
166 | response interface{}, | |
167 | ) error { | |
168 | return nil | |
169 | } | |
170 | ||
171 | // ErrorEncoder is responsible for encoding an error to the subscriber reply. | |
172 | // Users are encouraged to use custom ErrorEncoders to encode errors to | |
173 | // their replies, and will likely want to pass and check for their own error | |
174 | // types. | |
175 | type ErrorEncoder func(ctx context.Context, | |
176 | err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) | |
177 | ||
178 | // DefaultErrorEncoder simply ignores the message. It does not reply | |
179 | // nor Ack/Nack the message. | |
180 | func DefaultErrorEncoder(ctx context.Context, | |
181 | err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { | |
182 | } | |
183 | ||
184 | // SingleNackRequeueErrorEncoder issues a Nack to the delivery with multiple flag set as false | |
185 | // and requeue flag set as true. It does not reply the message. | |
186 | func SingleNackRequeueErrorEncoder(ctx context.Context, | |
187 | err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { | |
188 | deliv.Nack( | |
189 | false, //multiple | |
190 | true, //requeue | |
191 | ) | |
192 | duration := getNackSleepDuration(ctx) | |
193 | time.Sleep(duration) | |
194 | } | |
195 | ||
196 | // ReplyErrorEncoder serializes the error message as a DefaultErrorResponse | |
197 | // JSON and sends the message to the ReplyTo address. | |
198 | func ReplyErrorEncoder( | |
199 | ctx context.Context, | |
200 | err error, | |
201 | deliv *amqp.Delivery, | |
202 | ch Channel, | |
203 | pub *amqp.Publishing, | |
204 | ) { | |
205 | ||
206 | if pub.CorrelationId == "" { | |
207 | pub.CorrelationId = deliv.CorrelationId | |
208 | } | |
209 | ||
210 | replyExchange := getPublishExchange(ctx) | |
211 | replyTo := getPublishKey(ctx) | |
212 | if replyTo == "" { | |
213 | replyTo = deliv.ReplyTo | |
214 | } | |
215 | ||
216 | response := DefaultErrorResponse{err.Error()} | |
217 | ||
218 | b, err := json.Marshal(response) | |
219 | if err != nil { | |
220 | return | |
221 | } | |
222 | pub.Body = b | |
223 | ||
224 | ch.Publish( | |
225 | replyExchange, | |
226 | replyTo, | |
227 | false, // mandatory | |
228 | false, // immediate | |
229 | *pub, | |
230 | ) | |
231 | } | |
232 | ||
233 | // ReplyAndAckErrorEncoder serializes the error message as a DefaultErrorResponse | |
234 | // JSON and sends the message to the ReplyTo address then Acks the original | |
235 | // message. | |
236 | func ReplyAndAckErrorEncoder(ctx context.Context, err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) { | |
237 | ReplyErrorEncoder(ctx, err, deliv, ch, pub) | |
238 | deliv.Ack(false) | |
239 | } | |
240 | ||
241 | // DefaultErrorResponse is the default structure of responses in the event | |
242 | // of an error. | |
243 | type DefaultErrorResponse struct { | |
244 | Error string `json:"err"` | |
245 | } | |
246 | ||
247 | // Channel is a channel interface to make testing possible. | |
248 | // It is highly recommended to use *amqp.Channel as the interface implementation. | |
249 | type Channel interface { | |
250 | Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error | |
251 | Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) | |
252 | } |
0 | package amqp_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "testing" | |
7 | "time" | |
8 | ||
9 | amqptransport "github.com/go-kit/kit/transport/amqp" | |
10 | "github.com/streadway/amqp" | |
11 | ) | |
12 | ||
13 | var ( | |
14 | typeAssertionError = errors.New("type assertion error") | |
15 | ) | |
16 | ||
17 | // mockChannel is a mock of *amqp.Channel. | |
18 | type mockChannel struct { | |
19 | f func(exchange, key string, mandatory, immediate bool) | |
20 | c chan<- amqp.Publishing | |
21 | deliveries []amqp.Delivery | |
22 | } | |
23 | ||
24 | // Publish runs a test function f and sends resultant message to a channel. | |
25 | func (ch *mockChannel) Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error { | |
26 | ch.f(exchange, key, mandatory, immediate) | |
27 | ch.c <- msg | |
28 | return nil | |
29 | } | |
30 | ||
31 | var nullFunc = func(exchange, key string, mandatory, immediate bool) { | |
32 | } | |
33 | ||
34 | func (ch *mockChannel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) { | |
35 | c := make(chan amqp.Delivery, len(ch.deliveries)) | |
36 | for _, d := range ch.deliveries { | |
37 | c <- d | |
38 | } | |
39 | return c, nil | |
40 | } | |
41 | ||
42 | // TestSubscriberBadDecode checks if decoder errors are handled properly. | |
43 | func TestSubscriberBadDecode(t *testing.T) { | |
44 | sub := amqptransport.NewSubscriber( | |
45 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
46 | func(context.Context, *amqp.Delivery) (interface{}, error) { return nil, errors.New("err!") }, | |
47 | func(context.Context, *amqp.Publishing, interface{}) error { | |
48 | return nil | |
49 | }, | |
50 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
51 | ) | |
52 | ||
53 | outputChan := make(chan amqp.Publishing, 1) | |
54 | ch := &mockChannel{f: nullFunc, c: outputChan} | |
55 | sub.ServeDelivery(ch)(&amqp.Delivery{}) | |
56 | ||
57 | var msg amqp.Publishing | |
58 | select { | |
59 | case msg = <-outputChan: | |
60 | break | |
61 | ||
62 | case <-time.After(100 * time.Millisecond): | |
63 | t.Fatal("Timed out waiting for publishing") | |
64 | } | |
65 | res, err := decodeSubscriberError(msg) | |
66 | if err != nil { | |
67 | t.Fatal(err) | |
68 | } | |
69 | if want, have := "err!", res.Error; want != have { | |
70 | t.Errorf("want %s, have %s", want, have) | |
71 | } | |
72 | } | |
73 | ||
74 | // TestSubscriberBadEndpoint checks if endpoint errors are handled properly. | |
75 | func TestSubscriberBadEndpoint(t *testing.T) { | |
76 | sub := amqptransport.NewSubscriber( | |
77 | func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") }, | |
78 | func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, | |
79 | func(context.Context, *amqp.Publishing, interface{}) error { | |
80 | return nil | |
81 | }, | |
82 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
83 | ) | |
84 | ||
85 | outputChan := make(chan amqp.Publishing, 1) | |
86 | ch := &mockChannel{f: nullFunc, c: outputChan} | |
87 | sub.ServeDelivery(ch)(&amqp.Delivery{}) | |
88 | ||
89 | var msg amqp.Publishing | |
90 | ||
91 | select { | |
92 | case msg = <-outputChan: | |
93 | break | |
94 | ||
95 | case <-time.After(100 * time.Millisecond): | |
96 | t.Fatal("Timed out waiting for publishing") | |
97 | } | |
98 | ||
99 | res, err := decodeSubscriberError(msg) | |
100 | if err != nil { | |
101 | t.Fatal(err) | |
102 | } | |
103 | if want, have := "err!", res.Error; want != have { | |
104 | t.Errorf("want %s, have %s", want, have) | |
105 | } | |
106 | } | |
107 | ||
108 | // TestSubscriberBadEncoder checks if encoder errors are handled properly. | |
109 | func TestSubscriberBadEncoder(t *testing.T) { | |
110 | sub := amqptransport.NewSubscriber( | |
111 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
112 | func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, | |
113 | func(context.Context, *amqp.Publishing, interface{}) error { | |
114 | return errors.New("err!") | |
115 | }, | |
116 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
117 | ) | |
118 | ||
119 | outputChan := make(chan amqp.Publishing, 1) | |
120 | ch := &mockChannel{f: nullFunc, c: outputChan} | |
121 | sub.ServeDelivery(ch)(&amqp.Delivery{}) | |
122 | ||
123 | var msg amqp.Publishing | |
124 | ||
125 | select { | |
126 | case msg = <-outputChan: | |
127 | break | |
128 | ||
129 | case <-time.After(100 * time.Millisecond): | |
130 | t.Fatal("Timed out waiting for publishing") | |
131 | } | |
132 | ||
133 | res, err := decodeSubscriberError(msg) | |
134 | if err != nil { | |
135 | t.Fatal(err) | |
136 | } | |
137 | if want, have := "err!", res.Error; want != have { | |
138 | t.Errorf("want %s, have %s", want, have) | |
139 | } | |
140 | } | |
141 | ||
142 | // TestSubscriberSuccess checks if CorrelationId and ReplyTo are set properly | |
143 | // and if the payload is encoded properly. | |
144 | func TestSubscriberSuccess(t *testing.T) { | |
145 | cid := "correlation" | |
146 | replyTo := "sender" | |
147 | obj := testReq{ | |
148 | Squadron: 436, | |
149 | } | |
150 | b, err := json.Marshal(obj) | |
151 | if err != nil { | |
152 | t.Fatal(err) | |
153 | } | |
154 | ||
155 | sub := amqptransport.NewSubscriber( | |
156 | testEndpoint, | |
157 | testReqDecoder, | |
158 | amqptransport.EncodeJSONResponse, | |
159 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
160 | ) | |
161 | ||
162 | checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) { | |
163 | if want, have := replyTo, key; want != have { | |
164 | t.Errorf("want %s, have %s", want, have) | |
165 | } | |
166 | } | |
167 | ||
168 | outputChan := make(chan amqp.Publishing, 1) | |
169 | ch := &mockChannel{f: checkReplyToFunc, c: outputChan} | |
170 | sub.ServeDelivery(ch)(&amqp.Delivery{ | |
171 | CorrelationId: cid, | |
172 | ReplyTo: replyTo, | |
173 | Body: b, | |
174 | }) | |
175 | ||
176 | var msg amqp.Publishing | |
177 | ||
178 | select { | |
179 | case msg = <-outputChan: | |
180 | break | |
181 | ||
182 | case <-time.After(100 * time.Millisecond): | |
183 | t.Fatal("Timed out waiting for publishing") | |
184 | } | |
185 | ||
186 | if want, have := cid, msg.CorrelationId; want != have { | |
187 | t.Errorf("want %s, have %s", want, have) | |
188 | } | |
189 | ||
190 | // check if error is not thrown | |
191 | errRes, err := decodeSubscriberError(msg) | |
192 | if err != nil { | |
193 | t.Fatal(err) | |
194 | } | |
195 | if errRes.Error != "" { | |
196 | t.Error("Received error from subscriber", errRes.Error) | |
197 | return | |
198 | } | |
199 | ||
200 | // check obj vals | |
201 | response, err := testResDecoder(msg.Body) | |
202 | if err != nil { | |
203 | t.Fatal(err) | |
204 | } | |
205 | res, ok := response.(testRes) | |
206 | if !ok { | |
207 | t.Error(typeAssertionError) | |
208 | } | |
209 | ||
210 | if want, have := obj.Squadron, res.Squadron; want != have { | |
211 | t.Errorf("want %d, have %d", want, have) | |
212 | } | |
213 | if want, have := names[obj.Squadron], res.Name; want != have { | |
214 | t.Errorf("want %s, have %s", want, have) | |
215 | } | |
216 | } | |
217 | ||
218 | // TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode | |
219 | // are working. | |
220 | func TestSubscriberMultipleBefore(t *testing.T) { | |
221 | exchange := "some exchange" | |
222 | key := "some key" | |
223 | deliveryMode := uint8(127) | |
224 | contentType := "some content type" | |
225 | contentEncoding := "some content encoding" | |
226 | sub := amqptransport.NewSubscriber( | |
227 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
228 | func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, | |
229 | amqptransport.EncodeJSONResponse, | |
230 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
231 | amqptransport.SubscriberBefore( | |
232 | amqptransport.SetPublishExchange(exchange), | |
233 | amqptransport.SetPublishKey(key), | |
234 | amqptransport.SetPublishDeliveryMode(deliveryMode), | |
235 | amqptransport.SetContentType(contentType), | |
236 | amqptransport.SetContentEncoding(contentEncoding), | |
237 | ), | |
238 | ) | |
239 | checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { | |
240 | if want, have := exchange, exch; want != have { | |
241 | t.Errorf("want %s, have %s", want, have) | |
242 | } | |
243 | if want, have := key, k; want != have { | |
244 | t.Errorf("want %s, have %s", want, have) | |
245 | } | |
246 | } | |
247 | ||
248 | outputChan := make(chan amqp.Publishing, 1) | |
249 | ch := &mockChannel{f: checkReplyToFunc, c: outputChan} | |
250 | sub.ServeDelivery(ch)(&amqp.Delivery{}) | |
251 | ||
252 | var msg amqp.Publishing | |
253 | ||
254 | select { | |
255 | case msg = <-outputChan: | |
256 | break | |
257 | ||
258 | case <-time.After(100 * time.Millisecond): | |
259 | t.Fatal("Timed out waiting for publishing") | |
260 | } | |
261 | ||
262 | // check if error is not thrown | |
263 | errRes, err := decodeSubscriberError(msg) | |
264 | if err != nil { | |
265 | t.Fatal(err) | |
266 | } | |
267 | if errRes.Error != "" { | |
268 | t.Error("Received error from subscriber", errRes.Error) | |
269 | return | |
270 | } | |
271 | ||
272 | if want, have := contentType, msg.ContentType; want != have { | |
273 | t.Errorf("want %s, have %s", want, have) | |
274 | } | |
275 | ||
276 | if want, have := contentEncoding, msg.ContentEncoding; want != have { | |
277 | t.Errorf("want %s, have %s", want, have) | |
278 | } | |
279 | ||
280 | if want, have := deliveryMode, msg.DeliveryMode; want != have { | |
281 | t.Errorf("want %d, have %d", want, have) | |
282 | } | |
283 | } | |
284 | ||
285 | // TestDefaultContentMetaData checks that default ContentType and Content-Encoding | |
286 | // is not set as mentioned by AMQP specification. | |
287 | func TestDefaultContentMetaData(t *testing.T) { | |
288 | defaultContentType := "" | |
289 | defaultContentEncoding := "" | |
290 | sub := amqptransport.NewSubscriber( | |
291 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
292 | func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil }, | |
293 | amqptransport.EncodeJSONResponse, | |
294 | amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), | |
295 | ) | |
296 | checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { return } | |
297 | outputChan := make(chan amqp.Publishing, 1) | |
298 | ch := &mockChannel{f: checkReplyToFunc, c: outputChan} | |
299 | sub.ServeDelivery(ch)(&amqp.Delivery{}) | |
300 | ||
301 | var msg amqp.Publishing | |
302 | ||
303 | select { | |
304 | case msg = <-outputChan: | |
305 | break | |
306 | ||
307 | case <-time.After(100 * time.Millisecond): | |
308 | t.Fatal("Timed out waiting for publishing") | |
309 | } | |
310 | ||
311 | // check if error is not thrown | |
312 | errRes, err := decodeSubscriberError(msg) | |
313 | if err != nil { | |
314 | t.Fatal(err) | |
315 | } | |
316 | if errRes.Error != "" { | |
317 | t.Error("Received error from subscriber", errRes.Error) | |
318 | return | |
319 | } | |
320 | ||
321 | if want, have := defaultContentType, msg.ContentType; want != have { | |
322 | t.Errorf("want %s, have %s", want, have) | |
323 | } | |
324 | if want, have := defaultContentEncoding, msg.ContentEncoding; want != have { | |
325 | t.Errorf("want %s, have %s", want, have) | |
326 | } | |
327 | } | |
328 | ||
329 | func decodeSubscriberError(pub amqp.Publishing) (amqptransport.DefaultErrorResponse, error) { | |
330 | var res amqptransport.DefaultErrorResponse | |
331 | err := json.Unmarshal(pub.Body, &res) | |
332 | return res, err | |
333 | } | |
334 | ||
335 | type testReq struct { | |
336 | Squadron int `json:"s"` | |
337 | } | |
338 | type testRes struct { | |
339 | Squadron int `json:"s"` | |
340 | Name string `json:"n"` | |
341 | } | |
342 | ||
343 | func testEndpoint(_ context.Context, request interface{}) (interface{}, error) { | |
344 | req, ok := request.(testReq) | |
345 | if !ok { | |
346 | return nil, typeAssertionError | |
347 | } | |
348 | name, prs := names[req.Squadron] | |
349 | if !prs { | |
350 | return nil, errors.New("unknown squadron name") | |
351 | } | |
352 | res := testRes{ | |
353 | Squadron: req.Squadron, | |
354 | Name: name, | |
355 | } | |
356 | return res, nil | |
357 | } | |
358 | ||
359 | func testReqDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) { | |
360 | var obj testReq | |
361 | err := json.Unmarshal(d.Body, &obj) | |
362 | return obj, err | |
363 | } | |
364 | ||
365 | func testReqEncoder(_ context.Context, p *amqp.Publishing, request interface{}) error { | |
366 | req, ok := request.(testReq) | |
367 | if !ok { | |
368 | return errors.New("type assertion failure") | |
369 | } | |
370 | b, err := json.Marshal(req) | |
371 | if err != nil { | |
372 | return err | |
373 | } | |
374 | p.Body = b | |
375 | return nil | |
376 | } | |
377 | ||
378 | func testResDeliveryDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) { | |
379 | return testResDecoder(d.Body) | |
380 | } | |
381 | ||
382 | func testResDecoder(b []byte) (interface{}, error) { | |
383 | var obj testRes | |
384 | err := json.Unmarshal(b, &obj) | |
385 | return obj, err | |
386 | } | |
387 | ||
388 | var names = map[int]string{ | |
389 | 424: "tiger", | |
390 | 426: "thunderbird", | |
391 | 429: "bison", | |
392 | 436: "tusker", | |
393 | 437: "husky", | |
394 | } |
0 | package amqp | |
1 | ||
2 | import ( | |
3 | "math/rand" | |
4 | ) | |
5 | ||
6 | func randomString(l int) string { | |
7 | bytes := make([]byte, l) | |
8 | for i := 0; i < l; i++ { | |
9 | bytes[i] = byte(randInt(65, 90)) | |
10 | } | |
11 | return string(bytes) | |
12 | } | |
13 | ||
14 | func randInt(min int, max int) int { | |
15 | return min + rand.Intn(max-min) | |
16 | } |
68 | 68 | func EncodeKeyValue(key, val string) (string, string) { |
69 | 69 | key = strings.ToLower(key) |
70 | 70 | if strings.HasSuffix(key, binHdrSuffix) { |
71 | v := base64.StdEncoding.EncodeToString([]byte(val)) | |
72 | val = string(v) | |
71 | val = base64.StdEncoding.EncodeToString([]byte(val)) | |
73 | 72 | } |
74 | 73 | return key, val |
75 | 74 | } |
54 | 54 | // ServerOption sets an optional parameter for servers. |
55 | 55 | type ServerOption func(*Server) |
56 | 56 | |
57 | // ServerBefore functions are executed on the HTTP request object before the | |
57 | // ServerBefore functions are executed on the gRPC request object before the | |
58 | 58 | // request is decoded. |
59 | 59 | func ServerBefore(before ...ServerRequestFunc) ServerOption { |
60 | 60 | return func(s *Server) { s.before = append(s.before, before...) } |
61 | 61 | } |
62 | 62 | |
63 | // ServerAfter functions are executed on the HTTP response writer after the | |
63 | // ServerAfter functions are executed on the gRPC response writer after the | |
64 | 64 | // endpoint is invoked, but before anything is written to the client. |
65 | 65 | func ServerAfter(after ...ServerResponseFunc) ServerOption { |
66 | 66 | return func(s *Server) { s.after = append(s.after, after...) } |
4 | 4 | "context" |
5 | 5 | "encoding/json" |
6 | 6 | "encoding/xml" |
7 | "io" | |
7 | 8 | "io/ioutil" |
8 | 9 | "net/http" |
9 | 10 | "net/url" |
11 | 12 | "github.com/go-kit/kit/endpoint" |
12 | 13 | ) |
13 | 14 | |
15 | // HTTPClient is an interface that models *http.Client. | |
16 | type HTTPClient interface { | |
17 | Do(req *http.Request) (*http.Response, error) | |
18 | } | |
19 | ||
14 | 20 | // Client wraps a URL and provides a method that implements endpoint.Endpoint. |
15 | 21 | type Client struct { |
16 | client *http.Client | |
22 | client HTTPClient | |
17 | 23 | method string |
18 | 24 | tgt *url.URL |
19 | 25 | enc EncodeRequestFunc |
53 | 59 | |
54 | 60 | // SetClient sets the underlying HTTP client used for requests. |
55 | 61 | // By default, http.DefaultClient is used. |
56 | func SetClient(client *http.Client) ClientOption { | |
62 | func SetClient(client HTTPClient) ClientOption { | |
57 | 63 | return func(c *Client) { c.client = client } |
58 | 64 | } |
59 | 65 | |
78 | 84 | |
79 | 85 | // BufferedStream sets whether the Response.Body is left open, allowing it |
80 | 86 | // to be read from later. Useful for transporting a file as a buffered stream. |
87 | // That body has to be Closed to propery end the request. | |
81 | 88 | func BufferedStream(buffered bool) ClientOption { |
82 | 89 | return func(c *Client) { c.bufferedStream = buffered } |
83 | 90 | } |
86 | 93 | func (c Client) Endpoint() endpoint.Endpoint { |
87 | 94 | return func(ctx context.Context, request interface{}) (interface{}, error) { |
88 | 95 | ctx, cancel := context.WithCancel(ctx) |
89 | defer cancel() | |
90 | 96 | |
91 | 97 | var ( |
92 | 98 | resp *http.Response |
106 | 112 | |
107 | 113 | req, err := http.NewRequest(c.method, c.tgt.String(), nil) |
108 | 114 | if err != nil { |
115 | cancel() | |
109 | 116 | return nil, err |
110 | 117 | } |
111 | 118 | |
112 | 119 | if err = c.enc(ctx, req, request); err != nil { |
120 | cancel() | |
113 | 121 | return nil, err |
114 | 122 | } |
115 | 123 | |
120 | 128 | resp, err = c.client.Do(req.WithContext(ctx)) |
121 | 129 | |
122 | 130 | if err != nil { |
123 | return nil, err | |
124 | } | |
125 | ||
126 | if !c.bufferedStream { | |
131 | cancel() | |
132 | return nil, err | |
133 | } | |
134 | ||
135 | // If we expect a buffered stream, we don't cancel the context when the endpoint returns. | |
136 | // Instead, we should call the cancel func when closing the response body. | |
137 | if c.bufferedStream { | |
138 | resp.Body = bodyWithCancel{ReadCloser: resp.Body, cancel: cancel} | |
139 | } else { | |
127 | 140 | defer resp.Body.Close() |
141 | defer cancel() | |
128 | 142 | } |
129 | 143 | |
130 | 144 | for _, f := range c.after { |
138 | 152 | |
139 | 153 | return response, nil |
140 | 154 | } |
155 | } | |
156 | ||
157 | // bodyWithCancel is a wrapper for an io.ReadCloser with also a | |
158 | // cancel function which is called when the Close is used | |
159 | type bodyWithCancel struct { | |
160 | io.ReadCloser | |
161 | ||
162 | cancel context.CancelFunc | |
163 | } | |
164 | ||
165 | func (bwc bodyWithCancel) Close() error { | |
166 | bwc.ReadCloser.Close() | |
167 | bwc.cancel() | |
168 | return nil | |
141 | 169 | } |
142 | 170 | |
143 | 171 | // ClientFinalizerFunc can be used to perform work at the end of a client HTTP |
0 | 0 | package http_test |
1 | 1 | |
2 | 2 | import ( |
3 | "bytes" | |
3 | 4 | "context" |
4 | 5 | "io" |
5 | 6 | "io/ioutil" |
96 | 97 | } |
97 | 98 | |
98 | 99 | func TestHTTPClientBufferedStream(t *testing.T) { |
100 | // bodysize has a size big enought to make the resopnse.Body not an instant read | |
101 | // so if the response is cancelled it wount be all readed and the test would fail | |
102 | // The 6000 has not a particular meaning, it big enough to fulfill the usecase. | |
103 | const bodysize = 6000 | |
99 | 104 | var ( |
100 | testbody = "testbody" | |
105 | testbody = string(make([]byte, bodysize)) | |
101 | 106 | encode = func(context.Context, *http.Request, interface{}) error { return nil } |
102 | 107 | decode = func(_ context.Context, r *http.Response) (interface{}, error) { |
103 | 108 | return TestResponse{r.Body, ""}, nil |
127 | 132 | if !ok { |
128 | 133 | t.Fatal("response should be TestResponse") |
129 | 134 | } |
135 | defer response.Body.Close() | |
136 | // Faking work | |
137 | time.Sleep(time.Second * 1) | |
130 | 138 | |
131 | 139 | // Check that response body was NOT closed |
132 | 140 | b := make([]byte, len(testbody)) |
251 | 259 | } |
252 | 260 | } |
253 | 261 | |
262 | func TestSetClient(t *testing.T) { | |
263 | var ( | |
264 | encode = func(context.Context, *http.Request, interface{}) error { return nil } | |
265 | decode = func(_ context.Context, r *http.Response) (interface{}, error) { | |
266 | t, err := ioutil.ReadAll(r.Body) | |
267 | if err != nil { | |
268 | return nil, err | |
269 | } | |
270 | return string(t), nil | |
271 | } | |
272 | ) | |
273 | ||
274 | testHttpClient := httpClientFunc(func(req *http.Request) (*http.Response, error) { | |
275 | return &http.Response{ | |
276 | StatusCode: http.StatusOK, | |
277 | Request: req, | |
278 | Body: ioutil.NopCloser(bytes.NewBufferString("hello, world!")), | |
279 | }, nil | |
280 | }) | |
281 | ||
282 | client := httptransport.NewClient( | |
283 | "GET", | |
284 | &url.URL{}, | |
285 | encode, | |
286 | decode, | |
287 | httptransport.SetClient(testHttpClient), | |
288 | ).Endpoint() | |
289 | ||
290 | resp, err := client(context.Background(), nil) | |
291 | if err != nil { | |
292 | t.Fatal(err) | |
293 | } | |
294 | if r, ok := resp.(string); !ok || r != "hello, world!" { | |
295 | t.Fatal("Expected response to be 'hello, world!' string") | |
296 | } | |
297 | } | |
298 | ||
254 | 299 | func mustParse(s string) *url.URL { |
255 | 300 | u, err := url.Parse(s) |
256 | 301 | if err != nil { |
264 | 309 | } |
265 | 310 | |
266 | 311 | func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } |
312 | ||
313 | type httpClientFunc func(req *http.Request) (*http.Response, error) | |
314 | ||
315 | func (f httpClientFunc) Do(req *http.Request) (*http.Response, error) { | |
316 | return f(req) | |
317 | } |
7 | 7 | // DecodeRequestFunc extracts a user-domain request object from an HTTP |
8 | 8 | // request object. It's designed to be used in HTTP servers, for server-side |
9 | 9 | // endpoints. One straightforward DecodeRequestFunc could be something that |
10 | // JSON decodes from the request body to the concrete response type. | |
10 | // JSON decodes from the request body to the concrete request type. | |
11 | 11 | type DecodeRequestFunc func(context.Context, *http.Request) (request interface{}, err error) |
12 | 12 | |
13 | 13 | // EncodeRequestFunc encodes the passed request object into the HTTP request |
14 | 14 | // object. It's designed to be used in HTTP clients, for client-side |
15 | // endpoints. One straightforward EncodeRequestFunc could something that JSON | |
15 | // endpoints. One straightforward EncodeRequestFunc could be something that JSON | |
16 | 16 | // encodes the object directly to the request body. |
17 | 17 | type EncodeRequestFunc func(context.Context, *http.Request, interface{}) error |
18 | 18 |
14 | 14 | |
15 | 15 | // Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint. |
16 | 16 | type Client struct { |
17 | client *http.Client | |
17 | client httptransport.HTTPClient | |
18 | 18 | |
19 | 19 | // JSON RPC endpoint URL |
20 | 20 | tgt *url.URL |
85 | 85 | |
86 | 86 | // SetClient sets the underlying HTTP client used for requests. |
87 | 87 | // By default, http.DefaultClient is used. |
88 | func SetClient(client *http.Client) ClientOption { | |
88 | func SetClient(client httptransport.HTTPClient) ClientOption { | |
89 | 89 | return func(c *Client) { c.client = client } |
90 | 90 | } |
91 | 91 |
34 | 34 | return nil |
35 | 35 | } |
36 | 36 | |
37 | func (id *RequestID) MarshalJSON() ([]byte, error) { | |
38 | if id.intError == nil { | |
39 | return json.Marshal(id.intValue) | |
40 | } else if id.floatError == nil { | |
41 | return json.Marshal(id.floatValue) | |
42 | } else { | |
43 | return json.Marshal(id.stringValue) | |
44 | } | |
45 | } | |
46 | ||
37 | 47 | // Int returns the ID as an integer value. |
38 | 48 | // An error is returned if the ID can't be treated as an int. |
39 | 49 | func (id *RequestID) Int() (int, error) { |
57 | 67 | type Response struct { |
58 | 68 | JSONRPC string `json:"jsonrpc"` |
59 | 69 | Result json.RawMessage `json:"result,omitempty"` |
60 | Error *Error `json:"error,omitemty"` | |
70 | Error *Error `json:"error,omitempty"` | |
71 | ID *RequestID `json:"id"` | |
61 | 72 | } |
62 | 73 | |
63 | 74 | const ( |
108 | 108 | t.Fatalf("Expected ID to be nil, got %+v.\n", r.ID) |
109 | 109 | } |
110 | 110 | } |
111 | ||
112 | func TestCanMarshalID(t *testing.T) { | |
113 | cases := []struct { | |
114 | JSON string | |
115 | expType string | |
116 | expValue interface{} | |
117 | }{ | |
118 | {`12345`, "int", 12345}, | |
119 | {`12345.6`, "float", 12345.6}, | |
120 | {`"stringaling"`, "string", "stringaling"}, | |
121 | {`null`, "null", nil}, | |
122 | } | |
123 | ||
124 | for _, c := range cases { | |
125 | req := jsonrpc.Request{} | |
126 | JSON := fmt.Sprintf(`{"jsonrpc":"2.0","id":%s}`, c.JSON) | |
127 | json.Unmarshal([]byte(JSON), &req) | |
128 | resp := jsonrpc.Response{ID: req.ID, JSONRPC: req.JSONRPC} | |
129 | ||
130 | want := JSON | |
131 | bol, _ := json.Marshal(resp) | |
132 | got := string(bol) | |
133 | if got != want { | |
134 | t.Fatalf("'%s': want %s, got %s.", c.expType, want, got) | |
135 | } | |
136 | } | |
137 | } |
0 | package nats | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | ||
5 | "github.com/nats-io/go-nats" | |
6 | ) | |
7 | ||
8 | // DecodeRequestFunc extracts a user-domain request object from a publisher | |
9 | // request object. It's designed to be used in NATS subscribers, for subscriber-side | |
10 | // endpoints. One straightforward DecodeRequestFunc could be something that | |
11 | // JSON decodes from the request body to the concrete response type. | |
12 | type DecodeRequestFunc func(context.Context, *nats.Msg) (request interface{}, err error) | |
13 | ||
14 | // EncodeRequestFunc encodes the passed request object into the NATS request | |
15 | // object. It's designed to be used in NATS publishers, for publisher-side | |
16 | // endpoints. One straightforward EncodeRequestFunc could something that JSON | |
17 | // encodes the object directly to the request payload. | |
18 | type EncodeRequestFunc func(context.Context, *nats.Msg, interface{}) error | |
19 | ||
20 | // EncodeResponseFunc encodes the passed response object to the subscriber reply. | |
21 | // It's designed to be used in NATS subscribers, for subscriber-side | |
22 | // endpoints. One straightforward EncodeResponseFunc could be something that | |
23 | // JSON encodes the object directly to the response body. | |
24 | type EncodeResponseFunc func(context.Context, string, *nats.Conn, interface{}) error | |
25 | ||
26 | // DecodeResponseFunc extracts a user-domain response object from an NATS | |
27 | // response object. It's designed to be used in NATS publisher, for publisher-side | |
28 | // endpoints. One straightforward DecodeResponseFunc could be something that | |
29 | // JSON decodes from the response payload to the concrete response type. | |
30 | type DecodeResponseFunc func(context.Context, *nats.Msg) (response interface{}, err error) | |
31 |
0 | package nats | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "github.com/go-kit/kit/endpoint" | |
6 | "github.com/nats-io/go-nats" | |
7 | "time" | |
8 | ) | |
9 | ||
10 | // Publisher wraps a URL and provides a method that implements endpoint.Endpoint. | |
11 | type Publisher struct { | |
12 | publisher *nats.Conn | |
13 | subject string | |
14 | enc EncodeRequestFunc | |
15 | dec DecodeResponseFunc | |
16 | before []RequestFunc | |
17 | after []PublisherResponseFunc | |
18 | timeout time.Duration | |
19 | } | |
20 | ||
21 | // NewPublisher constructs a usable Publisher for a single remote method. | |
22 | func NewPublisher( | |
23 | publisher *nats.Conn, | |
24 | subject string, | |
25 | enc EncodeRequestFunc, | |
26 | dec DecodeResponseFunc, | |
27 | options ...PublisherOption, | |
28 | ) *Publisher { | |
29 | p := &Publisher{ | |
30 | publisher: publisher, | |
31 | subject: subject, | |
32 | enc: enc, | |
33 | dec: dec, | |
34 | timeout: 10 * time.Second, | |
35 | } | |
36 | for _, option := range options { | |
37 | option(p) | |
38 | } | |
39 | return p | |
40 | } | |
41 | ||
42 | // PublisherOption sets an optional parameter for clients. | |
43 | type PublisherOption func(*Publisher) | |
44 | ||
45 | // PublisherBefore sets the RequestFuncs that are applied to the outgoing NATS | |
46 | // request before it's invoked. | |
47 | func PublisherBefore(before ...RequestFunc) PublisherOption { | |
48 | return func(p *Publisher) { p.before = append(p.before, before...) } | |
49 | } | |
50 | ||
51 | // PublisherAfter sets the ClientResponseFuncs applied to the incoming NATS | |
52 | // request prior to it being decoded. This is useful for obtaining anything off | |
53 | // of the response and adding onto the context prior to decoding. | |
54 | func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { | |
55 | return func(p *Publisher) { p.after = append(p.after, after...) } | |
56 | } | |
57 | ||
58 | // PublisherTimeout sets the available timeout for NATS request. | |
59 | func PublisherTimeout(timeout time.Duration) PublisherOption { | |
60 | return func(p *Publisher) { p.timeout = timeout } | |
61 | } | |
62 | ||
63 | // Endpoint returns a usable endpoint that invokes the remote endpoint. | |
64 | func (p Publisher) Endpoint() endpoint.Endpoint { | |
65 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
66 | ctx, cancel := context.WithTimeout(context.Background(), p.timeout) | |
67 | defer cancel() | |
68 | ||
69 | msg := nats.Msg{Subject: p.subject} | |
70 | ||
71 | if err := p.enc(ctx, &msg, request); err != nil { | |
72 | return nil, err | |
73 | } | |
74 | ||
75 | for _, f := range p.before { | |
76 | ctx = f(ctx, &msg) | |
77 | } | |
78 | ||
79 | resp, err := p.publisher.RequestWithContext(ctx, msg.Subject, msg.Data) | |
80 | if err != nil { | |
81 | return nil, err | |
82 | } | |
83 | ||
84 | for _, f := range p.after { | |
85 | ctx = f(ctx, resp) | |
86 | } | |
87 | ||
88 | response, err := p.dec(ctx, resp) | |
89 | if err != nil { | |
90 | return nil, err | |
91 | } | |
92 | ||
93 | return response, nil | |
94 | } | |
95 | } | |
96 | ||
97 | // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a | |
98 | // JSON object to the Data of the Msg. Many JSON-over-NATS services can use it as | |
99 | // a sensible default. | |
100 | func EncodeJSONRequest(_ context.Context, msg *nats.Msg, request interface{}) error { | |
101 | b, err := json.Marshal(request) | |
102 | if err != nil { | |
103 | return err | |
104 | } | |
105 | ||
106 | msg.Data = b | |
107 | ||
108 | return nil | |
109 | } |
0 | package nats_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "strings" | |
5 | "testing" | |
6 | "time" | |
7 | ||
8 | natstransport "github.com/go-kit/kit/transport/nats" | |
9 | "github.com/nats-io/go-nats" | |
10 | ) | |
11 | ||
12 | func TestPublisher(t *testing.T) { | |
13 | var ( | |
14 | testdata = "testdata" | |
15 | encode = func(context.Context, *nats.Msg, interface{}) error { return nil } | |
16 | decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
17 | return TestResponse{string(msg.Data), ""}, nil | |
18 | } | |
19 | ) | |
20 | ||
21 | nc := newNatsConn(t) | |
22 | defer nc.Close() | |
23 | ||
24 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { | |
25 | if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil { | |
26 | t.Fatal(err) | |
27 | } | |
28 | }) | |
29 | if err != nil { | |
30 | t.Fatal(err) | |
31 | } | |
32 | defer sub.Unsubscribe() | |
33 | ||
34 | publisher := natstransport.NewPublisher( | |
35 | nc, | |
36 | "natstransport.test", | |
37 | encode, | |
38 | decode, | |
39 | ) | |
40 | ||
41 | res, err := publisher.Endpoint()(context.Background(), struct{}{}) | |
42 | if err != nil { | |
43 | t.Fatal(err) | |
44 | } | |
45 | ||
46 | response, ok := res.(TestResponse) | |
47 | if !ok { | |
48 | t.Fatal("response should be TestResponse") | |
49 | } | |
50 | if want, have := testdata, response.String; want != have { | |
51 | t.Errorf("want %q, have %q", want, have) | |
52 | } | |
53 | ||
54 | } | |
55 | ||
56 | func TestPublisherBefore(t *testing.T) { | |
57 | var ( | |
58 | testdata = "testdata" | |
59 | encode = func(context.Context, *nats.Msg, interface{}) error { return nil } | |
60 | decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
61 | return TestResponse{string(msg.Data), ""}, nil | |
62 | } | |
63 | ) | |
64 | ||
65 | nc := newNatsConn(t) | |
66 | defer nc.Close() | |
67 | ||
68 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { | |
69 | if err := nc.Publish(msg.Reply, msg.Data); err != nil { | |
70 | t.Fatal(err) | |
71 | } | |
72 | }) | |
73 | if err != nil { | |
74 | t.Fatal(err) | |
75 | } | |
76 | defer sub.Unsubscribe() | |
77 | ||
78 | publisher := natstransport.NewPublisher( | |
79 | nc, | |
80 | "natstransport.test", | |
81 | encode, | |
82 | decode, | |
83 | natstransport.PublisherBefore(func(ctx context.Context, msg *nats.Msg) context.Context { | |
84 | msg.Data = []byte(strings.ToUpper(string(testdata))) | |
85 | return ctx | |
86 | }), | |
87 | ) | |
88 | ||
89 | res, err := publisher.Endpoint()(context.Background(), struct{}{}) | |
90 | if err != nil { | |
91 | t.Fatal(err) | |
92 | } | |
93 | ||
94 | response, ok := res.(TestResponse) | |
95 | if !ok { | |
96 | t.Fatal("response should be TestResponse") | |
97 | } | |
98 | if want, have := strings.ToUpper(testdata), response.String; want != have { | |
99 | t.Errorf("want %q, have %q", want, have) | |
100 | } | |
101 | ||
102 | } | |
103 | ||
104 | func TestPublisherAfter(t *testing.T) { | |
105 | var ( | |
106 | testdata = "testdata" | |
107 | encode = func(context.Context, *nats.Msg, interface{}) error { return nil } | |
108 | decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
109 | return TestResponse{string(msg.Data), ""}, nil | |
110 | } | |
111 | ) | |
112 | ||
113 | nc := newNatsConn(t) | |
114 | defer nc.Close() | |
115 | ||
116 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { | |
117 | if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil { | |
118 | t.Fatal(err) | |
119 | } | |
120 | }) | |
121 | if err != nil { | |
122 | t.Fatal(err) | |
123 | } | |
124 | defer sub.Unsubscribe() | |
125 | ||
126 | publisher := natstransport.NewPublisher( | |
127 | nc, | |
128 | "natstransport.test", | |
129 | encode, | |
130 | decode, | |
131 | natstransport.PublisherAfter(func(ctx context.Context, msg *nats.Msg) context.Context { | |
132 | msg.Data = []byte(strings.ToUpper(string(msg.Data))) | |
133 | return ctx | |
134 | }), | |
135 | ) | |
136 | ||
137 | res, err := publisher.Endpoint()(context.Background(), struct{}{}) | |
138 | if err != nil { | |
139 | t.Fatal(err) | |
140 | } | |
141 | ||
142 | response, ok := res.(TestResponse) | |
143 | if !ok { | |
144 | t.Fatal("response should be TestResponse") | |
145 | } | |
146 | if want, have := strings.ToUpper(testdata), response.String; want != have { | |
147 | t.Errorf("want %q, have %q", want, have) | |
148 | } | |
149 | ||
150 | } | |
151 | ||
152 | func TestPublisherTimeout(t *testing.T) { | |
153 | var ( | |
154 | encode = func(context.Context, *nats.Msg, interface{}) error { return nil } | |
155 | decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { | |
156 | return TestResponse{string(msg.Data), ""}, nil | |
157 | } | |
158 | ) | |
159 | ||
160 | nc := newNatsConn(t) | |
161 | defer nc.Close() | |
162 | ||
163 | ch := make(chan struct{}) | |
164 | defer close(ch) | |
165 | ||
166 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { | |
167 | <-ch | |
168 | }) | |
169 | if err != nil { | |
170 | t.Fatal(err) | |
171 | } | |
172 | defer sub.Unsubscribe() | |
173 | ||
174 | publisher := natstransport.NewPublisher( | |
175 | nc, | |
176 | "natstransport.test", | |
177 | encode, | |
178 | decode, | |
179 | natstransport.PublisherTimeout(time.Second), | |
180 | ) | |
181 | ||
182 | _, err = publisher.Endpoint()(context.Background(), struct{}{}) | |
183 | if err != context.DeadlineExceeded { | |
184 | t.Errorf("want %s, have %s", context.DeadlineExceeded, err) | |
185 | } | |
186 | } | |
187 | ||
188 | func TestEncodeJSONRequest(t *testing.T) { | |
189 | var data string | |
190 | ||
191 | nc := newNatsConn(t) | |
192 | defer nc.Close() | |
193 | ||
194 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { | |
195 | data = string(msg.Data) | |
196 | ||
197 | if err := nc.Publish(msg.Reply, []byte("")); err != nil { | |
198 | t.Fatal(err) | |
199 | } | |
200 | }) | |
201 | if err != nil { | |
202 | t.Fatal(err) | |
203 | } | |
204 | defer sub.Unsubscribe() | |
205 | ||
206 | publisher := natstransport.NewPublisher( | |
207 | nc, | |
208 | "natstransport.test", | |
209 | natstransport.EncodeJSONRequest, | |
210 | func(context.Context, *nats.Msg) (interface{}, error) { return nil, nil }, | |
211 | ).Endpoint() | |
212 | ||
213 | for _, test := range []struct { | |
214 | value interface{} | |
215 | body string | |
216 | }{ | |
217 | {nil, "null"}, | |
218 | {12, "12"}, | |
219 | {1.2, "1.2"}, | |
220 | {true, "true"}, | |
221 | {"test", "\"test\""}, | |
222 | {struct { | |
223 | Foo string `json:"foo"` | |
224 | }{"foo"}, "{\"foo\":\"foo\"}"}, | |
225 | } { | |
226 | if _, err := publisher(context.Background(), test.value); err != nil { | |
227 | t.Fatal(err) | |
228 | continue | |
229 | } | |
230 | ||
231 | if data != test.body { | |
232 | t.Errorf("%v: actual %#v, expected %#v", test.value, data, test.body) | |
233 | } | |
234 | } | |
235 | ||
236 | } |
0 | package nats | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | ||
5 | "github.com/nats-io/go-nats" | |
6 | ) | |
7 | ||
8 | // RequestFunc may take information from a publisher request and put it into a | |
9 | // request context. In Subscribers, RequestFuncs are executed prior to invoking the | |
10 | // endpoint. | |
11 | type RequestFunc func(context.Context, *nats.Msg) context.Context | |
12 | ||
13 | // SubscriberResponseFunc may take information from a request context and use it to | |
14 | // manipulate a Publisher. SubscriberResponseFuncs are only executed in | |
15 | // subscribers, after invoking the endpoint but prior to publishing a reply. | |
16 | type SubscriberResponseFunc func(context.Context, *nats.Conn) context.Context | |
17 | ||
18 | // PublisherResponseFunc may take information from an NATS request and make the | |
19 | // response available for consumption. ClientResponseFuncs are only executed in | |
20 | // clients, after a request has been made, but prior to it being decoded. | |
21 | type PublisherResponseFunc func(context.Context, *nats.Msg) context.Context |
0 | package nats | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | ||
6 | "github.com/go-kit/kit/endpoint" | |
7 | "github.com/go-kit/kit/log" | |
8 | ||
9 | "github.com/nats-io/go-nats" | |
10 | ) | |
11 | ||
12 | // Subscriber wraps an endpoint and provides nats.MsgHandler. | |
13 | type Subscriber struct { | |
14 | e endpoint.Endpoint | |
15 | dec DecodeRequestFunc | |
16 | enc EncodeResponseFunc | |
17 | before []RequestFunc | |
18 | after []SubscriberResponseFunc | |
19 | errorEncoder ErrorEncoder | |
20 | finalizer []SubscriberFinalizerFunc | |
21 | logger log.Logger | |
22 | } | |
23 | ||
24 | // NewSubscriber constructs a new subscriber, which provides nats.MsgHandler and wraps | |
25 | // the provided endpoint. | |
26 | func NewSubscriber( | |
27 | e endpoint.Endpoint, | |
28 | dec DecodeRequestFunc, | |
29 | enc EncodeResponseFunc, | |
30 | options ...SubscriberOption, | |
31 | ) *Subscriber { | |
32 | s := &Subscriber{ | |
33 | e: e, | |
34 | dec: dec, | |
35 | enc: enc, | |
36 | errorEncoder: DefaultErrorEncoder, | |
37 | logger: log.NewNopLogger(), | |
38 | } | |
39 | for _, option := range options { | |
40 | option(s) | |
41 | } | |
42 | return s | |
43 | } | |
44 | ||
45 | // SubscriberOption sets an optional parameter for subscribers. | |
46 | type SubscriberOption func(*Subscriber) | |
47 | ||
48 | // SubscriberBefore functions are executed on the publisher request object before the | |
49 | // request is decoded. | |
50 | func SubscriberBefore(before ...RequestFunc) SubscriberOption { | |
51 | return func(s *Subscriber) { s.before = append(s.before, before...) } | |
52 | } | |
53 | ||
54 | // SubscriberAfter functions are executed on the subscriber reply after the | |
55 | // endpoint is invoked, but before anything is published to the reply. | |
56 | func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { | |
57 | return func(s *Subscriber) { s.after = append(s.after, after...) } | |
58 | } | |
59 | ||
60 | // SubscriberErrorEncoder is used to encode errors to the subscriber reply | |
61 | // whenever they're encountered in the processing of a request. Clients can | |
62 | // use this to provide custom error formatting. By default, | |
63 | // errors will be published with the DefaultErrorEncoder. | |
64 | func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption { | |
65 | return func(s *Subscriber) { s.errorEncoder = ee } | |
66 | } | |
67 | ||
68 | // SubscriberErrorLogger is used to log non-terminal errors. By default, no errors | |
69 | // are logged. This is intended as a diagnostic measure. Finer-grained control | |
70 | // of error handling, including logging in more detail, should be performed in a | |
71 | // custom SubscriberErrorEncoder which has access to the context. | |
72 | func SubscriberErrorLogger(logger log.Logger) SubscriberOption { | |
73 | return func(s *Subscriber) { s.logger = logger } | |
74 | } | |
75 | ||
76 | // SubscriberFinalizer is executed at the end of every request from a publisher through NATS. | |
77 | // By default, no finalizer is registered. | |
78 | func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption { | |
79 | return func(s *Subscriber) { s.finalizer = f } | |
80 | } | |
81 | ||
82 | // ServeMsg provides nats.MsgHandler. | |
83 | func (s Subscriber) ServeMsg(nc *nats.Conn) func(msg *nats.Msg) { | |
84 | return func(msg *nats.Msg) { | |
85 | ctx, cancel := context.WithCancel(context.Background()) | |
86 | defer cancel() | |
87 | ||
88 | if len(s.finalizer) > 0 { | |
89 | defer func() { | |
90 | for _, f := range s.finalizer { | |
91 | f(ctx, msg) | |
92 | } | |
93 | }() | |
94 | } | |
95 | ||
96 | for _, f := range s.before { | |
97 | ctx = f(ctx, msg) | |
98 | } | |
99 | ||
100 | request, err := s.dec(ctx, msg) | |
101 | if err != nil { | |
102 | s.logger.Log("err", err) | |
103 | if msg.Reply == "" { | |
104 | return | |
105 | } | |
106 | s.errorEncoder(ctx, err, msg.Reply, nc) | |
107 | return | |
108 | } | |
109 | ||
110 | response, err := s.e(ctx, request) | |
111 | if err != nil { | |
112 | s.logger.Log("err", err) | |
113 | if msg.Reply == "" { | |
114 | return | |
115 | } | |
116 | s.errorEncoder(ctx, err, msg.Reply, nc) | |
117 | return | |
118 | } | |
119 | ||
120 | for _, f := range s.after { | |
121 | ctx = f(ctx, nc) | |
122 | } | |
123 | ||
124 | if msg.Reply == "" { | |
125 | return | |
126 | } | |
127 | ||
128 | if err := s.enc(ctx, msg.Reply, nc, response); err != nil { | |
129 | s.logger.Log("err", err) | |
130 | s.errorEncoder(ctx, err, msg.Reply, nc) | |
131 | return | |
132 | } | |
133 | } | |
134 | } | |
135 | ||
136 | // ErrorEncoder is responsible for encoding an error to the subscriber reply. | |
137 | // Users are encouraged to use custom ErrorEncoders to encode errors to | |
138 | // their replies, and will likely want to pass and check for their own error | |
139 | // types. | |
140 | type ErrorEncoder func(ctx context.Context, err error, reply string, nc *nats.Conn) | |
141 | ||
142 | // ServerFinalizerFunc can be used to perform work at the end of an request | |
143 | // from a publisher, after the response has been written to the publisher. The principal | |
144 | // intended use is for request logging. | |
145 | type SubscriberFinalizerFunc func(ctx context.Context, msg *nats.Msg) | |
146 | ||
147 | // NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not | |
148 | // need to be decoded, and simply returns nil, nil. | |
149 | func NopRequestDecoder(_ context.Context, _ *nats.Msg) (interface{}, error) { | |
150 | return nil, nil | |
151 | } | |
152 | ||
153 | // EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a | |
154 | // JSON object to the subscriber reply. Many JSON-over services can use it as | |
155 | // a sensible default. | |
156 | func EncodeJSONResponse(_ context.Context, reply string, nc *nats.Conn, response interface{}) error { | |
157 | b, err := json.Marshal(response) | |
158 | if err != nil { | |
159 | return err | |
160 | } | |
161 | ||
162 | return nc.Publish(reply, b) | |
163 | } | |
164 | ||
165 | // DefaultErrorEncoder writes the error to the subscriber reply. | |
166 | func DefaultErrorEncoder(_ context.Context, err error, reply string, nc *nats.Conn) { | |
167 | logger := log.NewNopLogger() | |
168 | ||
169 | type Response struct { | |
170 | Error string `json:"err"` | |
171 | } | |
172 | ||
173 | var response Response | |
174 | ||
175 | response.Error = err.Error() | |
176 | ||
177 | b, err := json.Marshal(response) | |
178 | if err != nil { | |
179 | logger.Log("err", err) | |
180 | return | |
181 | } | |
182 | ||
183 | if err := nc.Publish(reply, b); err != nil { | |
184 | logger.Log("err", err) | |
185 | } | |
186 | } |
0 | package nats_test | |
1 | ||
2 | import ( | |
3 | "context" | |
4 | "encoding/json" | |
5 | "errors" | |
6 | "strings" | |
7 | "sync" | |
8 | "testing" | |
9 | "time" | |
10 | ||
11 | "github.com/nats-io/gnatsd/server" | |
12 | "github.com/nats-io/go-nats" | |
13 | ||
14 | "github.com/go-kit/kit/endpoint" | |
15 | natstransport "github.com/go-kit/kit/transport/nats" | |
16 | ) | |
17 | ||
18 | type TestResponse struct { | |
19 | String string `json:"str"` | |
20 | Error string `json:"err"` | |
21 | } | |
22 | ||
23 | var natsServer *server.Server | |
24 | ||
25 | func init() { | |
26 | natsServer = server.New(&server.Options{ | |
27 | Host: "localhost", | |
28 | Port: 4222, | |
29 | }) | |
30 | ||
31 | go func() { | |
32 | natsServer.Start() | |
33 | }() | |
34 | ||
35 | if ok := natsServer.ReadyForConnections(2 * time.Second); !ok { | |
36 | panic("Failed start of NATS") | |
37 | } | |
38 | } | |
39 | ||
40 | func newNatsConn(t *testing.T) *nats.Conn { | |
41 | // Subscriptions and connections are closed asynchronously, so it's possible | |
42 | // that there's still a subscription from an old connection that must be closed | |
43 | // before the current test can be run. | |
44 | for tries := 20; tries > 0; tries-- { | |
45 | if natsServer.NumSubscriptions() == 0 { | |
46 | break | |
47 | } | |
48 | ||
49 | time.Sleep(5 * time.Millisecond) | |
50 | } | |
51 | ||
52 | if n := natsServer.NumSubscriptions(); n > 0 { | |
53 | t.Fatalf("found %d active subscriptions on the server", n) | |
54 | } | |
55 | ||
56 | nc, err := nats.Connect("nats://"+natsServer.Addr().String(), nats.Name(t.Name())) | |
57 | if err != nil { | |
58 | t.Fatalf("failed to connect to gnatsd server: %s", err) | |
59 | } | |
60 | ||
61 | return nc | |
62 | } | |
63 | ||
64 | func TestSubscriberBadDecode(t *testing.T) { | |
65 | nc := newNatsConn(t) | |
66 | defer nc.Close() | |
67 | ||
68 | handler := natstransport.NewSubscriber( | |
69 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
70 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, errors.New("dang") }, | |
71 | func(context.Context, string, *nats.Conn, interface{}) error { return nil }, | |
72 | ) | |
73 | ||
74 | resp := testRequest(t, nc, handler) | |
75 | ||
76 | if want, have := "dang", resp.Error; want != have { | |
77 | t.Errorf("want %s, have %s", want, have) | |
78 | } | |
79 | ||
80 | } | |
81 | ||
82 | func TestSubscriberBadEndpoint(t *testing.T) { | |
83 | nc := newNatsConn(t) | |
84 | defer nc.Close() | |
85 | ||
86 | handler := natstransport.NewSubscriber( | |
87 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, | |
88 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
89 | func(context.Context, string, *nats.Conn, interface{}) error { return nil }, | |
90 | ) | |
91 | ||
92 | resp := testRequest(t, nc, handler) | |
93 | ||
94 | if want, have := "dang", resp.Error; want != have { | |
95 | t.Errorf("want %s, have %s", want, have) | |
96 | } | |
97 | } | |
98 | ||
99 | func TestSubscriberBadEncode(t *testing.T) { | |
100 | nc := newNatsConn(t) | |
101 | defer nc.Close() | |
102 | ||
103 | handler := natstransport.NewSubscriber( | |
104 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, | |
105 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
106 | func(context.Context, string, *nats.Conn, interface{}) error { return errors.New("dang") }, | |
107 | ) | |
108 | ||
109 | resp := testRequest(t, nc, handler) | |
110 | ||
111 | if want, have := "dang", resp.Error; want != have { | |
112 | t.Errorf("want %s, have %s", want, have) | |
113 | } | |
114 | } | |
115 | ||
116 | func TestSubscriberErrorEncoder(t *testing.T) { | |
117 | nc := newNatsConn(t) | |
118 | defer nc.Close() | |
119 | ||
120 | errTeapot := errors.New("teapot") | |
121 | code := func(err error) error { | |
122 | if err == errTeapot { | |
123 | return err | |
124 | } | |
125 | return errors.New("dang") | |
126 | } | |
127 | handler := natstransport.NewSubscriber( | |
128 | func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, | |
129 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
130 | func(context.Context, string, *nats.Conn, interface{}) error { return nil }, | |
131 | natstransport.SubscriberErrorEncoder(func(_ context.Context, err error, reply string, nc *nats.Conn) { | |
132 | var r TestResponse | |
133 | r.Error = code(err).Error() | |
134 | ||
135 | b, err := json.Marshal(r) | |
136 | if err != nil { | |
137 | t.Fatal(err) | |
138 | } | |
139 | ||
140 | if err := nc.Publish(reply, b); err != nil { | |
141 | t.Fatal(err) | |
142 | } | |
143 | }), | |
144 | ) | |
145 | ||
146 | resp := testRequest(t, nc, handler) | |
147 | ||
148 | if want, have := errTeapot.Error(), resp.Error; want != have { | |
149 | t.Errorf("want %s, have %s", want, have) | |
150 | } | |
151 | } | |
152 | ||
153 | func TestSubscriberHappySubject(t *testing.T) { | |
154 | step, response := testSubscriber(t) | |
155 | step() | |
156 | r := <-response | |
157 | ||
158 | var resp TestResponse | |
159 | err := json.Unmarshal(r.Data, &resp) | |
160 | if err != nil { | |
161 | t.Fatal(err) | |
162 | } | |
163 | ||
164 | if want, have := "", resp.Error; want != have { | |
165 | t.Errorf("want %s, have %s (%s)", want, have, r.Data) | |
166 | } | |
167 | } | |
168 | ||
169 | func TestMultipleSubscriberBefore(t *testing.T) { | |
170 | nc := newNatsConn(t) | |
171 | defer nc.Close() | |
172 | ||
173 | var ( | |
174 | response = struct{ Body string }{"go eat a fly ugly\n"} | |
175 | wg sync.WaitGroup | |
176 | done = make(chan struct{}) | |
177 | ) | |
178 | handler := natstransport.NewSubscriber( | |
179 | endpoint.Nop, | |
180 | func(context.Context, *nats.Msg) (interface{}, error) { | |
181 | return struct{}{}, nil | |
182 | }, | |
183 | func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { | |
184 | b, err := json.Marshal(response) | |
185 | if err != nil { | |
186 | return err | |
187 | } | |
188 | ||
189 | return nc.Publish(reply, b) | |
190 | }, | |
191 | natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context { | |
192 | ctx = context.WithValue(ctx, "one", 1) | |
193 | ||
194 | return ctx | |
195 | }), | |
196 | natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context { | |
197 | if _, ok := ctx.Value("one").(int); !ok { | |
198 | t.Error("Value was not set properly when multiple ServerBefores are used") | |
199 | } | |
200 | ||
201 | close(done) | |
202 | return ctx | |
203 | }), | |
204 | ) | |
205 | ||
206 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
207 | if err != nil { | |
208 | t.Fatal(err) | |
209 | } | |
210 | defer sub.Unsubscribe() | |
211 | ||
212 | wg.Add(1) | |
213 | go func() { | |
214 | defer wg.Done() | |
215 | _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
216 | if err != nil { | |
217 | t.Fatal(err) | |
218 | } | |
219 | }() | |
220 | ||
221 | select { | |
222 | case <-done: | |
223 | case <-time.After(time.Second): | |
224 | t.Fatal("timeout waiting for finalizer") | |
225 | } | |
226 | ||
227 | wg.Wait() | |
228 | } | |
229 | ||
230 | func TestMultipleSubscriberAfter(t *testing.T) { | |
231 | nc := newNatsConn(t) | |
232 | defer nc.Close() | |
233 | ||
234 | var ( | |
235 | response = struct{ Body string }{"go eat a fly ugly\n"} | |
236 | wg sync.WaitGroup | |
237 | done = make(chan struct{}) | |
238 | ) | |
239 | handler := natstransport.NewSubscriber( | |
240 | endpoint.Nop, | |
241 | func(context.Context, *nats.Msg) (interface{}, error) { | |
242 | return struct{}{}, nil | |
243 | }, | |
244 | func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { | |
245 | b, err := json.Marshal(response) | |
246 | if err != nil { | |
247 | return err | |
248 | } | |
249 | ||
250 | return nc.Publish(reply, b) | |
251 | }, | |
252 | natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { | |
253 | ctx = context.WithValue(ctx, "one", 1) | |
254 | ||
255 | return ctx | |
256 | }), | |
257 | natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { | |
258 | if _, ok := ctx.Value("one").(int); !ok { | |
259 | t.Error("Value was not set properly when multiple ServerAfters are used") | |
260 | } | |
261 | ||
262 | close(done) | |
263 | return ctx | |
264 | }), | |
265 | ) | |
266 | ||
267 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
268 | if err != nil { | |
269 | t.Fatal(err) | |
270 | } | |
271 | defer sub.Unsubscribe() | |
272 | ||
273 | wg.Add(1) | |
274 | go func() { | |
275 | defer wg.Done() | |
276 | _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
277 | if err != nil { | |
278 | t.Fatal(err) | |
279 | } | |
280 | }() | |
281 | ||
282 | select { | |
283 | case <-done: | |
284 | case <-time.After(time.Second): | |
285 | t.Fatal("timeout waiting for finalizer") | |
286 | } | |
287 | ||
288 | wg.Wait() | |
289 | } | |
290 | ||
291 | func TestSubscriberFinalizerFunc(t *testing.T) { | |
292 | nc := newNatsConn(t) | |
293 | defer nc.Close() | |
294 | ||
295 | var ( | |
296 | response = struct{ Body string }{"go eat a fly ugly\n"} | |
297 | wg sync.WaitGroup | |
298 | done = make(chan struct{}) | |
299 | ) | |
300 | handler := natstransport.NewSubscriber( | |
301 | endpoint.Nop, | |
302 | func(context.Context, *nats.Msg) (interface{}, error) { | |
303 | return struct{}{}, nil | |
304 | }, | |
305 | func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error { | |
306 | b, err := json.Marshal(response) | |
307 | if err != nil { | |
308 | return err | |
309 | } | |
310 | ||
311 | return nc.Publish(reply, b) | |
312 | }, | |
313 | natstransport.SubscriberFinalizer(func(ctx context.Context, _ *nats.Msg) { | |
314 | close(done) | |
315 | }), | |
316 | ) | |
317 | ||
318 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
319 | if err != nil { | |
320 | t.Fatal(err) | |
321 | } | |
322 | defer sub.Unsubscribe() | |
323 | ||
324 | wg.Add(1) | |
325 | go func() { | |
326 | defer wg.Done() | |
327 | _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
328 | if err != nil { | |
329 | t.Fatal(err) | |
330 | } | |
331 | }() | |
332 | ||
333 | select { | |
334 | case <-done: | |
335 | case <-time.After(time.Second): | |
336 | t.Fatal("timeout waiting for finalizer") | |
337 | } | |
338 | ||
339 | wg.Wait() | |
340 | } | |
341 | ||
342 | func TestEncodeJSONResponse(t *testing.T) { | |
343 | nc := newNatsConn(t) | |
344 | defer nc.Close() | |
345 | ||
346 | handler := natstransport.NewSubscriber( | |
347 | func(context.Context, interface{}) (interface{}, error) { | |
348 | return struct { | |
349 | Foo string `json:"foo"` | |
350 | }{"bar"}, nil | |
351 | }, | |
352 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
353 | natstransport.EncodeJSONResponse, | |
354 | ) | |
355 | ||
356 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
357 | if err != nil { | |
358 | t.Fatal(err) | |
359 | } | |
360 | defer sub.Unsubscribe() | |
361 | ||
362 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
363 | if err != nil { | |
364 | t.Fatal(err) | |
365 | } | |
366 | ||
367 | if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(r.Data)); want != have { | |
368 | t.Errorf("Body: want %s, have %s", want, have) | |
369 | } | |
370 | } | |
371 | ||
372 | type responseError struct { | |
373 | msg string | |
374 | } | |
375 | ||
376 | func (m responseError) Error() string { | |
377 | return m.msg | |
378 | } | |
379 | ||
380 | func TestErrorEncoder(t *testing.T) { | |
381 | nc := newNatsConn(t) | |
382 | defer nc.Close() | |
383 | ||
384 | errResp := struct { | |
385 | Error string `json:"err"` | |
386 | }{"oh no"} | |
387 | handler := natstransport.NewSubscriber( | |
388 | func(context.Context, interface{}) (interface{}, error) { | |
389 | return nil, responseError{msg: errResp.Error} | |
390 | }, | |
391 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
392 | natstransport.EncodeJSONResponse, | |
393 | ) | |
394 | ||
395 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
396 | if err != nil { | |
397 | t.Fatal(err) | |
398 | } | |
399 | defer sub.Unsubscribe() | |
400 | ||
401 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
402 | if err != nil { | |
403 | t.Fatal(err) | |
404 | } | |
405 | ||
406 | b, err := json.Marshal(errResp) | |
407 | if err != nil { | |
408 | t.Fatal(err) | |
409 | } | |
410 | if string(b) != string(r.Data) { | |
411 | t.Errorf("ErrorEncoder: got: %q, expected: %q", r.Data, b) | |
412 | } | |
413 | } | |
414 | ||
415 | type noContentResponse struct{} | |
416 | ||
417 | func TestEncodeNoContent(t *testing.T) { | |
418 | nc := newNatsConn(t) | |
419 | defer nc.Close() | |
420 | ||
421 | handler := natstransport.NewSubscriber( | |
422 | func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, | |
423 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
424 | natstransport.EncodeJSONResponse, | |
425 | ) | |
426 | ||
427 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
428 | if err != nil { | |
429 | t.Fatal(err) | |
430 | } | |
431 | defer sub.Unsubscribe() | |
432 | ||
433 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
434 | if err != nil { | |
435 | t.Fatal(err) | |
436 | } | |
437 | ||
438 | if want, have := `{}`, strings.TrimSpace(string(r.Data)); want != have { | |
439 | t.Errorf("Body: want %s, have %s", want, have) | |
440 | } | |
441 | } | |
442 | ||
443 | func TestNoOpRequestDecoder(t *testing.T) { | |
444 | nc := newNatsConn(t) | |
445 | defer nc.Close() | |
446 | ||
447 | handler := natstransport.NewSubscriber( | |
448 | func(ctx context.Context, request interface{}) (interface{}, error) { | |
449 | if request != nil { | |
450 | t.Error("Expected nil request in endpoint when using NopRequestDecoder") | |
451 | } | |
452 | return nil, nil | |
453 | }, | |
454 | natstransport.NopRequestDecoder, | |
455 | natstransport.EncodeJSONResponse, | |
456 | ) | |
457 | ||
458 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
459 | if err != nil { | |
460 | t.Fatal(err) | |
461 | } | |
462 | defer sub.Unsubscribe() | |
463 | ||
464 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
465 | if err != nil { | |
466 | t.Fatal(err) | |
467 | } | |
468 | ||
469 | if want, have := `null`, strings.TrimSpace(string(r.Data)); want != have { | |
470 | t.Errorf("Body: want %s, have %s", want, have) | |
471 | } | |
472 | } | |
473 | ||
474 | func testSubscriber(t *testing.T) (step func(), resp <-chan *nats.Msg) { | |
475 | var ( | |
476 | stepch = make(chan bool) | |
477 | endpoint = func(context.Context, interface{}) (interface{}, error) { | |
478 | <-stepch | |
479 | return struct{}{}, nil | |
480 | } | |
481 | response = make(chan *nats.Msg) | |
482 | handler = natstransport.NewSubscriber( | |
483 | endpoint, | |
484 | func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, | |
485 | natstransport.EncodeJSONResponse, | |
486 | natstransport.SubscriberBefore(func(ctx context.Context, msg *nats.Msg) context.Context { return ctx }), | |
487 | natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { return ctx }), | |
488 | ) | |
489 | ) | |
490 | ||
491 | go func() { | |
492 | nc := newNatsConn(t) | |
493 | defer nc.Close() | |
494 | ||
495 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
496 | if err != nil { | |
497 | t.Fatal(err) | |
498 | } | |
499 | defer sub.Unsubscribe() | |
500 | ||
501 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
502 | if err != nil { | |
503 | t.Fatal(err) | |
504 | } | |
505 | ||
506 | response <- r | |
507 | }() | |
508 | ||
509 | return func() { stepch <- true }, response | |
510 | } | |
511 | ||
512 | func testRequest(t *testing.T, nc *nats.Conn, handler *natstransport.Subscriber) TestResponse { | |
513 | sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc)) | |
514 | if err != nil { | |
515 | t.Fatal(err) | |
516 | } | |
517 | defer sub.Unsubscribe() | |
518 | ||
519 | r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second) | |
520 | if err != nil { | |
521 | t.Fatal(err) | |
522 | } | |
523 | ||
524 | var resp TestResponse | |
525 | err = json.Unmarshal(r.Data, &resp) | |
526 | if err != nil { | |
527 | t.Fatal(err) | |
528 | } | |
529 | ||
530 | return resp | |
531 | } |
1 | 1 | |
2 | 2 | import ( |
3 | 3 | "errors" |
4 | "math/rand" | |
4 | 5 | "net" |
5 | 6 | "time" |
6 | 7 | |
102 | 103 | case conn = <-connc: |
103 | 104 | if conn == nil { |
104 | 105 | // didn't work |
105 | backoff = exponential(backoff) // wait longer | |
106 | backoff = Exponential(backoff) // wait longer | |
106 | 107 | reconnectc = m.after(backoff) // try again |
107 | 108 | } else { |
108 | 109 | // worked! |
131 | 132 | return conn |
132 | 133 | } |
133 | 134 | |
134 | func exponential(d time.Duration) time.Duration { | |
135 | // Exponential takes a duration and returns another one that is twice as long, +/- 50%. It is | |
136 | // used to provide backoff for operations that may fail and should avoid thundering herds. | |
137 | // See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ for rationale | |
138 | func Exponential(d time.Duration) time.Duration { | |
135 | 139 | d *= 2 |
140 | jitter := rand.Float64() + 0.5 | |
141 | d = time.Duration(int64(float64(d.Nanoseconds()) * jitter)) | |
136 | 142 | if d > time.Minute { |
137 | 143 | d = time.Minute |
138 | 144 | } |
139 | 145 | return d |
146 | ||
140 | 147 | } |
141 | 148 | |
142 | 149 | // ErrConnectionUnavailable is returned by the Manager's Write method when the |