Codebase list golang-github-go-kit-kit / 87706f1
Merge tag 'upstream/0.8.0' into debian/sid Upstream version 0.8.0 Daniel Swarbrick 1 year, 4 months ago
90 changed file(s) with 4891 addition(s) and 141 deletion(s). Raw diff Collapse all Expand all
1111 - ./coveralls.bash
1212
1313 go:
14 - 1.9.x
1514 - 1.10.x
15 - 1.11.x
1616 - tip
106106 - [Martini](https://github.com/go-martini/martini)
107107 - [Beego](http://beego.me/)
108108 - [Revel](https://revel.github.io/) (considered [harmful](https://github.com/go-kit/kit/issues/350))
109 - [GoBuffalo](https://gobuffalo.io/)
109110
110111 ## Additional reading
111112
112 - [Architecting for the Cloud](http://fr.slideshare.net/stonse/architecting-for-the-cloud-using-netflixoss-codemash-workshop-29852233) — Netflix
113 - [Architecting for the Cloud](https://slideshare.net/stonse/architecting-for-the-cloud-using-netflixoss-codemash-workshop-29852233) — Netflix
113114 - [Dapper, a Large-Scale Distributed Systems Tracing Infrastructure](http://research.google.com/pubs/pub36356.html) — Google
114115 - [Your Server as a Function](http://monkey.org/~marius/funsrv.pdf) (PDF) — Twitter
115116
0 package casbin
1
2 import (
3 "context"
4 "errors"
5
6 stdcasbin "github.com/casbin/casbin"
7 "github.com/go-kit/kit/endpoint"
8 )
9
10 type contextKey string
11
12 const (
13 // CasbinModelContextKey holds the key to store the access control model
14 // in context, it can be a path to configuration file or a casbin/model
15 // Model.
16 CasbinModelContextKey contextKey = "CasbinModel"
17
18 // CasbinPolicyContextKey holds the key to store the access control policy
19 // in context, it can be a path to policy file or an implementation of
20 // casbin/persist Adapter interface.
21 CasbinPolicyContextKey contextKey = "CasbinPolicy"
22
23 // CasbinEnforcerContextKey holds the key to retrieve the active casbin
24 // Enforcer.
25 CasbinEnforcerContextKey contextKey = "CasbinEnforcer"
26 )
27
28 var (
29 // ErrModelContextMissing denotes a casbin model was not passed into
30 // the parsing of middleware's context.
31 ErrModelContextMissing = errors.New("CasbinModel is required in context")
32
33 // ErrPolicyContextMissing denotes a casbin policy was not passed into
34 // the parsing of middleware's context.
35 ErrPolicyContextMissing = errors.New("CasbinPolicy is required in context")
36
37 // ErrUnauthorized denotes the subject is not authorized to do the action
38 // intended on the given object, based on the context model and policy.
39 ErrUnauthorized = errors.New("Unauthorized Access")
40 )
41
42 // NewEnforcer checks whether the subject is authorized to do the specified
43 // action on the given object. If a valid access control model and policy
44 // is given, then the generated casbin Enforcer is stored in the context
45 // with CasbinEnforcer as the key.
46 func NewEnforcer(
47 subject string, object interface{}, action string,
48 ) endpoint.Middleware {
49 return func(next endpoint.Endpoint) endpoint.Endpoint {
50 return func(ctx context.Context, request interface{}) (
51 response interface{}, err error,
52 ) {
53 casbinModel := ctx.Value(CasbinModelContextKey)
54 casbinPolicy := ctx.Value(CasbinPolicyContextKey)
55
56 enforcer := stdcasbin.NewEnforcer(casbinModel, casbinPolicy)
57 ctx = context.WithValue(ctx, CasbinEnforcerContextKey, enforcer)
58 if !enforcer.Enforce(subject, object, action) {
59 return nil, ErrUnauthorized
60 }
61 return next(ctx, request)
62 }
63 }
64 }
0 package casbin
1
2 import (
3 "context"
4 "testing"
5
6 stdcasbin "github.com/casbin/casbin"
7 fileadapter "github.com/casbin/casbin/persist/file-adapter"
8 )
9
10 func TestStructBaseContext(t *testing.T) {
11 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
12
13 m := stdcasbin.NewModel()
14 m.AddDef("r", "r", "sub, obj, act")
15 m.AddDef("p", "p", "sub, obj, act")
16 m.AddDef("e", "e", "some(where (p.eft == allow))")
17 m.AddDef("m", "m", "r.sub == p.sub && keyMatch(r.obj, p.obj) && regexMatch(r.act, p.act)")
18
19 a := fileadapter.NewAdapter("testdata/keymatch_policy.csv")
20
21 ctx := context.WithValue(context.Background(), CasbinModelContextKey, m)
22 ctx = context.WithValue(ctx, CasbinPolicyContextKey, a)
23
24 // positive case
25 middleware := NewEnforcer("alice", "/alice_data/resource1", "GET")(e)
26 ctx1, err := middleware(ctx, struct{}{})
27 if err != nil {
28 t.Fatalf("Enforcer returned error: %s", err)
29 }
30 _, ok := ctx1.(context.Context).Value(CasbinEnforcerContextKey).(*stdcasbin.Enforcer)
31 if !ok {
32 t.Fatalf("context should contains the active enforcer")
33 }
34
35 // negative case
36 middleware = NewEnforcer("alice", "/alice_data/resource2", "POST")(e)
37 _, err = middleware(ctx, struct{}{})
38 if err == nil {
39 t.Fatalf("Enforcer should return error")
40 }
41 }
42
43 func TestFileBaseContext(t *testing.T) {
44 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
45 ctx := context.WithValue(context.Background(), CasbinModelContextKey, "testdata/basic_model.conf")
46 ctx = context.WithValue(ctx, CasbinPolicyContextKey, "testdata/basic_policy.csv")
47
48 // positive case
49 middleware := NewEnforcer("alice", "data1", "read")(e)
50 _, err := middleware(ctx, struct{}{})
51 if err != nil {
52 t.Fatalf("Enforcer returned error: %s", err)
53 }
54 }
0 [request_definition]
1 r = sub, obj, act
2
3 [policy_definition]
4 p = sub, obj, act
5
6 [policy_effect]
7 e = some(where (p.eft == allow))
8
9 [matchers]
10 m = r.sub == p.sub && r.obj == p.obj && r.act == p.act
0 p, alice, data1, read
1 p, bob, data2, write
0 p, alice, /alice_data/*, GET
1 p, alice, /alice_data/resource1, POST
2
3 p, bob, /alice_data/resource2, GET
4 p, bob, /bob_data/*, POST
5
6 p, cathy, /cathy_data, (GET)|(POST)
5454 ```
5555
5656 In order for the parser and the signer to work, the authorization headers need
57 to be passed between the request and the context. `ToHTTPContext()`,
58 `FromHTTPContext()`, `ToGRPCContext()`, and `FromGRPCContext()` are given as
57 to be passed between the request and the context. `HTTPToContext()`,
58 `ContextToHTTP()`, `GRPCToContext()`, and `ContextToGRPC()` are given as
5959 helpers to do this. These functions implement the correlating transport's
6060 RequestFunc interface and can be passed as ClientBefore or ServerBefore
6161 options.
7676 options := []httptransport.ClientOption{}
7777 var exampleEndpoint endpoint.Endpoint
7878 {
79 exampleEndpoint = grpctransport.NewClient(..., grpctransport.ClientBefore(jwt.FromGRPCContext())).Endpoint()
79 exampleEndpoint = grpctransport.NewClient(..., grpctransport.ClientBefore(jwt.ContextToGRPC())).Endpoint()
8080 exampleEndpoint = jwt.NewSigner(
8181 "kid-header",
8282 []byte("SigningString"),
107107 endpoints.CreateUserEndpoint,
108108 DecodeGRPCCreateUserRequest,
109109 EncodeGRPCCreateUserResponse,
110 append(options, grpctransport.ServerBefore(jwt.ToGRPCContext()))...,
110 append(options, grpctransport.ServerBefore(jwt.GRPCToContext()))...,
111111 ),
112112 getUser: grpctransport.NewServer(
113113 ctx,
0 machine:
1 pre:
2 - curl -sSL https://s3.amazonaws.com/circle-downloads/install-circleci-docker.sh | bash -s -- 1.10.0
3 - sudo rm -rf /usr/local/go
4 - curl -sSL https://storage.googleapis.com/golang/go1.9.linux-amd64.tar.gz | sudo tar xz -C /usr/local
5 services:
6 - docker
0 version: 2
71
8 dependencies:
9 pre:
10 - sudo curl -L "https://github.com/docker/compose/releases/download/1.10.0/docker-compose-linux-x86_64" -o /usr/local/bin/docker-compose
11 - sudo chmod +x /usr/local/bin/docker-compose
12 - docker-compose -f docker-compose-integration.yml up -d --force-recreate
13
14 test:
15 pre:
16 - mkdir -p /home/ubuntu/.go_workspace/src/github.com/go-kit
17 - mv /home/ubuntu/kit /home/ubuntu/.go_workspace/src/github.com/go-kit
18 - ln -s /home/ubuntu/.go_workspace/src/github.com/go-kit/kit /home/ubuntu/kit
19 - go get -t github.com/go-kit/kit/...
20 override:
21 - go test -v -race -tags integration github.com/go-kit/kit/...:
22 environment:
23 ETCD_ADDR: http://localhost:2379
24 CONSUL_ADDR: localhost:8500
25 ZK_ADDR: localhost:2181
26 EUREKA_ADDR: http://localhost:8761/eureka
2 jobs:
3 build:
4 machine: true
5 working_directory: /home/circleci/.go_workspace/src/github.com/go-kit/kit
6 environment:
7 ETCD_ADDR: http://localhost:2379
8 CONSUL_ADDR: localhost:8500
9 ZK_ADDR: localhost:2181
10 EUREKA_ADDR: http://localhost:8761/eureka
11 steps:
12 - checkout
13 - run: wget -q https://storage.googleapis.com/golang/go1.11.linux-amd64.tar.gz
14 - run: sudo rm -rf /usr/local/go
15 - run: sudo tar -C /usr/local -xzf go1.11.linux-amd64.tar.gz
16 - run: docker-compose -f docker-compose-integration.yml up -d --force-recreate
17 - run: go get -t github.com/go-kit/kit/...
18 - run: go test -v -race -tags integration github.com/go-kit/kit/...
0 # kitgen
1 kitgen is an experimental code generation utility that helps with some of the
2 boilerplate code required to implement the "onion" pattern `go-kit` utilizes.
3
4 ## Usage
5 Before using this tool please explore the [testdata]() directory for examples
6 of the inputs it requires and the outputs that will be produced. _You may not
7 need this tool._ If you are new to and just learning `go-kit` or if your use
8 case involves introducing `go-kit` to an existing codebase you are better
9 suited by slowly building out the "onion" by hand.
10
11 Before starting you need to *install* `kitgen` utility — see instructions below.
12 1. **Define** your service. Create a `.go` file with the definition of your
13 Service interface and any of the custom types it refers to:
14 ```go
15 // service.go
16 package profilesvc // don't forget to name your package
17
18 type Service interface {
19 PostProfile(ctx context.Context, p Profile) error
20 // ...
21 }
22 type Profile struct {
23 ID string `json:"id"`
24 Name string `json:"name,omitempty"`
25 // ...
26 }
27 ```
28 2. **Generate** your code. Run the following command:
29 ```sh
30 kitgen ./service.go
31 # kitgen has a couple of flags that you may find useful
32
33 # keep all code in the root directory
34 kitgen -repo-layout flat ./service.go
35
36 # put generated code elsewhere
37 kitgen -target-dir ~/Projects/gohome/src/home.com/kitchenservice/brewcoffee
38 ```
39
40 ## Installation
41 1. **Fetch** the `inlinefiles` utility. Go generate will use it to create your
42 code:
43 ```
44 go get github.com/nyarly/inlinefiles
45 ```
46 2. **Install** the binary for easy access to `kitgen`. Run the following commands:
47 ```sh
48 cd $GOPATH/src/github.com/go-kit/kit/cmd/kitgen
49 go install
50
51 # Check installation by running:
52 kitgen -h
53 ```
2525 return outer(next)
2626 }
2727 }
28
29 // Failer may be implemented by Go kit response types that contain business
30 // logic error details. If Failed returns a non-nil error, the Go kit transport
31 // layer may interpret this as a business logic error, and may encode it
32 // differently than a regular, successful response.
33 //
34 // It's not necessary for your response types to implement Failer, but it may
35 // help for more sophisticated use cases. The addsvc example shows how Failer
36 // should be used by a complete application.
37 type Failer interface {
38 Failed() error
39 }
1313 "github.com/apache/thrift/lib/go/thrift"
1414 lightstep "github.com/lightstep/lightstep-tracer-go"
1515 stdopentracing "github.com/opentracing/opentracing-go"
16 zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing"
1617 zipkin "github.com/openzipkin/zipkin-go"
17 zipkinot "github.com/openzipkin/zipkin-go-opentracing"
1818 zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
1919 "sourcegraph.com/sourcegraph/appdash"
2020 appdashot "sourcegraph.com/sourcegraph/appdash/opentracing"
4242 thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson")
4343 thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered")
4444 thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing")
45 zipkinV2URL = fs.String("zipkin-url", "", "Enable Zipkin v2 tracing (zipkin-go) via HTTP Reporter URL e.g. http://localhost:94111/api/v2/spans")
45 zipkinV2URL = fs.String("zipkin-url", "", "Enable Zipkin v2 tracing (zipkin-go) via HTTP Reporter URL e.g. http://localhost:9411/api/v2/spans")
4646 zipkinV1URL = fs.String("zipkin-v1-url", "", "Enable Zipkin v1 tracing (zipkin-go-opentracing) via a collector URL e.g. http://localhost:9411/api/v1/spans")
4747 lightstepToken = fs.String("lightstep-token", "", "Enable LightStep tracing via a LightStep access token")
4848 appdashAddr = fs.String("appdash-addr", "", "Enable Appdash tracing via an Appdash server host:port")
204204 fmt.Fprintf(os.Stdout, "%q + %q = %q\n", a, b, v)
205205
206206 default:
207 fmt.Fprintf(os.Stderr, "error: invalid method %q\n", method)
207 fmt.Fprintf(os.Stderr, "error: invalid method %q\n", *method)
208208 os.Exit(1)
209209 }
210210 }
1313 lightstep "github.com/lightstep/lightstep-tracer-go"
1414 "github.com/oklog/oklog/pkg/group"
1515 stdopentracing "github.com/opentracing/opentracing-go"
16 zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing"
1617 zipkin "github.com/openzipkin/zipkin-go"
17 zipkinot "github.com/openzipkin/zipkin-go-opentracing"
1818 zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
1919 stdprometheus "github.com/prometheus/client_golang/prometheus"
2020 "github.com/prometheus/client_golang/prometheus/promhttp"
1515 }
1616
1717 pact := dsl.Pact{
18 Port: 6666,
1918 Consumer: "addsvc",
2019 Provider: "stringsvc",
2120 }
2423 pact.AddInteraction().
2524 UponReceiving("stringsvc uppercase").
2625 WithRequest(dsl.Request{
27 Headers: map[string]string{"Content-Type": "application/json; charset=utf-8"},
26 Headers: dsl.MapMatcher{"Content-Type": dsl.String("application/json; charset=utf-8")},
2827 Method: "POST",
29 Path: "/uppercase",
28 Path: dsl.String("/uppercase"),
3029 Body: `{"s":"foo"}`,
3130 }).
3231 WillRespondWith(dsl.Response{
3332 Status: 200,
34 Headers: map[string]string{"Content-Type": "application/json; charset=utf-8"},
33 Headers: dsl.MapMatcher{"Content-Type": dsl.String("application/json; charset=utf-8")},
3534 Body: `{"v":"FOO"}`,
3635 })
3736
9797 }
9898 }
9999
100 // Failer is an interface that should be implemented by response types.
101 // Response encoders can check if responses are Failer, and if so if they've
102 // failed, and if so encode them using a separate write path based on the error.
103 type Failer interface {
104 Failed() error
105 }
100 // compile time assertions for our response types implementing endpoint.Failer.
101 var (
102 _ endpoint.Failer = SumResponse{}
103 _ endpoint.Failer = ConcatResponse{}
104 )
106105
107106 // SumRequest collects the request parameters for the Sum method.
108107 type SumRequest struct {
115114 Err error `json:"-"` // should be intercepted by Failed/errorEncoder
116115 }
117116
118 // Failed implements Failer.
117 // Failed implements endpoint.Failer.
119118 func (r SumResponse) Failed() error { return r.Err }
120119
121120 // ConcatRequest collects the request parameters for the Concat method.
129128 Err error `json:"-"`
130129 }
131130
132 // Failed implements Failer.
131 // Failed implements endpoint.Failer.
133132 func (r ConcatResponse) Failed() error { return r.Err }
106106 zipkinClient,
107107 }
108108
109 // Each individual endpoint is an http/transport.Client (which implements
109 // Each individual endpoint is an grpc/transport.Client (which implements
110110 // endpoint.Endpoint) that gets wrapped with various middlewares. If you
111111 // made your own client library, you'd do this work there, so your server
112112 // could rely on a consistent set of client behavior.
234234 // encodeHTTPGenericResponse is a transport/http.EncodeResponseFunc that encodes
235235 // the response as JSON to the response writer. Primarily useful in a server.
236236 func encodeHTTPGenericResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error {
237 if f, ok := response.(addendpoint.Failer); ok && f.Failed() != nil {
237 if f, ok := response.(endpoint.Failer); ok && f.Failed() != nil {
238238 errorEncoder(ctx, f.Failed(), w)
239239 return nil
240240 }
11
22 This example demonstrates how to use Go kit to implement a REST-y HTTP service.
33 It leverages the excellent [gorilla mux package](https://github.com/gorilla/mux) for routing.
4
5 Run the example with the optional port address for the service:
6
7 ```bash
8 $ go run ./cmd/profilesvc/main.go -http.addr :8080
9 ts=2018-05-01T16:13:12.849086255Z caller=main.go:47 transport=HTTP addr=:8080
10 ```
11
12 Create a Profile:
13
14 ```bash
15 $ curl -d '{"id":"1234","Name":"Go Kit"}' -H "Content-Type: application/json" -X POST http://localhost:8080/profiles/
16 {}
17 ```
18
19 Get the profile you just created
20
21 ```bash
22 $ curl localhost:8080/profiles/1234
23 {"profile":{"id":"1234","name":"Go Kit"}}
24 ```
6767 options := []httptransport.ClientOption{}
6868
6969 // Note that the request encoders need to modify the request URL, changing
70 // the path and method. That's fine: we simply need to provide specific
71 // encoders for each endpoint.
70 // the path. That's fine: we simply need to provide specific encoders for
71 // each endpoint.
7272
7373 return Endpoints{
7474 PostProfileEndpoint: httptransport.NewClient("POST", tgt, encodePostProfileRequest, decodePostProfileResponse, options...).Endpoint(),
216216
217217 func encodePostProfileRequest(ctx context.Context, req *http.Request, request interface{}) error {
218218 // r.Methods("POST").Path("/profiles/")
219 req.Method, req.URL.Path = "POST", "/profiles/"
219 req.URL.Path = "/profiles/"
220220 return encodeRequest(ctx, req, request)
221221 }
222222
224224 // r.Methods("GET").Path("/profiles/{id}")
225225 r := request.(getProfileRequest)
226226 profileID := url.QueryEscape(r.ID)
227 req.Method, req.URL.Path = "GET", "/profiles/"+profileID
227 req.URL.Path = "/profiles/" + profileID
228228 return encodeRequest(ctx, req, request)
229229 }
230230
232232 // r.Methods("PUT").Path("/profiles/{id}")
233233 r := request.(putProfileRequest)
234234 profileID := url.QueryEscape(r.ID)
235 req.Method, req.URL.Path = "PUT", "/profiles/"+profileID
235 req.URL.Path = "/profiles/" + profileID
236236 return encodeRequest(ctx, req, request)
237237 }
238238
240240 // r.Methods("PATCH").Path("/profiles/{id}")
241241 r := request.(patchProfileRequest)
242242 profileID := url.QueryEscape(r.ID)
243 req.Method, req.URL.Path = "PATCH", "/profiles/"+profileID
243 req.URL.Path = "/profiles/" + profileID
244244 return encodeRequest(ctx, req, request)
245245 }
246246
248248 // r.Methods("DELETE").Path("/profiles/{id}")
249249 r := request.(deleteProfileRequest)
250250 profileID := url.QueryEscape(r.ID)
251 req.Method, req.URL.Path = "DELETE", "/profiles/"+profileID
251 req.URL.Path = "/profiles/" + profileID
252252 return encodeRequest(ctx, req, request)
253253 }
254254
256256 // r.Methods("GET").Path("/profiles/{id}/addresses/")
257257 r := request.(getAddressesRequest)
258258 profileID := url.QueryEscape(r.ProfileID)
259 req.Method, req.URL.Path = "GET", "/profiles/"+profileID+"/addresses/"
259 req.URL.Path = "/profiles/" + profileID + "/addresses/"
260260 return encodeRequest(ctx, req, request)
261261 }
262262
265265 r := request.(getAddressRequest)
266266 profileID := url.QueryEscape(r.ProfileID)
267267 addressID := url.QueryEscape(r.AddressID)
268 req.Method, req.URL.Path = "GET", "/profiles/"+profileID+"/addresses/"+addressID
268 req.URL.Path = "/profiles/" + profileID + "/addresses/" + addressID
269269 return encodeRequest(ctx, req, request)
270270 }
271271
273273 // r.Methods("POST").Path("/profiles/{id}/addresses/")
274274 r := request.(postAddressRequest)
275275 profileID := url.QueryEscape(r.ProfileID)
276 req.Method, req.URL.Path = "POST", "/profiles/"+profileID+"/addresses/"
276 req.URL.Path = "/profiles/" + profileID + "/addresses/"
277277 return encodeRequest(ctx, req, request)
278278 }
279279
282282 r := request.(deleteAddressRequest)
283283 profileID := url.QueryEscape(r.ProfileID)
284284 addressID := url.QueryEscape(r.AddressID)
285 req.Method, req.URL.Path = "DELETE", "/profiles/"+profileID+"/addresses/"+addressID
285 req.URL.Path = "/profiles/" + profileID + "/addresses/" + addressID
286286 return encodeRequest(ctx, req, request)
287287 }
288288
1313
1414 // StringService provides operations on strings.
1515 type StringService interface {
16 Uppercase(context.Context, string) (string, error)
17 Count(context.Context, string) int
16 Uppercase(string) (string, error)
17 Count(string) int
1818 }
1919
2020 // stringService is a concrete implementation of StringService
2121 type stringService struct{}
2222
23 func (stringService) Uppercase(_ context.Context, s string) (string, error) {
23 func (stringService) Uppercase(s string) (string, error) {
2424 if s == "" {
2525 return "", ErrEmpty
2626 }
2727 return strings.ToUpper(s), nil
2828 }
2929
30 func (stringService) Count(_ context.Context, s string) int {
30 func (stringService) Count(s string) int {
3131 return len(s)
3232 }
3333
5454
5555 // Endpoints are a primary abstraction in go-kit. An endpoint represents a single RPC (method in our service interface)
5656 func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint {
57 return func(ctx context.Context, request interface{}) (interface{}, error) {
57 return func(_ context.Context, request interface{}) (interface{}, error) {
5858 req := request.(uppercaseRequest)
59 v, err := svc.Uppercase(ctx, req.S)
59 v, err := svc.Uppercase(req.S)
6060 if err != nil {
6161 return uppercaseResponse{v, err.Error()}, nil
6262 }
6565 }
6666
6767 func makeCountEndpoint(svc StringService) endpoint.Endpoint {
68 return func(ctx context.Context, request interface{}) (interface{}, error) {
68 return func(_ context.Context, request interface{}) (interface{}, error) {
6969 req := request.(countRequest)
70 v := svc.Count(ctx, req.S)
70 v := svc.Count(req.S)
7171 return countResponse{v}, nil
7272 }
7373 }
88 )
99
1010 func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint {
11 return func(ctx context.Context, request interface{}) (interface{}, error) {
11 return func(_ context.Context, request interface{}) (interface{}, error) {
1212 req := request.(uppercaseRequest)
1313 v, err := svc.Uppercase(req.S)
1414 if err != nil {
1919 }
2020
2121 func makeCountEndpoint(svc StringService) endpoint.Endpoint {
22 return func(ctx context.Context, request interface{}) (interface{}, error) {
22 return func(_ context.Context, request interface{}) (interface{}, error) {
2323 req := request.(countRequest)
2424 v := svc.Count(req.S)
2525 return countResponse{v}, nil
0 package main
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "log"
7 "strings"
8 "flag"
9 "net/http"
10
11 "github.com/go-kit/kit/endpoint"
12 natstransport "github.com/go-kit/kit/transport/nats"
13 httptransport "github.com/go-kit/kit/transport/http"
14
15 "github.com/nats-io/go-nats"
16 )
17
18 // StringService provides operations on strings.
19 type StringService interface {
20 Uppercase(context.Context, string) (string, error)
21 Count(context.Context, string) int
22 }
23
24 // stringService is a concrete implementation of StringService
25 type stringService struct{}
26
27 func (stringService) Uppercase(_ context.Context, s string) (string, error) {
28 if s == "" {
29 return "", ErrEmpty
30 }
31 return strings.ToUpper(s), nil
32 }
33
34 func (stringService) Count(_ context.Context, s string) int {
35 return len(s)
36 }
37
38 // ErrEmpty is returned when an input string is empty.
39 var ErrEmpty = errors.New("empty string")
40
41 // For each method, we define request and response structs
42 type uppercaseRequest struct {
43 S string `json:"s"`
44 }
45
46 type uppercaseResponse struct {
47 V string `json:"v"`
48 Err string `json:"err,omitempty"` // errors don't define JSON marshaling
49 }
50
51 type countRequest struct {
52 S string `json:"s"`
53 }
54
55 type countResponse struct {
56 V int `json:"v"`
57 }
58
59 // Endpoints are a primary abstraction in go-kit. An endpoint represents a single RPC (method in our service interface)
60 func makeUppercaseHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint {
61 return natstransport.NewPublisher(
62 nc,
63 "stringsvc.uppercase",
64 natstransport.EncodeJSONRequest,
65 decodeUppercaseResponse,
66 ).Endpoint()
67 }
68
69 func makeCountHTTPEndpoint(nc *nats.Conn) endpoint.Endpoint {
70 return natstransport.NewPublisher(
71 nc,
72 "stringsvc.count",
73 natstransport.EncodeJSONRequest,
74 decodeCountResponse,
75 ).Endpoint()
76 }
77
78 func makeUppercaseEndpoint(svc StringService) endpoint.Endpoint {
79 return func(ctx context.Context, request interface{}) (interface{}, error) {
80 req := request.(uppercaseRequest)
81 v, err := svc.Uppercase(ctx, req.S)
82 if err != nil {
83 return uppercaseResponse{v, err.Error()}, nil
84 }
85 return uppercaseResponse{v, ""}, nil
86 }
87 }
88
89 func makeCountEndpoint(svc StringService) endpoint.Endpoint {
90 return func(ctx context.Context, request interface{}) (interface{}, error) {
91 req := request.(countRequest)
92 v := svc.Count(ctx, req.S)
93 return countResponse{v}, nil
94 }
95 }
96
97 // Transports expose the service to the network. In this fourth example we utilize JSON over NATS and HTTP.
98 func main() {
99 svc := stringService{}
100
101 natsURL := flag.String("nats-url", nats.DefaultURL, "URL for connection to NATS")
102 flag.Parse()
103
104 nc, err := nats.Connect(*natsURL)
105 if err != nil {
106 log.Fatal(err)
107 }
108 defer nc.Close()
109
110 uppercaseHTTPHandler := httptransport.NewServer(
111 makeUppercaseHTTPEndpoint(nc),
112 decodeUppercaseHTTPRequest,
113 httptransport.EncodeJSONResponse,
114 )
115
116 countHTTPHandler := httptransport.NewServer(
117 makeCountHTTPEndpoint(nc),
118 decodeCountHTTPRequest,
119 httptransport.EncodeJSONResponse,
120 )
121
122 uppercaseHandler := natstransport.NewSubscriber(
123 makeUppercaseEndpoint(svc),
124 decodeUppercaseRequest,
125 natstransport.EncodeJSONResponse,
126 )
127
128 countHandler := natstransport.NewSubscriber(
129 makeCountEndpoint(svc),
130 decodeCountRequest,
131 natstransport.EncodeJSONResponse,
132 )
133
134 uSub, err := nc.QueueSubscribe("stringsvc.uppercase", "stringsvc", uppercaseHandler.ServeMsg(nc))
135 if err != nil {
136 log.Fatal(err)
137 }
138 defer uSub.Unsubscribe()
139
140 cSub, err := nc.QueueSubscribe("stringsvc.count", "stringsvc", countHandler.ServeMsg(nc))
141 if err != nil {
142 log.Fatal(err)
143 }
144 defer cSub.Unsubscribe()
145
146 http.Handle("/uppercase", uppercaseHTTPHandler)
147 http.Handle("/count", countHTTPHandler)
148 log.Fatal(http.ListenAndServe(":8080", nil))
149
150 }
151
152 func decodeUppercaseHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) {
153 var request uppercaseRequest
154 if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
155 return nil, err
156 }
157 return request, nil
158 }
159
160 func decodeCountHTTPRequest(_ context.Context, r *http.Request) (interface{}, error) {
161 var request countRequest
162 if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
163 return nil, err
164 }
165 return request, nil
166 }
167
168 func decodeUppercaseResponse(_ context.Context, msg *nats.Msg) (interface{}, error) {
169 var response uppercaseResponse
170
171 if err := json.Unmarshal(msg.Data, &response); err != nil {
172 return nil, err
173 }
174
175 return response, nil
176 }
177
178 func decodeCountResponse(_ context.Context, msg *nats.Msg) (interface{}, error) {
179 var response countResponse
180
181 if err := json.Unmarshal(msg.Data, &response); err != nil {
182 return nil, err
183 }
184
185 return response, nil
186 }
187
188 func decodeUppercaseRequest(_ context.Context, msg *nats.Msg) (interface{}, error) {
189 var request uppercaseRequest
190
191 if err := json.Unmarshal(msg.Data, &request); err != nil {
192 return nil, err
193 }
194 return request, nil
195 }
196
197 func decodeCountRequest(_ context.Context, msg *nats.Msg) (interface{}, error) {
198 var request countRequest
199
200 if err := json.Unmarshal(msg.Data, &request); err != nil {
201 return nil, err
202 }
203 return request, nil
204 }
205
0 // Package level implements leveled logging on top of package log. To use the
1 // level package, create a logger as per normal in your func main, and wrap it
2 // with level.NewFilter.
0 // Package level implements leveled logging on top of Go kit's log package. To
1 // use the level package, create a logger as per normal in your func main, and
2 // wrap it with level.NewFilter.
33 //
44 // var logger log.Logger
55 // logger = log.NewLogfmtLogger(os.Stderr)
175175 func DebugValue() Value { return debugValue }
176176
177177 var (
178 // key is of type interfae{} so that it allocates once during package
178 // key is of type interface{} so that it allocates once during package
179179 // initialization and avoids allocating every time the value is added to a
180180 // []interface{} later.
181181 key interface{} = "level"
0 // Package logrus provides an adapter to the
1 // go-kit log.Logger interface.
2 package logrus
3
4 import (
5 "errors"
6 "fmt"
7
8 "github.com/go-kit/kit/log"
9 "github.com/sirupsen/logrus"
10 )
11
12 type logrusLogger struct {
13 *logrus.Logger
14 }
15
16 var errMissingValue = errors.New("(MISSING)")
17
18 // NewLogrusLogger returns a go-kit log.Logger that sends log events to a Logrus logger.
19 func NewLogrusLogger(logger *logrus.Logger) log.Logger {
20 return &logrusLogger{logger}
21 }
22
23 func (l logrusLogger) Log(keyvals ...interface{}) error {
24 fields := logrus.Fields{}
25 for i := 0; i < len(keyvals); i += 2 {
26 if i+1 < len(keyvals) {
27 fields[fmt.Sprint(keyvals[i])] = keyvals[i+1]
28 } else {
29 fields[fmt.Sprint(keyvals[i])] = errMissingValue
30 }
31 }
32 l.WithFields(fields).Info()
33 return nil
34 }
0 package logrus_test
1
2 import (
3 "bytes"
4 "errors"
5 "strings"
6 "testing"
7
8 log "github.com/go-kit/kit/log/logrus"
9 "github.com/sirupsen/logrus"
10 )
11
12 func TestLogrusLogger(t *testing.T) {
13 t.Parallel()
14 buf := &bytes.Buffer{}
15 logrusLogger := logrus.New()
16 logrusLogger.Out = buf
17 logrusLogger.Formatter = &logrus.TextFormatter{TimestampFormat: "02-01-2006 15:04:05", FullTimestamp: true}
18 logger := log.NewLogrusLogger(logrusLogger)
19
20 if err := logger.Log("hello", "world"); err != nil {
21 t.Fatal(err)
22 }
23 if want, have := "hello=world\n", strings.Split(buf.String(), " ")[3]; want != have {
24 t.Errorf("want %#v, have %#v", want, have)
25 }
26
27 buf.Reset()
28 if err := logger.Log("a", 1, "err", errors.New("error")); err != nil {
29 t.Fatal(err)
30 }
31 if want, have := "a=1 err=error", strings.TrimSpace(strings.SplitAfterN(buf.String(), " ", 4)[3]); want != have {
32 t.Errorf("want %#v, have %#v", want, have)
33 }
34
35 buf.Reset()
36 if err := logger.Log("a", 1, "b"); err != nil {
37 t.Fatal(err)
38 }
39 if want, have := "a=1 b=\"(MISSING)\"", strings.TrimSpace(strings.SplitAfterN(buf.String(), " ", 4)[3]); want != have {
40 t.Errorf("want %#v, have %#v", want, have)
41 }
42
43 buf.Reset()
44 if err := logger.Log("my_map", mymap{0: 0}); err != nil {
45 t.Fatal(err)
46 }
47 if want, have := "my_map=special_behavior", strings.TrimSpace(strings.Split(buf.String(), " ")[3]); want != have {
48 t.Errorf("want %#v, have %#v", want, have)
49 }
50 }
51
52 type mymap map[int]int
53
54 func (m mymap) String() string { return "special_behavior" }
0 // +build !windows
1 // +build !plan9
2 // +build !nacl
3
04 package syslog_test
15
26 import (
0 // +build !windows
1 // +build !plan9
2 // +build !nacl
3
04 package syslog
15
26 import (
0 // +build !windows
1 // +build !plan9
2 // +build !nacl
3
04 package syslog
15
26 import (
00 // The code in this file is adapted from github.com/mattn/go-colorable.
1
2 // +build windows
31
42 package term
53
44 "syscall"
55 "testing"
66 )
7
8 // +build windows
97
108 type myWriter struct {
119 fd uintptr
00 package log
11
22 import (
3 "runtime"
4 "strconv"
5 "strings"
36 "time"
4
5 "github.com/go-stack/stack"
67 )
78
89 // A Valuer generates a log value. When passed to With or WithPrefix in a
8081 // Caller returns a Valuer that returns a file and line from a specified depth
8182 // in the callstack. Users will probably want to use DefaultCaller.
8283 func Caller(depth int) Valuer {
83 return func() interface{} { return stack.Caller(depth) }
84 return func() interface{} {
85 _, file, line, _ := runtime.Caller(depth)
86 idx := strings.LastIndexByte(file, '/')
87 // using idx+1 below handles both of following cases:
88 // idx == -1 because no "/" was found, or
89 // idx >= 0 and we want to start at the character after the found "/".
90 return file[idx+1:] + ":" + strconv.Itoa(line)
91 }
8492 }
8593
8694 var (
metrics/debug.test less more
Binary diff not shown
0 // Package influxstatsd provides support for InfluxData's StatsD Telegraf plugin. It's very
1 // similar to StatsD, but supports arbitrary tags per-metric, which map to Go
2 // kit's label values. So, while label values are no-ops in StatsD, they are
3 // supported here. For more details, see the article at
4 // https://www.influxdata.com/blog/getting-started-with-sending-statsd-metrics-to-telegraf-influxdb/
5 //
6 // This package batches observations and emits them on some schedule to the
7 // remote server. This is useful even if you connect to your service
8 // over UDP. Emitting one network packet per observation can quickly overwhelm
9 // even the fastest internal network.
10 package influxstatsd
11
12 import (
13 "fmt"
14 "io"
15 "strings"
16 "sync"
17 "sync/atomic"
18 "time"
19
20 "github.com/go-kit/kit/log"
21 "github.com/go-kit/kit/metrics"
22 "github.com/go-kit/kit/metrics/generic"
23 "github.com/go-kit/kit/metrics/internal/lv"
24 "github.com/go-kit/kit/metrics/internal/ratemap"
25 "github.com/go-kit/kit/util/conn"
26 )
27
28 // Influxstatsd receives metrics observations and forwards them to a server.
29 // Create a Influxstatsd object, use it to create metrics, and pass those
30 // metrics as dependencies to the components that will use them.
31 //
32 // All metrics are buffered until WriteTo is called. Counters and gauges are
33 // aggregated into a single observation per timeseries per write. Timings and
34 // histograms are buffered but not aggregated.
35 //
36 // To regularly report metrics to an io.Writer, use the WriteLoop helper method.
37 // To send to a InfluxStatsD server, use the SendLoop helper method.
38 type Influxstatsd struct {
39 mtx sync.RWMutex
40 prefix string
41 rates *ratemap.RateMap
42 counters *lv.Space
43 gauges map[string]*gaugeNode
44 timings *lv.Space
45 histograms *lv.Space
46 logger log.Logger
47 lvs lv.LabelValues
48 }
49
50 // New returns a Influxstatsd object that may be used to create metrics. Prefix is
51 // applied to all created metrics. Callers must ensure that regular calls to
52 // WriteTo are performed, either manually or with one of the helper methods.
53 func New(prefix string, logger log.Logger, lvs ...string) *Influxstatsd {
54 if len(lvs)%2 != 0 {
55 panic("odd number of LabelValues; programmer error!")
56 }
57 return &Influxstatsd{
58 prefix: prefix,
59 rates: ratemap.New(),
60 counters: lv.NewSpace(),
61 gauges: map[string]*gaugeNode{}, // https://github.com/go-kit/kit/pull/588
62 timings: lv.NewSpace(),
63 histograms: lv.NewSpace(),
64 logger: logger,
65 lvs: lvs,
66 }
67 }
68
69 // NewCounter returns a counter, sending observations to this Influxstatsd object.
70 func (d *Influxstatsd) NewCounter(name string, sampleRate float64) *Counter {
71 d.rates.Set(name, sampleRate)
72 return &Counter{
73 name: name,
74 obs: d.counters.Observe,
75 }
76 }
77
78 // NewGauge returns a gauge, sending observations to this Influxstatsd object.
79 func (d *Influxstatsd) NewGauge(name string) *Gauge {
80 d.mtx.Lock()
81 n, ok := d.gauges[name]
82 if !ok {
83 n = &gaugeNode{gauge: &Gauge{g: generic.NewGauge(name), influx: d}}
84 d.gauges[name] = n
85 }
86 d.mtx.Unlock()
87 return n.gauge
88 }
89
90 // NewTiming returns a histogram whose observations are interpreted as
91 // millisecond durations, and are forwarded to this Influxstatsd object.
92 func (d *Influxstatsd) NewTiming(name string, sampleRate float64) *Timing {
93 d.rates.Set(name, sampleRate)
94 return &Timing{
95 name: name,
96 obs: d.timings.Observe,
97 }
98 }
99
100 // NewHistogram returns a histogram whose observations are of an unspecified
101 // unit, and are forwarded to this Influxstatsd object.
102 func (d *Influxstatsd) NewHistogram(name string, sampleRate float64) *Histogram {
103 d.rates.Set(name, sampleRate)
104 return &Histogram{
105 name: name,
106 obs: d.histograms.Observe,
107 }
108 }
109
110 // WriteLoop is a helper method that invokes WriteTo to the passed writer every
111 // time the passed channel fires. This method blocks until the channel is
112 // closed, so clients probably want to run it in its own goroutine. For typical
113 // usage, create a time.Ticker and pass its C channel to this method.
114 func (d *Influxstatsd) WriteLoop(c <-chan time.Time, w io.Writer) {
115 for range c {
116 if _, err := d.WriteTo(w); err != nil {
117 d.logger.Log("during", "WriteTo", "err", err)
118 }
119 }
120 }
121
122 // SendLoop is a helper method that wraps WriteLoop, passing a managed
123 // connection to the network and address. Like WriteLoop, this method blocks
124 // until the channel is closed, so clients probably want to start it in its own
125 // goroutine. For typical usage, create a time.Ticker and pass its C channel to
126 // this method.
127 func (d *Influxstatsd) SendLoop(c <-chan time.Time, network, address string) {
128 d.WriteLoop(c, conn.NewDefaultManager(network, address, d.logger))
129 }
130
131 // WriteTo flushes the buffered content of the metrics to the writer, in
132 // InfluxStatsD format. WriteTo abides best-effort semantics, so observations are
133 // lost if there is a problem with the write. Clients should be sure to call
134 // WriteTo regularly, ideally through the WriteLoop or SendLoop helper methods.
135 func (d *Influxstatsd) WriteTo(w io.Writer) (count int64, err error) {
136 var n int
137
138 d.counters.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool {
139 n, err = fmt.Fprintf(w, "%s%s%s:%f|c%s\n", d.prefix, name, d.tagValues(lvs), sum(values), sampling(d.rates.Get(name)))
140 if err != nil {
141 return false
142 }
143 count += int64(n)
144 return true
145 })
146 if err != nil {
147 return count, err
148 }
149
150 d.mtx.RLock()
151 for _, root := range d.gauges {
152 root.walk(func(name string, lvs lv.LabelValues, value float64) bool {
153 n, err = fmt.Fprintf(w, "%s%s%s:%f|g\n", d.prefix, name, d.tagValues(lvs), value)
154 if err != nil {
155 return false
156 }
157 count += int64(n)
158 return true
159 })
160 }
161 d.mtx.RUnlock()
162
163 d.timings.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool {
164 sampleRate := d.rates.Get(name)
165 for _, value := range values {
166 n, err = fmt.Fprintf(w, "%s%s%s:%f|ms%s\n", d.prefix, name, d.tagValues(lvs), value, sampling(sampleRate))
167 if err != nil {
168 return false
169 }
170 count += int64(n)
171 }
172 return true
173 })
174 if err != nil {
175 return count, err
176 }
177
178 d.histograms.Reset().Walk(func(name string, lvs lv.LabelValues, values []float64) bool {
179 sampleRate := d.rates.Get(name)
180 for _, value := range values {
181 n, err = fmt.Fprintf(w, "%s%s%s:%f|h%s\n", d.prefix, name, d.tagValues(lvs), value, sampling(sampleRate))
182 if err != nil {
183 return false
184 }
185 count += int64(n)
186 }
187 return true
188 })
189 if err != nil {
190 return count, err
191 }
192
193 return count, err
194 }
195
196 func sum(a []float64) float64 {
197 var v float64
198 for _, f := range a {
199 v += f
200 }
201 return v
202 }
203
204 func last(a []float64) float64 {
205 return a[len(a)-1]
206 }
207
208 func sampling(r float64) string {
209 var sv string
210 if r < 1.0 {
211 sv = fmt.Sprintf("|@%f", r)
212 }
213 return sv
214 }
215
216 func (d *Influxstatsd) tagValues(labelValues []string) string {
217 if len(labelValues) == 0 && len(d.lvs) == 0 {
218 return ""
219 }
220 if len(labelValues)%2 != 0 {
221 panic("tagValues received a labelValues with an odd number of strings")
222 }
223 pairs := make([]string, 0, (len(d.lvs)+len(labelValues))/2)
224 for i := 0; i < len(d.lvs); i += 2 {
225 pairs = append(pairs, d.lvs[i]+"="+d.lvs[i+1])
226 }
227 for i := 0; i < len(labelValues); i += 2 {
228 pairs = append(pairs, labelValues[i]+"="+labelValues[i+1])
229 }
230 return "," + strings.Join(pairs, ",")
231 }
232
233 type observeFunc func(name string, lvs lv.LabelValues, value float64)
234
235 // Counter is a InfluxStatsD counter. Observations are forwarded to a Influxstatsd
236 // object, and aggregated (summed) per timeseries.
237 type Counter struct {
238 name string
239 lvs lv.LabelValues
240 obs observeFunc
241 }
242
243 // With implements metrics.Counter.
244 func (c *Counter) With(labelValues ...string) metrics.Counter {
245 return &Counter{
246 name: c.name,
247 lvs: c.lvs.With(labelValues...),
248 obs: c.obs,
249 }
250 }
251
252 // Add implements metrics.Counter.
253 func (c *Counter) Add(delta float64) {
254 c.obs(c.name, c.lvs, delta)
255 }
256
257 // Gauge is a InfluxStatsD gauge. Observations are forwarded to a Influxstatsd
258 // object, and aggregated (the last observation selected) per timeseries.
259 type Gauge struct {
260 g *generic.Gauge
261 influx *Influxstatsd
262 set int32
263 }
264
265 // With implements metrics.Gauge.
266 func (g *Gauge) With(labelValues ...string) metrics.Gauge {
267 g.influx.mtx.RLock()
268 node := g.influx.gauges[g.g.Name]
269 g.influx.mtx.RUnlock()
270
271 ga := &Gauge{g: g.g.With(labelValues...).(*generic.Gauge), influx: g.influx}
272 return node.addGauge(ga, ga.g.LabelValues())
273 }
274
275 // Set implements metrics.Gauge.
276 func (g *Gauge) Set(value float64) {
277 g.g.Set(value)
278 g.touch()
279 }
280
281 // Add implements metrics.Gauge.
282 func (g *Gauge) Add(delta float64) {
283 g.g.Add(delta)
284 g.touch()
285 }
286
287 // Timing is a InfluxStatsD timing, or metrics.Histogram. Observations are
288 // forwarded to a Influxstatsd object, and collected (but not aggregated) per
289 // timeseries.
290 type Timing struct {
291 name string
292 lvs lv.LabelValues
293 obs observeFunc
294 }
295
296 // With implements metrics.Timing.
297 func (t *Timing) With(labelValues ...string) metrics.Histogram {
298 return &Timing{
299 name: t.name,
300 lvs: t.lvs.With(labelValues...),
301 obs: t.obs,
302 }
303 }
304
305 // Observe implements metrics.Histogram. Value is interpreted as milliseconds.
306 func (t *Timing) Observe(value float64) {
307 t.obs(t.name, t.lvs, value)
308 }
309
310 // Histogram is a InfluxStatsD histrogram. Observations are forwarded to a
311 // Influxstatsd object, and collected (but not aggregated) per timeseries.
312 type Histogram struct {
313 name string
314 lvs lv.LabelValues
315 obs observeFunc
316 }
317
318 // With implements metrics.Histogram.
319 func (h *Histogram) With(labelValues ...string) metrics.Histogram {
320 return &Histogram{
321 name: h.name,
322 lvs: h.lvs.With(labelValues...),
323 obs: h.obs,
324 }
325 }
326
327 // Observe implements metrics.Histogram.
328 func (h *Histogram) Observe(value float64) {
329 h.obs(h.name, h.lvs, value)
330 }
331
332 type pair struct{ label, value string }
333
334 type gaugeNode struct {
335 mtx sync.RWMutex
336 gauge *Gauge
337 children map[pair]*gaugeNode
338 }
339
340 func (n *gaugeNode) addGauge(g *Gauge, lvs lv.LabelValues) *Gauge {
341 n.mtx.Lock()
342 defer n.mtx.Unlock()
343 if len(lvs) == 0 {
344 if n.gauge == nil {
345 n.gauge = g
346 }
347 return n.gauge
348 }
349 if len(lvs) < 2 {
350 panic("too few LabelValues; programmer error!")
351 }
352 head, tail := pair{lvs[0], lvs[1]}, lvs[2:]
353 if n.children == nil {
354 n.children = map[pair]*gaugeNode{}
355 }
356 child, ok := n.children[head]
357 if !ok {
358 child = &gaugeNode{}
359 n.children[head] = child
360 }
361 return child.addGauge(g, tail)
362 }
363
364 func (n *gaugeNode) walk(fn func(string, lv.LabelValues, float64) bool) bool {
365 n.mtx.RLock()
366 defer n.mtx.RUnlock()
367 if n.gauge != nil {
368 value, ok := n.gauge.read()
369 if ok && !fn(n.gauge.g.Name, n.gauge.g.LabelValues(), value) {
370 return false
371 }
372 }
373 for _, child := range n.children {
374 if !child.walk(fn) {
375 return false
376 }
377 }
378 return true
379 }
380
381 func (g *Gauge) touch() {
382 atomic.StoreInt32(&(g.set), 1)
383 }
384
385 func (g *Gauge) read() (float64, bool) {
386 set := atomic.SwapInt32(&(g.set), 0)
387 return g.g.Value(), set != 0
388 }
0 package influxstatsd
1
2 import (
3 "testing"
4
5 "github.com/go-kit/kit/log"
6 "github.com/go-kit/kit/metrics/teststat"
7 )
8
9 func TestCounter(t *testing.T) {
10 prefix, name := "abc.", "def"
11 label, value := "label", "value"
12 regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|c$`
13 d := New(prefix, log.NewNopLogger())
14 counter := d.NewCounter(name, 1.0).With(label, value)
15 valuef := teststat.SumLines(d, regex)
16 if err := teststat.TestCounter(counter, valuef); err != nil {
17 t.Fatal(err)
18 }
19 }
20
21 func TestCounterSampled(t *testing.T) {
22 // This will involve multiplying the observed sum by the inverse of the
23 // sample rate and checking against the expected value within some
24 // tolerance.
25 t.Skip("TODO")
26 }
27
28 func TestGauge(t *testing.T) {
29 prefix, name := "ghi.", "jkl"
30 label, value := "xyz", "abc"
31 regex := `^` + prefix + name + `,hostname=foohost,` + label + `=` + value + `:([0-9\.]+)\|g$`
32 d := New(prefix, log.NewNopLogger(), "hostname", "foohost")
33 gauge := d.NewGauge(name).With(label, value)
34 valuef := teststat.LastLine(d, regex)
35 if err := teststat.TestGauge(gauge, valuef); err != nil {
36 t.Fatal(err)
37 }
38 }
39
40 // InfluxStatsD histograms just emit all observations. So, we collect them into
41 // a generic histogram, and run the statistics test on that.
42
43 func TestHistogram(t *testing.T) {
44 prefix, name := "influxstatsd.", "histogram_test"
45 label, value := "abc", "def"
46 regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|h$`
47 d := New(prefix, log.NewNopLogger())
48 histogram := d.NewHistogram(name, 1.0).With(label, value)
49 quantiles := teststat.Quantiles(d, regex, 50) // no |@0.X
50 if err := teststat.TestHistogram(histogram, quantiles, 0.01); err != nil {
51 t.Fatal(err)
52 }
53 }
54
55 func TestHistogramSampled(t *testing.T) {
56 prefix, name := "influxstatsd.", "sampled_histogram_test"
57 label, value := "foo", "bar"
58 regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|h\|@0\.01[0]*$`
59 d := New(prefix, log.NewNopLogger())
60 histogram := d.NewHistogram(name, 0.01).With(label, value)
61 quantiles := teststat.Quantiles(d, regex, 50)
62 if err := teststat.TestHistogram(histogram, quantiles, 0.02); err != nil {
63 t.Fatal(err)
64 }
65 }
66
67 func TestTiming(t *testing.T) {
68 prefix, name := "influxstatsd.", "timing_test"
69 label, value := "wiggle", "bottom"
70 regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|ms$`
71 d := New(prefix, log.NewNopLogger())
72 histogram := d.NewTiming(name, 1.0).With(label, value)
73 quantiles := teststat.Quantiles(d, regex, 50) // no |@0.X
74 if err := teststat.TestHistogram(histogram, quantiles, 0.01); err != nil {
75 t.Fatal(err)
76 }
77 }
78
79 func TestTimingSampled(t *testing.T) {
80 prefix, name := "influxstatsd.", "sampled_timing_test"
81 label, value := "internal", "external"
82 regex := `^` + prefix + name + "," + label + `=` + value + `:([0-9\.]+)\|ms\|@0.03[0]*$`
83 d := New(prefix, log.NewNopLogger())
84 histogram := d.NewTiming(name, 0.03).With(label, value)
85 quantiles := teststat.Quantiles(d, regex, 50)
86 if err := teststat.TestHistogram(histogram, quantiles, 0.02); err != nil {
87 t.Fatal(err)
88 }
89 }
1616 if want, have := name, low.Name; want != have {
1717 t.Errorf("Name: want %q, have %q", want, have)
1818 }
19 value := func() float64 { return low.Value() }
20 if err := teststat.TestCounter(top, value); err != nil {
19 if err := teststat.TestCounter(top, low.Value); err != nil {
2120 t.Fatal(err)
2221 }
2322 }
3231 if want, have := name, low.Name; want != have {
3332 t.Errorf("Name: want %q, have %q", want, have)
3433 }
35 value := func() float64 { return low.Value() }
36 if err := teststat.TestCounter(top, value); err != nil {
34 if err := teststat.TestCounter(top, low.Value); err != nil {
3735 t.Fatal(err)
3836 }
3937 }
3939
4040 have := map[string]float64{}
4141 s.Walk(func(name string, lvs LabelValues, obs []float64) bool {
42 //t.Logf("%s %v => %v", name, lvs, obs)
4342 have[name+" ["+strings.Join(lvs, "")+"]"] += sum(obs)
4443 return true
4544 })
1111 "strings"
1212 "testing"
1313
14 "github.com/go-kit/kit/metrics/teststat"
1415 stdprometheus "github.com/prometheus/client_golang/prometheus"
15
16 "github.com/go-kit/kit/metrics/teststat"
1716 )
1817
1918 func TestCounter(t *testing.T) {
197196 if !ok {
198197 t.Fatalf("expected error, got %s", reflect.TypeOf(x))
199198 }
200 if want, have := "inconsistent label cardinality", err.Error(); want != have {
199 if want, have := "inconsistent label cardinality", err.Error(); !strings.HasPrefix(have, want) {
201200 t.Fatalf("want %q, have %q", want, have)
202201 }
203202 }()
4646 re := regexp.MustCompile(regex)
4747 buf := &bytes.Buffer{}
4848 w.WriteTo(buf)
49 //fmt.Fprintf(os.Stderr, "%s\n", buf.String())
5049 s := bufio.NewScanner(buf)
5150 for s.Scan() {
5251 match := re.FindStringSubmatch(s.Text())
6363 z = math.Sqrt(-math.Log((1.0 - y) / 2.0))
6464 x = (((c[3]*z+c[2])*z+c[1])*z + c[0]) / ((d[1]*z+d[0])*z + 1.0)
6565 }
66 x = x - (math.Erf(x)-y)/(2.0/math.SqrtPi*math.Exp(-x*x))
67 x = x - (math.Erf(x)-y)/(2.0/math.SqrtPi*math.Exp(-x*x))
66 x -= (math.Erf(x) - y) / (2.0 / math.SqrtPi * math.Exp(-x*x))
67 x -= (math.Erf(x) - y) / (2.0 / math.SqrtPi * math.Exp(-x*x))
6868 }
6969
7070 return x
00 package consul
11
22 import (
3 "errors"
34 "fmt"
4 "io"
5 "time"
56
67 consul "github.com/hashicorp/consul/api"
78
89 "github.com/go-kit/kit/log"
910 "github.com/go-kit/kit/sd"
1011 "github.com/go-kit/kit/sd/internal/instance"
12 "github.com/go-kit/kit/util/conn"
1113 )
1214
1315 const defaultIndex = 0
16
17 // errStopped notifies the loop to quit. aka stopped via quitc
18 var errStopped = errors.New("quit and closed consul instancer")
1419
1520 // Instancer yields instances for a service in Consul.
1621 type Instancer struct {
5863 var (
5964 instances []string
6065 err error
66 d time.Duration = 10 * time.Millisecond
6167 )
6268 for {
6369 instances, lastIndex, err = s.getInstances(lastIndex, s.quitc)
6470 switch {
65 case err == io.EOF:
71 case err == errStopped:
6672 return // stopped via quitc
6773 case err != nil:
6874 s.logger.Log("err", err)
75 time.Sleep(d)
76 d = conn.Exponential(d)
6977 s.cache.Update(sd.Event{Err: err})
7078 default:
7179 s.cache.Update(sd.Event{Instances: instances})
80 d = 10 * time.Millisecond
7281 }
7382 }
7483 }
118127 case res := <-resc:
119128 return res.instances, res.index, nil
120129 case <-interruptc:
121 return nil, 0, io.EOF
130 return nil, 0, errStopped
122131 }
123132 }
124133
11
22 import (
33 "context"
4 consul "github.com/hashicorp/consul/api"
5 "io"
46 "testing"
5
6 consul "github.com/hashicorp/consul/api"
7 "time"
78
89 "github.com/go-kit/kit/log"
910 "github.com/go-kit/kit/sd"
130131 t.Errorf("want %q, have %q", want, have)
131132 }
132133 }
134
135 type eofTestClient struct {
136 client *testClient
137 eofSig chan bool
138 called chan struct{}
139 }
140
141 func neweofTestClient(client *testClient, sig chan bool, called chan struct{}) Client {
142 return &eofTestClient{client: client, eofSig: sig, called: called}
143 }
144
145 func (c *eofTestClient) Register(r *consul.AgentServiceRegistration) error {
146 return c.client.Register(r)
147 }
148
149 func (c *eofTestClient) Deregister(r *consul.AgentServiceRegistration) error {
150 return c.client.Deregister(r)
151 }
152
153 func (c *eofTestClient) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) {
154 c.called <- struct{}{}
155 shouldEOF := <-c.eofSig
156 if shouldEOF {
157 return nil, &consul.QueryMeta{}, io.EOF
158 }
159 return c.client.Service(service, tag, passingOnly, queryOpts)
160 }
161
162 func TestInstancerWithEOF(t *testing.T) {
163 var (
164 sig = make(chan bool, 1)
165 called = make(chan struct{}, 1)
166 logger = log.NewNopLogger()
167 client = neweofTestClient(newTestClient(consulState), sig, called)
168 )
169
170 sig <- false
171 s := NewInstancer(client, logger, "search", []string{"api"}, true)
172 defer s.Stop()
173
174 select {
175 case <-called:
176 case <-time.Tick(time.Millisecond * 500):
177 t.Error("failed, to receive call")
178 }
179
180 state := s.cache.State()
181 if want, have := 2, len(state.Instances); want != have {
182 t.Errorf("want %d, have %d", want, have)
183 }
184
185 // some error occurred resulting in io.EOF
186 sig <- true
187
188 // Service Called Once
189 select {
190 case <-called:
191 case <-time.Tick(time.Millisecond * 500):
192 t.Error("failed, to receive call in time")
193 }
194
195 sig <- false
196
197 // loop should continue
198 select {
199 case <-called:
200 case <-time.Tick(time.Millisecond * 500):
201 t.Error("failed, to receive call in time")
202 }
203 }
1616 func TestIntegration(t *testing.T) {
1717 consulAddr := os.Getenv("CONSUL_ADDR")
1818 if consulAddr == "" {
19 t.Fatal("CONSUL_ADDR is not set")
19 t.Skip("CONSUL_ADDR not set; skipping integration test")
2020 }
2121 stdClient, err := stdconsul.NewClient(&stdconsul.Config{
2222 Address: consulAddr,
99 "net/http"
1010 "time"
1111
12 etcd "github.com/coreos/etcd/client"
12 etcd "go.etcd.io/etcd/client"
1313 )
1414
1515 var (
77
88 "golang.org/x/net/context"
99
10 etcd "github.com/coreos/etcd/client"
10 etcd "go.etcd.io/etcd/client"
1111 )
1212
1313 func TestNewClient(t *testing.T) {
33 "errors"
44 "testing"
55
6 stdetcd "github.com/coreos/etcd/client"
6 stdetcd "go.etcd.io/etcd/client"
77
88 "github.com/go-kit/kit/log"
99 "github.com/go-kit/kit/sd"
33 "sync"
44 "time"
55
6 etcd "github.com/coreos/etcd/client"
6 etcd "go.etcd.io/etcd/client"
77
88 "github.com/go-kit/kit/log"
99 )
55 "errors"
66 "time"
77
8 "github.com/coreos/etcd/clientv3"
9 "github.com/coreos/etcd/pkg/transport"
8 "go.etcd.io/etcd/clientv3"
9 "go.etcd.io/etcd/pkg/transport"
1010 )
1111
1212 var (
227227 }
228228 if c.watcher != nil {
229229 c.watcher.Close()
230 }
231 if c.wcf != nil {
230232 c.wcf()
231233 }
232234 }
3333
3434 // Register our instance.
3535 registrar.Register()
36 t.Logf("Registered")
36 t.Log("Registered")
3737
3838 // Retrieve entries from etcd manually.
3939 entries, err = client.GetEntries(settings.key)
5555 if err != nil {
5656 t.Fatalf("NewInstancer: %v", err)
5757 }
58 t.Logf("Constructed Instancer OK")
58 t.Log("Constructed Instancer OK")
5959 defer instancer.Stop()
6060
6161 endpointer := sd.NewEndpointer(
6363 func(string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, nil, nil },
6464 log.With(log.NewLogfmtLogger(os.Stderr), "component", "instancer"),
6565 )
66 t.Logf("Constructed Endpointer OK")
66 t.Log("Constructed Endpointer OK")
6767 defer endpointer.Close()
6868
6969 if !within(time.Second, func() bool {
7070 endpoints, err := endpointer.Endpoints()
7171 return err == nil && len(endpoints) == 1
7272 }) {
73 t.Fatalf("Endpointer didn't see Register in time")
74 }
75 t.Logf("Endpointer saw Register OK")
73 t.Fatal("Endpointer didn't see Register in time")
74 }
75 t.Log("Endpointer saw Register OK")
7676
7777 // Deregister first instance of test data.
7878 registrar.Deregister()
79 t.Logf("Deregistered")
79 t.Log("Deregistered")
8080
8181 // Check it was deregistered.
8282 if !within(time.Second, func() bool {
163163 runIntegration(settings, client, service, t)
164164 }
165165
166 func TestIntegrationRegistrarOnly(t *testing.T) {
167 settings := testIntegrationSettings(t)
168 client, err := NewClient(context.Background(), []string{settings.addr}, ClientOptions{
169 DialTimeout: 2 * time.Second,
170 DialKeepAlive: 2 * time.Second,
171 })
172 if err != nil {
173 t.Fatalf("NewClient(%q): %v", settings.addr, err)
174 }
175
176 service := Service{
177 Key: settings.key,
178 Value: settings.value,
179 TTL: NewTTLOption(time.Second*3, time.Second*10),
180 }
181 defer client.Deregister(service)
182
183 // Verify test data is initially empty.
184 entries, err := client.GetEntries(settings.key)
185 if err != nil {
186 t.Fatalf("GetEntries(%q): expected no error, got one: %v", settings.key, err)
187 }
188 if len(entries) > 0 {
189 t.Fatalf("GetEntries(%q): expected no instance entries, got %d", settings.key, len(entries))
190 }
191 t.Logf("GetEntries(%q): %v (OK)", settings.key, entries)
192
193 // Instantiate a new Registrar, passing in test data.
194 registrar := NewRegistrar(
195 client,
196 service,
197 log.With(log.NewLogfmtLogger(os.Stderr), "component", "registrar"),
198 )
199
200 // Register our instance.
201 registrar.Register()
202 t.Log("Registered")
203
204 // Deregister our instance. (so we test registrar only scenario)
205 registrar.Deregister()
206 t.Log("Deregistered")
207
208 }
209
166210 func within(d time.Duration, f func() bool) bool {
167211 deadline := time.Now().Add(d)
168212 for time.Now().Before(deadline) {
33
44 import (
55 "bytes"
6 "log"
76 "os"
87 "testing"
98 "time"
1716
1817 func TestMain(m *testing.M) {
1918 zkAddr := os.Getenv("ZK_ADDR")
20 if zkAddr == "" {
21 log.Fatal("ZK_ADDR is not set")
22 }
23 host = []string{zkAddr}
19 if zkAddr != "" {
20 host = []string{zkAddr}
21 }
22 m.Run()
2423 }
2524
2625 func TestCreateParentNodesOnServer(t *testing.T) {
26 if len(host) == 0 {
27 t.Skip("ZK_ADDR not set; skipping integration test")
28 }
2729 payload := [][]byte{[]byte("Payload"), []byte("Test")}
2830 c1, err := NewClient(host, logger, Payload(payload))
2931 if err != nil {
6668 }
6769
6870 func TestCreateBadParentNodesOnServer(t *testing.T) {
71 if len(host) == 0 {
72 t.Skip("ZK_ADDR not set; skipping integration test")
73 }
6974 c, _ := NewClient(host, logger)
7075 defer c.Stop()
7176
7782 }
7883
7984 func TestCredentials1(t *testing.T) {
85 if len(host) == 0 {
86 t.Skip("ZK_ADDR not set; skipping integration test")
87 }
8088 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
8189 c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret"))
8290 defer c.Stop()
8997 }
9098
9199 func TestCredentials2(t *testing.T) {
100 if len(host) == 0 {
101 t.Skip("ZK_ADDR not set; skipping integration test")
102 }
92103 acl := stdzk.DigestACL(stdzk.PermAll, "user", "secret")
93104 c, _ := NewClient(host, logger, ACL(acl))
94105 defer c.Stop()
101112 }
102113
103114 func TestConnection(t *testing.T) {
115 if len(host) == 0 {
116 t.Skip("ZK_ADDR not set; skipping integration test")
117 }
104118 c, _ := NewClient(host, logger)
105119 c.Stop()
106120
112126 }
113127
114128 func TestGetEntriesOnServer(t *testing.T) {
129 if len(host) == 0 {
130 t.Skip("ZK_ADDR not set; skipping integration test")
131 }
115132 var instancePayload = "10.0.3.204:8002"
116133
117134 c1, err := NewClient(host, logger)
157174 }
158175
159176 func TestGetEntriesPayloadOnServer(t *testing.T) {
177 if len(host) == 0 {
178 t.Skip("ZK_ADDR not set; skipping integration test")
179 }
160180 c, err := NewClient(host, logger)
161181 if err != nil {
162182 t.Fatalf("Connect returned error: %v", err)
1818 binding to use. Instrumentation exists for `kit/transport/http` and
1919 `kit/transport/grpc`. The bindings are highlighted in the [addsvc] example. For
2020 more information regarding Zipkin feel free to visit [Zipkin's Gitter].
21
22 ## OpenCensus
23
24 Go kit supports transport and endpoint middlewares for the [OpenCensus]
25 instrumentation library. OpenCensus provides a cross language consistent data
26 model and instrumentation libraries for tracing and metrics. From this data
27 model it allows exports to various tracing and metrics backends including but
28 not limited to Zipkin, Prometheus, Stackdriver Trace & Monitoring, Jaeger,
29 AWS X-Ray and Datadog. Go kit uses the [opencensus-go] implementation to power
30 its middlewares.
2131
2232 ## OpenTracing
2333
6171
6272 [Zipkin]: http://zipkin.io/
6373 [Open Zipkin GitHub]: https://github.com/openzipkin
64 [zipkin-go-opentracing]: https://github.com/openzipkin/zipkin-go-opentracing
74 [zipkin-go-opentracing]: https://github.com/openzipkin-contrib/zipkin-go-opentracing
6575 [zipkin-go]: https://github.com/openzipkin/zipkin-go
6676 [Zipkin's Gitter]: https://gitter.im/openzipkin/zipkin
6777
7080
7181 [LightStep]: http://lightstep.com/
7282 [lightstep-tracer-go]: https://github.com/lightstep/lightstep-tracer-go
83
84 [OpenCensus]: https://opencensus.io/
85 [opencensus-go]: https://github.com/census-instrumentation/opencensus-go
0 // Package opencensus provides Go kit integration to the OpenCensus project.
1 // OpenCensus is a single distribution of libraries for metrics and distributed
2 // tracing with minimal overhead that allows you to export data to multiple
3 // backends. The Go kit OpenCencus package as provided here contains middlewares
4 // for tracing.
5 package opencensus
0 package opencensus
1
2 import (
3 "context"
4 "strconv"
5
6 "go.opencensus.io/trace"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/sd/lb"
10 )
11
12 // TraceEndpointDefaultName is the default endpoint span name to use.
13 const TraceEndpointDefaultName = "gokit/endpoint"
14
15 // TraceEndpoint returns an Endpoint middleware, tracing a Go kit endpoint.
16 // This endpoint tracer should be used in combination with a Go kit Transport
17 // tracing middleware, generic OpenCensus transport middleware or custom before
18 // and after transport functions as service propagation of SpanContext is not
19 // provided in this middleware.
20 func TraceEndpoint(name string, options ...EndpointOption) endpoint.Middleware {
21 if name == "" {
22 name = TraceEndpointDefaultName
23 }
24
25 cfg := &EndpointOptions{}
26
27 for _, o := range options {
28 o(cfg)
29 }
30
31 return func(next endpoint.Endpoint) endpoint.Endpoint {
32 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
33 ctx, span := trace.StartSpan(ctx, name)
34 if len(cfg.Attributes) > 0 {
35 span.AddAttributes(cfg.Attributes...)
36 }
37 defer span.End()
38
39 defer func() {
40 if err != nil {
41 if lberr, ok := err.(lb.RetryError); ok {
42 // handle errors originating from lb.Retry
43 attrs := make([]trace.Attribute, 0, len(lberr.RawErrors))
44 for idx, rawErr := range lberr.RawErrors {
45 attrs = append(attrs, trace.StringAttribute(
46 "gokit.retry.error."+strconv.Itoa(idx+1), rawErr.Error(),
47 ))
48 }
49 span.AddAttributes(attrs...)
50 span.SetStatus(trace.Status{
51 Code: trace.StatusCodeUnknown,
52 Message: lberr.Final.Error(),
53 })
54 return
55 }
56 // generic error
57 span.SetStatus(trace.Status{
58 Code: trace.StatusCodeUnknown,
59 Message: err.Error(),
60 })
61 return
62 }
63
64 // test for business error
65 if res, ok := response.(endpoint.Failer); ok && res.Failed() != nil {
66 span.AddAttributes(
67 trace.StringAttribute("gokit.business.error", res.Failed().Error()),
68 )
69 if cfg.IgnoreBusinessError {
70 span.SetStatus(trace.Status{Code: trace.StatusCodeOK})
71 return
72 }
73 // treating business error as real error in span.
74 span.SetStatus(trace.Status{
75 Code: trace.StatusCodeUnknown,
76 Message: res.Failed().Error(),
77 })
78 return
79 }
80
81 // no errors identified
82 span.SetStatus(trace.Status{Code: trace.StatusCodeOK})
83 }()
84 response, err = next(ctx, request)
85 return
86 }
87 }
88 }
0 package opencensus
1
2 import "go.opencensus.io/trace"
3
4 // EndpointOptions holds the options for tracing an endpoint
5 type EndpointOptions struct {
6 // IgnoreBusinessError if set to true will not treat a business error
7 // identified through the endpoint.Failer interface as a span error.
8 IgnoreBusinessError bool
9
10 // Attributes holds the default attributes which will be set on span
11 // creation by our Endpoint middleware.
12 Attributes []trace.Attribute
13 }
14
15 // EndpointOption allows for functional options to our OpenCensus endpoint
16 // tracing middleware.
17 type EndpointOption func(*EndpointOptions)
18
19 // WithEndpointConfig sets all configuration options at once by use of the
20 // EndpointOptions struct.
21 func WithEndpointConfig(options EndpointOptions) EndpointOption {
22 return func(o *EndpointOptions) {
23 *o = options
24 }
25 }
26
27 // WithEndpointAttributes sets the default attributes for the spans created by
28 // the Endpoint tracer.
29 func WithEndpointAttributes(attrs ...trace.Attribute) EndpointOption {
30 return func(o *EndpointOptions) {
31 o.Attributes = attrs
32 }
33 }
34
35 // WithIgnoreBusinessError if set to true will not treat a business error
36 // identified through the endpoint.Failer interface as a span error.
37 func WithIgnoreBusinessError(val bool) EndpointOption {
38 return func(o *EndpointOptions) {
39 o.IgnoreBusinessError = val
40 }
41 }
0 package opencensus_test
1
2 import (
3 "context"
4 "errors"
5 "testing"
6 "time"
7
8 "go.opencensus.io/trace"
9
10 "github.com/go-kit/kit/endpoint"
11 "github.com/go-kit/kit/sd"
12 "github.com/go-kit/kit/sd/lb"
13 "github.com/go-kit/kit/tracing/opencensus"
14 )
15
16 const (
17 span1 = ""
18 span2 = "SPAN-2"
19 span3 = "SPAN-3"
20 span4 = "SPAN-4"
21 span5 = "SPAN-5"
22 )
23
24 var (
25 err1 = errors.New("some error")
26 err2 = errors.New("other error")
27 err3 = errors.New("some business error")
28 err4 = errors.New("other business error")
29 )
30
31 // compile time assertion
32 var _ endpoint.Failer = failedResponse{}
33
34 type failedResponse struct {
35 err error
36 }
37
38 func (r failedResponse) Failed() error { return r.err }
39
40 func passEndpoint(_ context.Context, req interface{}) (interface{}, error) {
41 if err, _ := req.(error); err != nil {
42 return nil, err
43 }
44 return req, nil
45 }
46
47 func TestTraceEndpoint(t *testing.T) {
48 ctx := context.Background()
49
50 e := &recordingExporter{}
51 trace.RegisterExporter(e)
52 trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
53
54 // span 1
55 span1Attrs := []trace.Attribute{
56 trace.StringAttribute("string", "value"),
57 trace.Int64Attribute("int64", 42),
58 }
59 mw := opencensus.TraceEndpoint(
60 span1, opencensus.WithEndpointAttributes(span1Attrs...),
61 )
62 mw(endpoint.Nop)(ctx, nil)
63
64 // span 2
65 opts := opencensus.EndpointOptions{}
66 mw = opencensus.TraceEndpoint(span2, opencensus.WithEndpointConfig(opts))
67 mw(passEndpoint)(ctx, err1)
68
69 // span3
70 mw = opencensus.TraceEndpoint(span3)
71 ep := lb.Retry(5, 1*time.Second, lb.NewRoundRobin(sd.FixedEndpointer{passEndpoint}))
72 mw(ep)(ctx, err2)
73
74 // span4
75 mw = opencensus.TraceEndpoint(span4)
76 mw(passEndpoint)(ctx, failedResponse{err: err3})
77
78 // span4
79 mw = opencensus.TraceEndpoint(span5, opencensus.WithIgnoreBusinessError(true))
80 mw(passEndpoint)(ctx, failedResponse{err: err4})
81
82 // check span count
83 spans := e.Flush()
84 if want, have := 5, len(spans); want != have {
85 t.Fatalf("incorrected number of spans, wanted %d, got %d", want, have)
86 }
87
88 // test span 1
89 span := spans[0]
90 if want, have := int32(trace.StatusCodeOK), span.Code; want != have {
91 t.Errorf("incorrect status code, wanted %d, got %d", want, have)
92 }
93
94 if want, have := opencensus.TraceEndpointDefaultName, span.Name; want != have {
95 t.Errorf("incorrect span name, wanted %q, got %q", want, have)
96 }
97
98 if want, have := 2, len(span.Attributes); want != have {
99 t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have)
100 }
101
102 // test span 2
103 span = spans[1]
104 if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have {
105 t.Errorf("incorrect status code, wanted %d, got %d", want, have)
106 }
107
108 if want, have := span2, span.Name; want != have {
109 t.Errorf("incorrect span name, wanted %q, got %q", want, have)
110 }
111
112 if want, have := 0, len(span.Attributes); want != have {
113 t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have)
114 }
115
116 // test span 3
117 span = spans[2]
118 if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have {
119 t.Errorf("incorrect status code, wanted %d, got %d", want, have)
120 }
121
122 if want, have := span3, span.Name; want != have {
123 t.Errorf("incorrect span name, wanted %q, got %q", want, have)
124 }
125
126 if want, have := 5, len(span.Attributes); want != have {
127 t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have)
128 }
129
130 // test span 4
131 span = spans[3]
132 if want, have := int32(trace.StatusCodeUnknown), span.Code; want != have {
133 t.Errorf("incorrect status code, wanted %d, got %d", want, have)
134 }
135
136 if want, have := span4, span.Name; want != have {
137 t.Errorf("incorrect span name, wanted %q, got %q", want, have)
138 }
139
140 if want, have := 1, len(span.Attributes); want != have {
141 t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have)
142 }
143
144 // test span 5
145 span = spans[4]
146 if want, have := int32(trace.StatusCodeOK), span.Code; want != have {
147 t.Errorf("incorrect status code, wanted %d, got %d", want, have)
148 }
149
150 if want, have := span5, span.Name; want != have {
151 t.Errorf("incorrect span name, wanted %q, got %q", want, have)
152 }
153
154 if want, have := 1, len(span.Attributes); want != have {
155 t.Fatalf("incorrect attribute count, wanted %d, got %d", want, have)
156 }
157
158 }
0 package opencensus
1
2 import (
3 "context"
4
5 "go.opencensus.io/trace"
6 "go.opencensus.io/trace/propagation"
7 "google.golang.org/grpc/codes"
8 "google.golang.org/grpc/metadata"
9 "google.golang.org/grpc/status"
10
11 kitgrpc "github.com/go-kit/kit/transport/grpc"
12 )
13
14 const propagationKey = "grpc-trace-bin"
15
16 // GRPCClientTrace enables OpenCensus tracing of a Go kit gRPC transport client.
17 func GRPCClientTrace(options ...TracerOption) kitgrpc.ClientOption {
18 cfg := TracerOptions{}
19
20 for _, option := range options {
21 option(&cfg)
22 }
23
24 if cfg.Sampler == nil {
25 cfg.Sampler = trace.AlwaysSample()
26 }
27
28 clientBefore := kitgrpc.ClientBefore(
29 func(ctx context.Context, md *metadata.MD) context.Context {
30 var name string
31
32 if cfg.Name != "" {
33 name = cfg.Name
34 } else {
35 name = ctx.Value(kitgrpc.ContextKeyRequestMethod).(string)
36 }
37
38 ctx, span := trace.StartSpan(
39 ctx,
40 name,
41 trace.WithSampler(cfg.Sampler),
42 trace.WithSpanKind(trace.SpanKindClient),
43 )
44
45 if !cfg.Public {
46 traceContextBinary := string(propagation.Binary(span.SpanContext()))
47 (*md)[propagationKey] = append((*md)[propagationKey], traceContextBinary)
48 }
49
50 return ctx
51 },
52 )
53
54 clientFinalizer := kitgrpc.ClientFinalizer(
55 func(ctx context.Context, err error) {
56 if span := trace.FromContext(ctx); span != nil {
57 if s, ok := status.FromError(err); ok {
58 span.SetStatus(trace.Status{Code: int32(s.Code()), Message: s.Message()})
59 } else {
60 span.SetStatus(trace.Status{Code: int32(codes.Unknown), Message: err.Error()})
61 }
62 span.End()
63 }
64 },
65 )
66
67 return func(c *kitgrpc.Client) {
68 clientBefore(c)
69 clientFinalizer(c)
70 }
71 }
72
73 // GRPCServerTrace enables OpenCensus tracing of a Go kit gRPC transport server.
74 func GRPCServerTrace(options ...TracerOption) kitgrpc.ServerOption {
75 cfg := TracerOptions{}
76
77 for _, option := range options {
78 option(&cfg)
79 }
80
81 if cfg.Sampler == nil {
82 cfg.Sampler = trace.AlwaysSample()
83 }
84
85 serverBefore := kitgrpc.ServerBefore(
86 func(ctx context.Context, md metadata.MD) context.Context {
87 var name string
88
89 if cfg.Name != "" {
90 name = cfg.Name
91 } else {
92 name, _ = ctx.Value(kitgrpc.ContextKeyRequestMethod).(string)
93 if name == "" {
94 // we can't find the gRPC method. probably the
95 // unaryInterceptor was not wired up.
96 name = "unknown grpc method"
97 }
98 }
99
100 var (
101 parentContext trace.SpanContext
102 traceContext = md[propagationKey]
103 ok bool
104 )
105
106 if len(traceContext) > 0 {
107 traceContextBinary := []byte(traceContext[0])
108 parentContext, ok = propagation.FromBinary(traceContextBinary)
109 if ok && !cfg.Public {
110 ctx, _ = trace.StartSpanWithRemoteParent(
111 ctx,
112 name,
113 parentContext,
114 trace.WithSpanKind(trace.SpanKindServer),
115 trace.WithSampler(cfg.Sampler),
116 )
117 return ctx
118 }
119 }
120 ctx, span := trace.StartSpan(
121 ctx,
122 name,
123 trace.WithSpanKind(trace.SpanKindServer),
124 trace.WithSampler(cfg.Sampler),
125 )
126 if ok {
127 span.AddLink(
128 trace.Link{
129 TraceID: parentContext.TraceID,
130 SpanID: parentContext.SpanID,
131 Type: trace.LinkTypeChild,
132 },
133 )
134 }
135 return ctx
136 },
137 )
138
139 serverFinalizer := kitgrpc.ServerFinalizer(
140 func(ctx context.Context, err error) {
141 if span := trace.FromContext(ctx); span != nil {
142 if s, ok := status.FromError(err); ok {
143 span.SetStatus(trace.Status{Code: int32(s.Code()), Message: s.Message()})
144 } else {
145 span.SetStatus(trace.Status{Code: int32(codes.Internal), Message: err.Error()})
146 }
147 span.End()
148 }
149 },
150 )
151
152 return func(s *kitgrpc.Server) {
153 serverBefore(s)
154 serverFinalizer(s)
155 }
156 }
0 package opencensus_test
1
2 import (
3 "context"
4 "errors"
5 "testing"
6
7 "go.opencensus.io/trace"
8 "go.opencensus.io/trace/propagation"
9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/metadata"
12
13 "github.com/go-kit/kit/endpoint"
14 ockit "github.com/go-kit/kit/tracing/opencensus"
15 grpctransport "github.com/go-kit/kit/transport/grpc"
16 )
17
18 type dummy struct{}
19
20 const traceContextKey = "grpc-trace-bin"
21
22 func unaryInterceptor(
23 ctx context.Context, method string, req, reply interface{},
24 cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
25 ) error {
26 return nil
27 }
28
29 func TestGRPCClientTrace(t *testing.T) {
30 rec := &recordingExporter{}
31
32 trace.RegisterExporter(rec)
33
34 cc, err := grpc.Dial(
35 "",
36 grpc.WithUnaryInterceptor(unaryInterceptor),
37 grpc.WithInsecure(),
38 )
39 if err != nil {
40 t.Fatalf("unable to create gRPC dialer: %s", err.Error())
41 }
42
43 traces := []struct {
44 name string
45 err error
46 }{
47 {"", nil},
48 {"CustomName", nil},
49 {"", errors.New("dummy-error")},
50 }
51
52 for _, tr := range traces {
53 clientTracer := ockit.GRPCClientTrace(ockit.WithName(tr.name))
54
55 ep := grpctransport.NewClient(
56 cc,
57 "dummyService",
58 "dummyMethod",
59 func(context.Context, interface{}) (interface{}, error) {
60 return nil, nil
61 },
62 func(context.Context, interface{}) (interface{}, error) {
63 return nil, tr.err
64 },
65 dummy{},
66 clientTracer,
67 ).Endpoint()
68
69 ctx, parentSpan := trace.StartSpan(context.Background(), "test")
70
71 _, err = ep(ctx, nil)
72 if want, have := tr.err, err; want != have {
73 t.Fatalf("unexpected error, want %s, have %s", tr.err.Error(), err.Error())
74 }
75
76 spans := rec.Flush()
77 if want, have := 1, len(spans); want != have {
78 t.Fatalf("incorrect number of spans, want %d, have %d", want, have)
79 }
80 span := spans[0]
81 if want, have := parentSpan.SpanContext().SpanID, span.ParentSpanID; want != have {
82 t.Errorf("incorrect parent ID, want %s, have %s", want, have)
83 }
84
85 if want, have := tr.name, span.Name; want != have && want != "" {
86 t.Errorf("incorrect span name, want %s, have %s", want, have)
87 }
88
89 if want, have := "/dummyService/dummyMethod", span.Name; want != have && tr.name == "" {
90 t.Errorf("incorrect span name, want %s, have %s", want, have)
91 }
92
93 code := trace.StatusCodeOK
94 if tr.err != nil {
95 code = trace.StatusCodeUnknown
96
97 if want, have := err.Error(), span.Status.Message; want != have {
98 t.Errorf("incorrect span status msg, want %s, have %s", want, have)
99 }
100 }
101
102 if want, have := int32(code), span.Status.Code; want != have {
103 t.Errorf("incorrect span status code, want %d, have %d", want, have)
104 }
105 }
106 }
107
108 func TestGRPCServerTrace(t *testing.T) {
109 rec := &recordingExporter{}
110
111 trace.RegisterExporter(rec)
112
113 traces := []struct {
114 useParent bool
115 name string
116 err error
117 }{
118 {false, "", nil},
119 {true, "", nil},
120 {true, "CustomName", nil},
121 {true, "", errors.New("dummy-error")},
122 }
123
124 for _, tr := range traces {
125 var (
126 ctx = context.Background()
127 parentSpan *trace.Span
128 )
129
130 server := grpctransport.NewServer(
131 endpoint.Nop,
132 func(context.Context, interface{}) (interface{}, error) {
133 return nil, nil
134 },
135 func(context.Context, interface{}) (interface{}, error) {
136 return nil, tr.err
137 },
138 ockit.GRPCServerTrace(ockit.WithName(tr.name)),
139 )
140
141 if tr.useParent {
142 _, parentSpan = trace.StartSpan(context.Background(), "test")
143 traceContextBinary := propagation.Binary(parentSpan.SpanContext())
144
145 md := metadata.MD{}
146 md.Set(traceContextKey, string(traceContextBinary))
147 ctx = metadata.NewIncomingContext(ctx, md)
148 }
149
150 server.ServeGRPC(ctx, nil)
151
152 spans := rec.Flush()
153
154 if want, have := 1, len(spans); want != have {
155 t.Fatalf("incorrect number of spans, want %d, have %d", want, have)
156 }
157
158 if tr.useParent {
159 if want, have := parentSpan.SpanContext().TraceID, spans[0].TraceID; want != have {
160 t.Errorf("incorrect trace ID, want %s, have %s", want, have)
161 }
162
163 if want, have := parentSpan.SpanContext().SpanID, spans[0].ParentSpanID; want != have {
164 t.Errorf("incorrect span ID, want %s, have %s", want, have)
165 }
166 }
167
168 if want, have := tr.name, spans[0].Name; want != have && want != "" {
169 t.Errorf("incorrect span name, want %s, have %s", want, have)
170 }
171
172 if tr.err != nil {
173 if want, have := int32(codes.Internal), spans[0].Status.Code; want != have {
174 t.Errorf("incorrect span status code, want %d, have %d", want, have)
175 }
176
177 if want, have := tr.err.Error(), spans[0].Status.Message; want != have {
178 t.Errorf("incorrect span status message, want %s, have %s", want, have)
179 }
180 }
181 }
182 }
0 package opencensus
1
2 import (
3 "context"
4 "net/http"
5
6 "go.opencensus.io/plugin/ochttp"
7 "go.opencensus.io/plugin/ochttp/propagation/b3"
8 "go.opencensus.io/trace"
9
10 kithttp "github.com/go-kit/kit/transport/http"
11 )
12
13 // HTTPClientTrace enables OpenCensus tracing of a Go kit HTTP transport client.
14 func HTTPClientTrace(options ...TracerOption) kithttp.ClientOption {
15 cfg := TracerOptions{}
16
17 for _, option := range options {
18 option(&cfg)
19 }
20
21 if cfg.Sampler == nil {
22 cfg.Sampler = trace.AlwaysSample()
23 }
24
25 if !cfg.Public && cfg.HTTPPropagate == nil {
26 cfg.HTTPPropagate = &b3.HTTPFormat{}
27 }
28
29 clientBefore := kithttp.ClientBefore(
30 func(ctx context.Context, req *http.Request) context.Context {
31 var name string
32
33 if cfg.Name != "" {
34 name = cfg.Name
35 } else {
36 // OpenCensus states Path being default naming for a client span
37 name = req.Method + " " + req.URL.Path
38 }
39
40 ctx, span := trace.StartSpan(
41 ctx,
42 name,
43 trace.WithSampler(cfg.Sampler),
44 trace.WithSpanKind(trace.SpanKindClient),
45 )
46
47 span.AddAttributes(
48 trace.StringAttribute(ochttp.HostAttribute, req.URL.Host),
49 trace.StringAttribute(ochttp.MethodAttribute, req.Method),
50 trace.StringAttribute(ochttp.PathAttribute, req.URL.Path),
51 trace.StringAttribute(ochttp.UserAgentAttribute, req.UserAgent()),
52 )
53
54 if !cfg.Public {
55 cfg.HTTPPropagate.SpanContextToRequest(span.SpanContext(), req)
56 }
57
58 return ctx
59 },
60 )
61
62 clientAfter := kithttp.ClientAfter(
63 func(ctx context.Context, res *http.Response) context.Context {
64 if span := trace.FromContext(ctx); span != nil {
65 span.SetStatus(ochttp.TraceStatus(res.StatusCode, http.StatusText(res.StatusCode)))
66 span.AddAttributes(
67 trace.Int64Attribute(ochttp.StatusCodeAttribute, int64(res.StatusCode)),
68 )
69 }
70 return ctx
71 },
72 )
73
74 clientFinalizer := kithttp.ClientFinalizer(
75 func(ctx context.Context, err error) {
76 if span := trace.FromContext(ctx); span != nil {
77 if err != nil {
78 span.SetStatus(trace.Status{
79 Code: trace.StatusCodeUnknown,
80 Message: err.Error(),
81 })
82 }
83 span.End()
84 }
85 },
86 )
87
88 return func(c *kithttp.Client) {
89 clientBefore(c)
90 clientAfter(c)
91 clientFinalizer(c)
92 }
93 }
94
95 // HTTPServerTrace enables OpenCensus tracing of a Go kit HTTP transport server.
96 func HTTPServerTrace(options ...TracerOption) kithttp.ServerOption {
97 cfg := TracerOptions{}
98
99 for _, option := range options {
100 option(&cfg)
101 }
102
103 if cfg.Sampler == nil {
104 cfg.Sampler = trace.AlwaysSample()
105 }
106
107 if !cfg.Public && cfg.HTTPPropagate == nil {
108 cfg.HTTPPropagate = &b3.HTTPFormat{}
109 }
110
111 serverBefore := kithttp.ServerBefore(
112 func(ctx context.Context, req *http.Request) context.Context {
113 var (
114 spanContext trace.SpanContext
115 span *trace.Span
116 name string
117 ok bool
118 )
119
120 if cfg.Name != "" {
121 name = cfg.Name
122 } else {
123 name = req.Method + " " + req.URL.Path
124 }
125
126 spanContext, ok = cfg.HTTPPropagate.SpanContextFromRequest(req)
127 if ok && !cfg.Public {
128 ctx, span = trace.StartSpanWithRemoteParent(
129 ctx,
130 name,
131 spanContext,
132 trace.WithSpanKind(trace.SpanKindServer),
133 trace.WithSampler(cfg.Sampler),
134 )
135 } else {
136 ctx, span = trace.StartSpan(
137 ctx,
138 name,
139 trace.WithSpanKind(trace.SpanKindServer),
140 trace.WithSampler(cfg.Sampler),
141 )
142 if ok {
143 span.AddLink(trace.Link{
144 TraceID: spanContext.TraceID,
145 SpanID: spanContext.SpanID,
146 Type: trace.LinkTypeChild,
147 Attributes: nil,
148 })
149 }
150 }
151
152 span.AddAttributes(
153 trace.StringAttribute(ochttp.MethodAttribute, req.Method),
154 trace.StringAttribute(ochttp.PathAttribute, req.URL.Path),
155 )
156
157 return ctx
158 },
159 )
160
161 serverFinalizer := kithttp.ServerFinalizer(
162 func(ctx context.Context, code int, r *http.Request) {
163 if span := trace.FromContext(ctx); span != nil {
164 span.SetStatus(ochttp.TraceStatus(code, http.StatusText(code)))
165
166 if rs, ok := ctx.Value(kithttp.ContextKeyResponseSize).(int64); ok {
167 span.AddAttributes(
168 trace.Int64Attribute("http.response_size", rs),
169 )
170 }
171
172 span.End()
173 }
174 },
175 )
176
177 return func(s *kithttp.Server) {
178 serverBefore(s)
179 serverFinalizer(s)
180 }
181 }
0 package opencensus_test
1
2 import (
3 "context"
4 "errors"
5 "net/http"
6 "net/http/httptest"
7 "net/url"
8 "testing"
9
10 "go.opencensus.io/plugin/ochttp"
11 "go.opencensus.io/plugin/ochttp/propagation/b3"
12 "go.opencensus.io/plugin/ochttp/propagation/tracecontext"
13 "go.opencensus.io/trace"
14 "go.opencensus.io/trace/propagation"
15
16 "github.com/go-kit/kit/endpoint"
17 ockit "github.com/go-kit/kit/tracing/opencensus"
18 kithttp "github.com/go-kit/kit/transport/http"
19 )
20
21 func TestHTTPClientTrace(t *testing.T) {
22 var (
23 err error
24 rec = &recordingExporter{}
25 rURL, _ = url.Parse("http://test.com/dummy/path")
26 )
27
28 trace.RegisterExporter(rec)
29
30 traces := []struct {
31 name string
32 err error
33 }{
34 {"", nil},
35 {"CustomName", nil},
36 {"", errors.New("dummy-error")},
37 }
38
39 for _, tr := range traces {
40 clientTracer := ockit.HTTPClientTrace(ockit.WithName(tr.name))
41 ep := kithttp.NewClient(
42 "GET",
43 rURL,
44 func(ctx context.Context, r *http.Request, i interface{}) error {
45 return nil
46 },
47 func(ctx context.Context, r *http.Response) (response interface{}, err error) {
48 return nil, tr.err
49 },
50 clientTracer,
51 ).Endpoint()
52
53 ctx, parentSpan := trace.StartSpan(context.Background(), "test")
54
55 _, err = ep(ctx, nil)
56 if want, have := tr.err, err; want != have {
57 t.Fatalf("unexpected error, want %s, have %s", tr.err.Error(), err.Error())
58 }
59
60 spans := rec.Flush()
61 if want, have := 1, len(spans); want != have {
62 t.Fatalf("incorrect number of spans, want %d, have %d", want, have)
63 }
64
65 span := spans[0]
66 if want, have := parentSpan.SpanContext().SpanID, span.ParentSpanID; want != have {
67 t.Errorf("incorrect parent ID, want %s, have %s", want, have)
68 }
69
70 if want, have := tr.name, span.Name; want != have && want != "" {
71 t.Errorf("incorrect span name, want %s, have %s", want, have)
72 }
73
74 if want, have := "GET /dummy/path", span.Name; want != have && tr.name == "" {
75 t.Errorf("incorrect span name, want %s, have %s", want, have)
76 }
77
78 code := trace.StatusCodeOK
79 if tr.err != nil {
80 code = trace.StatusCodeUnknown
81
82 if want, have := err.Error(), span.Status.Message; want != have {
83 t.Errorf("incorrect span status msg, want %s, have %s", want, have)
84 }
85 }
86
87 if want, have := int32(code), span.Status.Code; want != have {
88 t.Errorf("incorrect span status code, want %d, have %d", want, have)
89 }
90 }
91 }
92
93 func TestHTTPServerTrace(t *testing.T) {
94 rec := &recordingExporter{}
95
96 trace.RegisterExporter(rec)
97
98 traces := []struct {
99 useParent bool
100 name string
101 err error
102 propagation propagation.HTTPFormat
103 }{
104 {false, "", nil, nil},
105 {true, "", nil, nil},
106 {true, "CustomName", nil, &b3.HTTPFormat{}},
107 {true, "", errors.New("dummy-error"), &tracecontext.HTTPFormat{}},
108 }
109
110 for _, tr := range traces {
111 var client http.Client
112
113 handler := kithttp.NewServer(
114 endpoint.Nop,
115 func(context.Context, *http.Request) (interface{}, error) { return nil, nil },
116 func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dummy") },
117 ockit.HTTPServerTrace(
118 ockit.WithName(tr.name),
119 ockit.WithHTTPPropagation(tr.propagation),
120 ),
121 )
122
123 server := httptest.NewServer(handler)
124 defer server.Close()
125
126 const httpMethod = "GET"
127
128 req, err := http.NewRequest(httpMethod, server.URL, nil)
129 if err != nil {
130 t.Fatalf("unable to create HTTP request: %s", err.Error())
131 }
132
133 if tr.useParent {
134 client = http.Client{
135 Transport: &ochttp.Transport{
136 Propagation: tr.propagation,
137 },
138 }
139 }
140
141 resp, err := client.Do(req.WithContext(context.Background()))
142 if err != nil {
143 t.Fatalf("unable to send HTTP request: %s", err.Error())
144 }
145 resp.Body.Close()
146
147 spans := rec.Flush()
148
149 expectedSpans := 1
150 if tr.useParent {
151 expectedSpans++
152 }
153
154 if want, have := expectedSpans, len(spans); want != have {
155 t.Fatalf("incorrect number of spans, want %d, have %d", want, have)
156 }
157
158 if tr.useParent {
159 if want, have := spans[1].TraceID, spans[0].TraceID; want != have {
160 t.Errorf("incorrect trace ID, want %s, have %s", want, have)
161 }
162
163 if want, have := spans[1].SpanID, spans[0].ParentSpanID; want != have {
164 t.Errorf("incorrect span ID, want %s, have %s", want, have)
165 }
166 }
167
168 if want, have := tr.name, spans[0].Name; want != have && want != "" {
169 t.Errorf("incorrect span name, want %s, have %s", want, have)
170 }
171
172 if want, have := "GET /", spans[0].Name; want != have && tr.name == "" {
173 t.Errorf("incorrect span name, want %s, have %s", want, have)
174 }
175 }
176 }
0 package opencensus_test
1
2 import (
3 "sync"
4
5 "go.opencensus.io/trace"
6 )
7
8 type recordingExporter struct {
9 mu sync.Mutex
10 data []*trace.SpanData
11 }
12
13 func (e *recordingExporter) ExportSpan(d *trace.SpanData) {
14 e.mu.Lock()
15 defer e.mu.Unlock()
16
17 e.data = append(e.data, d)
18 }
19
20 func (e *recordingExporter) Flush() (data []*trace.SpanData) {
21 e.mu.Lock()
22 defer e.mu.Unlock()
23
24 data = e.data
25 e.data = nil
26 return
27 }
0 package opencensus
1
2 import (
3 "go.opencensus.io/plugin/ochttp/propagation/b3"
4 "go.opencensus.io/trace"
5 "go.opencensus.io/trace/propagation"
6 )
7
8 // defaultHTTPPropagate holds OpenCensus' default HTTP propagation format which
9 // currently is Zipkin's B3.
10 var defaultHTTPPropagate propagation.HTTPFormat = &b3.HTTPFormat{}
11
12 // TracerOption allows for functional options to our OpenCensus tracing
13 // middleware.
14 type TracerOption func(o *TracerOptions)
15
16 // WithTracerConfig sets all configuration options at once.
17 func WithTracerConfig(options TracerOptions) TracerOption {
18 return func(o *TracerOptions) {
19 *o = options
20 }
21 }
22
23 // WithSampler sets the sampler to use by our OpenCensus Tracer.
24 func WithSampler(sampler trace.Sampler) TracerOption {
25 return func(o *TracerOptions) {
26 o.Sampler = sampler
27 }
28 }
29
30 // WithName sets the name for an instrumented transport endpoint. If name is omitted
31 // at tracing middleware creation, the method of the transport or transport rpc
32 // name is used.
33 func WithName(name string) TracerOption {
34 return func(o *TracerOptions) {
35 o.Name = name
36 }
37 }
38
39 // IsPublic should be set to true for publicly accessible servers and for
40 // clients that should not propagate their current trace metadata.
41 // On the server side a new trace will always be started regardless of any
42 // trace metadata being found in the incoming request. If any trace metadata
43 // is found, it will be added as a linked trace instead.
44 func IsPublic(isPublic bool) TracerOption {
45 return func(o *TracerOptions) {
46 o.Public = isPublic
47 }
48 }
49
50 // WithHTTPPropagation sets the propagation handlers for the HTTP transport
51 // middlewares. If used on a non HTTP transport this is a noop.
52 func WithHTTPPropagation(p propagation.HTTPFormat) TracerOption {
53 return func(o *TracerOptions) {
54 if p == nil {
55 // reset to default OC HTTP format
56 o.HTTPPropagate = defaultHTTPPropagate
57 return
58 }
59 o.HTTPPropagate = p
60 }
61 }
62
63 // TracerOptions holds configuration for our tracing middlewares
64 type TracerOptions struct {
65 Sampler trace.Sampler
66 Name string
67 Public bool
68 HTTPPropagate propagation.HTTPFormat
69 }
5252 func (w metadataReaderWriter) Set(key, val string) {
5353 key = strings.ToLower(key)
5454 if strings.HasSuffix(key, "-bin") {
55 val = string(base64.StdEncoding.EncodeToString([]byte(val)))
55 val = base64.StdEncoding.EncodeToString([]byte(val))
5656 }
5757 (*w.MD)[key] = append((*w.MD)[key], val)
5858 }
141141
142142 rpcMethod, ok := ctx.Value(kitgrpc.ContextKeyRequestMethod).(string)
143143 if !ok {
144 config.logger.Log("unable to retrieve method name: missing gRPC interceptor hook")
144 config.logger.Log("err", "unable to retrieve method name: missing gRPC interceptor hook")
145145 } else {
146146 tags["grpc.method"] = rpcMethod
147147 }
0 // Package amqp implements an AMQP transport.
1 package amqp
0 package amqp
1
2 import (
3 "context"
4 "github.com/streadway/amqp"
5 )
6
7 // DecodeRequestFunc extracts a user-domain request object from
8 // an AMQP Delivery object. It is designed to be used in AMQP Subscribers.
9 type DecodeRequestFunc func(context.Context, *amqp.Delivery) (request interface{}, err error)
10
11 // EncodeRequestFunc encodes the passed request object into
12 // an AMQP Publishing object. It is designed to be used in AMQP Publishers.
13 type EncodeRequestFunc func(context.Context, *amqp.Publishing, interface{}) error
14
15 // EncodeResponseFunc encodes the passed reponse object to
16 // an AMQP Publishing object. It is designed to be used in AMQP Subscribers.
17 type EncodeResponseFunc func(context.Context, *amqp.Publishing, interface{}) error
18
19 // DecodeResponseFunc extracts a user-domain response object from
20 // an AMQP Delivery object. It is designed to be used in AMQP Publishers.
21 type DecodeResponseFunc func(context.Context, *amqp.Delivery) (response interface{}, err error)
0 package amqp
1
2 import (
3 "context"
4 "time"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/streadway/amqp"
8 )
9
10 // The golang AMQP implementation requires the []byte representation of
11 // correlation id strings to have a maximum length of 255 bytes.
12 const maxCorrelationIdLength = 255
13
14 // Publisher wraps an AMQP channel and queue, and provides a method that
15 // implements endpoint.Endpoint.
16 type Publisher struct {
17 ch Channel
18 q *amqp.Queue
19 enc EncodeRequestFunc
20 dec DecodeResponseFunc
21 before []RequestFunc
22 after []PublisherResponseFunc
23 timeout time.Duration
24 }
25
26 // NewPublisher constructs a usable Publisher for a single remote method.
27 func NewPublisher(
28 ch Channel,
29 q *amqp.Queue,
30 enc EncodeRequestFunc,
31 dec DecodeResponseFunc,
32 options ...PublisherOption,
33 ) *Publisher {
34 p := &Publisher{
35 ch: ch,
36 q: q,
37 enc: enc,
38 dec: dec,
39 timeout: 10 * time.Second,
40 }
41 for _, option := range options {
42 option(p)
43 }
44 return p
45 }
46
47 // PublisherOption sets an optional parameter for clients.
48 type PublisherOption func(*Publisher)
49
50 // PublisherBefore sets the RequestFuncs that are applied to the outgoing AMQP
51 // request before it's invoked.
52 func PublisherBefore(before ...RequestFunc) PublisherOption {
53 return func(p *Publisher) { p.before = append(p.before, before...) }
54 }
55
56 // PublisherAfter sets the ClientResponseFuncs applied to the incoming AMQP
57 // request prior to it being decoded. This is useful for obtaining anything off
58 // of the response and adding onto the context prior to decoding.
59 func PublisherAfter(after ...PublisherResponseFunc) PublisherOption {
60 return func(p *Publisher) { p.after = append(p.after, after...) }
61 }
62
63 // PublisherTimeout sets the available timeout for an AMQP request.
64 func PublisherTimeout(timeout time.Duration) PublisherOption {
65 return func(p *Publisher) { p.timeout = timeout }
66 }
67
68 // Endpoint returns a usable endpoint that invokes the remote endpoint.
69 func (p Publisher) Endpoint() endpoint.Endpoint {
70 return func(ctx context.Context, request interface{}) (interface{}, error) {
71 ctx, cancel := context.WithTimeout(ctx, p.timeout)
72 defer cancel()
73
74 pub := amqp.Publishing{
75 ReplyTo: p.q.Name,
76 CorrelationId: randomString(randInt(5, maxCorrelationIdLength)),
77 }
78
79 if err := p.enc(ctx, &pub, request); err != nil {
80 return nil, err
81 }
82
83 for _, f := range p.before {
84 ctx = f(ctx, &pub)
85 }
86
87 deliv, err := p.publishAndConsumeFirstMatchingResponse(ctx, &pub)
88 if err != nil {
89 return nil, err
90 }
91
92 for _, f := range p.after {
93 ctx = f(ctx, deliv)
94 }
95 response, err := p.dec(ctx, deliv)
96 if err != nil {
97 return nil, err
98 }
99
100 return response, nil
101 }
102 }
103
104 // publishAndConsumeFirstMatchingResponse publishes the specified Publishing
105 // and returns the first Delivery object with the matching correlationId.
106 // If the context times out while waiting for a reply, an error will be returned.
107 func (p Publisher) publishAndConsumeFirstMatchingResponse(
108 ctx context.Context,
109 pub *amqp.Publishing,
110 ) (*amqp.Delivery, error) {
111 err := p.ch.Publish(
112 getPublishExchange(ctx),
113 getPublishKey(ctx),
114 false, //mandatory
115 false, //immediate
116 *pub,
117 )
118 if err != nil {
119 return nil, err
120 }
121 autoAck := getConsumeAutoAck(ctx)
122
123 msg, err := p.ch.Consume(
124 p.q.Name,
125 "", //consumer
126 autoAck,
127 false, //exclusive
128 false, //noLocal
129 false, //noWait
130 getConsumeArgs(ctx),
131 )
132 if err != nil {
133 return nil, err
134 }
135
136 for {
137 select {
138 case d := <-msg:
139 if d.CorrelationId == pub.CorrelationId {
140 if !autoAck {
141 d.Ack(false) //multiple
142 }
143 return &d, nil
144 }
145
146 case <-ctx.Done():
147 return nil, ctx.Err()
148 }
149 }
150
151 }
0 package amqp_test
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "testing"
7 "time"
8
9 amqptransport "github.com/go-kit/kit/transport/amqp"
10 "github.com/streadway/amqp"
11 )
12
13 var (
14 defaultContentType = ""
15 defaultContentEncoding = ""
16 )
17
18 // TestBadEncode tests if encode errors are handled properly.
19 func TestBadEncode(t *testing.T) {
20 ch := &mockChannel{f: nullFunc}
21 q := &amqp.Queue{Name: "some queue"}
22 pub := amqptransport.NewPublisher(
23 ch,
24 q,
25 func(context.Context, *amqp.Publishing, interface{}) error { return errors.New("err!") },
26 func(context.Context, *amqp.Delivery) (response interface{}, err error) { return struct{}{}, nil },
27 )
28 errChan := make(chan error, 1)
29 var err error
30 go func() {
31 _, err := pub.Endpoint()(context.Background(), struct{}{})
32 errChan <- err
33
34 }()
35 select {
36 case err = <-errChan:
37 break
38
39 case <-time.After(100 * time.Millisecond):
40 t.Fatal("Timed out waiting for result")
41 }
42 if err == nil {
43 t.Error("expected error")
44 }
45 if want, have := "err!", err.Error(); want != have {
46 t.Errorf("want %s, have %s", want, have)
47 }
48 }
49
50 // TestBadDecode tests if decode errors are handled properly.
51 func TestBadDecode(t *testing.T) {
52 cid := "correlation"
53 ch := &mockChannel{
54 f: nullFunc,
55 c: make(chan amqp.Publishing, 1),
56 deliveries: []amqp.Delivery{
57 amqp.Delivery{
58 CorrelationId: cid,
59 },
60 },
61 }
62 q := &amqp.Queue{Name: "some queue"}
63
64 pub := amqptransport.NewPublisher(
65 ch,
66 q,
67 func(context.Context, *amqp.Publishing, interface{}) error { return nil },
68 func(context.Context, *amqp.Delivery) (response interface{}, err error) {
69 return struct{}{}, errors.New("err!")
70 },
71 amqptransport.PublisherBefore(
72 amqptransport.SetCorrelationID(cid),
73 ),
74 )
75
76 var err error
77 errChan := make(chan error, 1)
78 go func() {
79 _, err := pub.Endpoint()(context.Background(), struct{}{})
80 errChan <- err
81
82 }()
83
84 select {
85 case err = <-errChan:
86 break
87
88 case <-time.After(100 * time.Millisecond):
89 t.Fatal("Timed out waiting for result")
90 }
91
92 if err == nil {
93 t.Error("expected error")
94 }
95 if want, have := "err!", err.Error(); want != have {
96 t.Errorf("want %s, have %s", want, have)
97 }
98 }
99
100 // TestPublisherTimeout ensures that the publisher timeout mechanism works.
101 func TestPublisherTimeout(t *testing.T) {
102 ch := &mockChannel{
103 f: nullFunc,
104 c: make(chan amqp.Publishing, 1),
105 deliveries: []amqp.Delivery{}, // no reply from mock subscriber
106 }
107 q := &amqp.Queue{Name: "some queue"}
108
109 pub := amqptransport.NewPublisher(
110 ch,
111 q,
112 func(context.Context, *amqp.Publishing, interface{}) error { return nil },
113 func(context.Context, *amqp.Delivery) (response interface{}, err error) {
114 return struct{}{}, nil
115 },
116 amqptransport.PublisherTimeout(50*time.Millisecond),
117 )
118
119 var err error
120 errChan := make(chan error, 1)
121 go func() {
122 _, err := pub.Endpoint()(context.Background(), struct{}{})
123 errChan <- err
124
125 }()
126
127 select {
128 case err = <-errChan:
129 break
130
131 case <-time.After(100 * time.Millisecond):
132 t.Fatal("timed out waiting for result")
133 }
134
135 if err == nil {
136 t.Error("expected error")
137 }
138 if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have {
139 t.Errorf("want %s, have %s", want, have)
140 }
141 }
142
143 func TestSuccessfulPublisher(t *testing.T) {
144 cid := "correlation"
145 mockReq := testReq{437}
146 mockRes := testRes{
147 Squadron: mockReq.Squadron,
148 Name: names[mockReq.Squadron],
149 }
150 b, err := json.Marshal(mockRes)
151 if err != nil {
152 t.Fatal(err)
153 }
154 reqChan := make(chan amqp.Publishing, 1)
155 ch := &mockChannel{
156 f: nullFunc,
157 c: reqChan,
158 deliveries: []amqp.Delivery{
159 amqp.Delivery{
160 CorrelationId: cid,
161 Body: b,
162 },
163 },
164 }
165 q := &amqp.Queue{Name: "some queue"}
166
167 pub := amqptransport.NewPublisher(
168 ch,
169 q,
170 testReqEncoder,
171 testResDeliveryDecoder,
172 amqptransport.PublisherBefore(
173 amqptransport.SetCorrelationID(cid),
174 ),
175 )
176 var publishing amqp.Publishing
177 var res testRes
178 var ok bool
179 resChan := make(chan interface{}, 1)
180 errChan := make(chan error, 1)
181 go func() {
182 res, err := pub.Endpoint()(context.Background(), mockReq)
183 if err != nil {
184 errChan <- err
185 } else {
186 resChan <- res
187 }
188 }()
189
190 select {
191 case publishing = <-reqChan:
192 break
193
194 case <-time.After(100 * time.Millisecond):
195 t.Fatal("timed out waiting for request")
196 }
197 if want, have := defaultContentType, publishing.ContentType; want != have {
198 t.Errorf("want %s, have %s", want, have)
199 }
200 if want, have := defaultContentEncoding, publishing.ContentEncoding; want != have {
201 t.Errorf("want %s, have %s", want, have)
202 }
203
204 select {
205 case response := <-resChan:
206 res, ok = response.(testRes)
207 if !ok {
208 t.Error("failed to assert endpoint response type")
209 }
210 break
211
212 case err = <-errChan:
213 break
214
215 case <-time.After(100 * time.Millisecond):
216 t.Fatal("timed out waiting for result")
217 }
218
219 if err != nil {
220 t.Fatal(err)
221 }
222 if want, have := mockRes.Name, res.Name; want != have {
223 t.Errorf("want %s, have %s", want, have)
224 }
225 }
0 package amqp
1
2 import (
3 "context"
4 "time"
5
6 "github.com/streadway/amqp"
7 )
8
9 // RequestFunc may take information from a publisher request and put it into a
10 // request context. In Subscribers, RequestFuncs are executed prior to invoking
11 // the endpoint.
12 type RequestFunc func(context.Context, *amqp.Publishing) context.Context
13
14 // SubscriberResponseFunc may take information from a request context and use it to
15 // manipulate a Publisher. SubscriberResponseFuncs are only executed in
16 // subscribers, after invoking the endpoint but prior to publishing a reply.
17 type SubscriberResponseFunc func(context.Context,
18 *amqp.Delivery,
19 Channel,
20 *amqp.Publishing,
21 ) context.Context
22
23 // PublisherResponseFunc may take information from an AMQP request and make the
24 // response available for consumption. PublisherResponseFunc are only executed
25 // in publishers, after a request has been made, but prior to it being decoded.
26 type PublisherResponseFunc func(context.Context, *amqp.Delivery) context.Context
27
28 // SetPublishExchange returns a RequestFunc that sets the Exchange field
29 // of an AMQP Publish call.
30 func SetPublishExchange(publishExchange string) RequestFunc {
31 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
32 return context.WithValue(ctx, ContextKeyExchange, publishExchange)
33 }
34 }
35
36 // SetPublishKey returns a RequestFunc that sets the Key field
37 // of an AMQP Publish call.
38 func SetPublishKey(publishKey string) RequestFunc {
39 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
40 return context.WithValue(ctx, ContextKeyPublishKey, publishKey)
41 }
42 }
43
44 // SetPublishDeliveryMode sets the delivery mode of a Publishing.
45 // Please refer to AMQP delivery mode constants in the AMQP package.
46 func SetPublishDeliveryMode(dmode uint8) RequestFunc {
47 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
48 pub.DeliveryMode = dmode
49 return ctx
50 }
51 }
52
53 // SetNackSleepDuration returns a RequestFunc that sets the amount of time
54 // to sleep in the event of a Nack.
55 // This has to be used in conjunction with an error encoder that Nack and sleeps.
56 // One example is the SingleNackRequeueErrorEncoder.
57 // It is designed to be used by Subscribers.
58 func SetNackSleepDuration(duration time.Duration) RequestFunc {
59 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
60 return context.WithValue(ctx, ContextKeyNackSleepDuration, duration)
61 }
62 }
63
64 // SetConsumeAutoAck returns a RequestFunc that sets whether or not to autoAck
65 // messages when consuming.
66 // When set to false, the publisher will Ack the first message it receives with
67 // a matching correlationId.
68 // It is designed to be used by Publishers.
69 func SetConsumeAutoAck(autoAck bool) RequestFunc {
70 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
71 return context.WithValue(ctx, ContextKeyAutoAck, autoAck)
72 }
73 }
74
75 // SetConsumeArgs returns a RequestFunc that set the arguments for amqp Consume
76 // function.
77 // It is designed to be used by Publishers.
78 func SetConsumeArgs(args amqp.Table) RequestFunc {
79 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
80 return context.WithValue(ctx, ContextKeyConsumeArgs, args)
81 }
82 }
83
84 // SetContentType returns a RequestFunc that sets the ContentType field of
85 // an AMQP Publishing.
86 func SetContentType(contentType string) RequestFunc {
87 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
88 pub.ContentType = contentType
89 return ctx
90 }
91 }
92
93 // SetContentEncoding returns a RequestFunc that sets the ContentEncoding field
94 // of an AMQP Publishing.
95 func SetContentEncoding(contentEncoding string) RequestFunc {
96 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
97 pub.ContentEncoding = contentEncoding
98 return ctx
99 }
100 }
101
102 // SetCorrelationID returns a RequestFunc that sets the CorrelationId field
103 // of an AMQP Publishing.
104 func SetCorrelationID(cid string) RequestFunc {
105 return func(ctx context.Context, pub *amqp.Publishing) context.Context {
106 pub.CorrelationId = cid
107 return ctx
108 }
109 }
110
111 // SetAckAfterEndpoint returns a SubscriberResponseFunc that prompts the service
112 // to Ack the Delivery object after successfully evaluating the endpoint,
113 // and before it encodes the response.
114 // It is designed to be used by Subscribers.
115 func SetAckAfterEndpoint(multiple bool) SubscriberResponseFunc {
116 return func(ctx context.Context,
117 deliv *amqp.Delivery,
118 ch Channel,
119 pub *amqp.Publishing,
120 ) context.Context {
121 deliv.Ack(multiple)
122 return ctx
123 }
124 }
125
126 func getPublishExchange(ctx context.Context) string {
127 if exchange := ctx.Value(ContextKeyExchange); exchange != nil {
128 return exchange.(string)
129 }
130 return ""
131 }
132
133 func getPublishKey(ctx context.Context) string {
134 if publishKey := ctx.Value(ContextKeyPublishKey); publishKey != nil {
135 return publishKey.(string)
136 }
137 return ""
138 }
139
140 func getNackSleepDuration(ctx context.Context) time.Duration {
141 if duration := ctx.Value(ContextKeyNackSleepDuration); duration != nil {
142 return duration.(time.Duration)
143 }
144 return 0
145 }
146
147 func getConsumeAutoAck(ctx context.Context) bool {
148 if autoAck := ctx.Value(ContextKeyAutoAck); autoAck != nil {
149 return autoAck.(bool)
150 }
151 return false
152 }
153
154 func getConsumeArgs(ctx context.Context) amqp.Table {
155 if args := ctx.Value(ContextKeyConsumeArgs); args != nil {
156 return args.(amqp.Table)
157 }
158 return nil
159 }
160
161 type contextKey int
162
163 const (
164 // ContextKeyExchange is the value of the reply Exchange in
165 // amqp.Publish.
166 ContextKeyExchange contextKey = iota
167 // ContextKeyPublishKey is the value of the ReplyTo field in
168 // amqp.Publish.
169 ContextKeyPublishKey
170 // ContextKeyNackSleepDuration is the duration to sleep for if the
171 // service Nack and requeues a message.
172 // This is to prevent sporadic send-resending of message
173 // when a message is constantly Nack'd and requeued.
174 ContextKeyNackSleepDuration
175 // ContextKeyAutoAck is the value of autoAck field when calling
176 // amqp.Channel.Consume.
177 ContextKeyAutoAck
178 // ContextKeyConsumeArgs is the value of consumeArgs field when calling
179 // amqp.Channel.Consume.
180 ContextKeyConsumeArgs
181 )
0 package amqp
1
2 import (
3 "context"
4 "encoding/json"
5 "time"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/go-kit/kit/log"
9 "github.com/streadway/amqp"
10 )
11
12 // Subscriber wraps an endpoint and provides a handler for AMQP Delivery messages.
13 type Subscriber struct {
14 e endpoint.Endpoint
15 dec DecodeRequestFunc
16 enc EncodeResponseFunc
17 before []RequestFunc
18 after []SubscriberResponseFunc
19 errorEncoder ErrorEncoder
20 logger log.Logger
21 }
22
23 // NewSubscriber constructs a new subscriber, which provides a handler
24 // for AMQP Delivery messages.
25 func NewSubscriber(
26 e endpoint.Endpoint,
27 dec DecodeRequestFunc,
28 enc EncodeResponseFunc,
29 options ...SubscriberOption,
30 ) *Subscriber {
31 s := &Subscriber{
32 e: e,
33 dec: dec,
34 enc: enc,
35 errorEncoder: DefaultErrorEncoder,
36 logger: log.NewNopLogger(),
37 }
38 for _, option := range options {
39 option(s)
40 }
41 return s
42 }
43
44 // SubscriberOption sets an optional parameter for subscribers.
45 type SubscriberOption func(*Subscriber)
46
47 // SubscriberBefore functions are executed on the publisher delivery object
48 // before the request is decoded.
49 func SubscriberBefore(before ...RequestFunc) SubscriberOption {
50 return func(s *Subscriber) { s.before = append(s.before, before...) }
51 }
52
53 // SubscriberAfter functions are executed on the subscriber reply after the
54 // endpoint is invoked, but before anything is published to the reply.
55 func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption {
56 return func(s *Subscriber) { s.after = append(s.after, after...) }
57 }
58
59 // SubscriberErrorEncoder is used to encode errors to the subscriber reply
60 // whenever they're encountered in the processing of a request. Clients can
61 // use this to provide custom error formatting. By default,
62 // errors will be published with the DefaultErrorEncoder.
63 func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption {
64 return func(s *Subscriber) { s.errorEncoder = ee }
65 }
66
67 // SubscriberErrorLogger is used to log non-terminal errors. By default, no errors
68 // are logged. This is intended as a diagnostic measure. Finer-grained control
69 // of error handling, including logging in more detail, should be performed in a
70 // custom SubscriberErrorEncoder which has access to the context.
71 func SubscriberErrorLogger(logger log.Logger) SubscriberOption {
72 return func(s *Subscriber) { s.logger = logger }
73 }
74
75 // ServeDelivery handles AMQP Delivery messages
76 // It is strongly recommended to use *amqp.Channel as the
77 // Channel interface implementation.
78 func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) {
79 return func(deliv *amqp.Delivery) {
80 ctx, cancel := context.WithCancel(context.Background())
81 defer cancel()
82
83 pub := amqp.Publishing{}
84
85 for _, f := range s.before {
86 ctx = f(ctx, &pub)
87 }
88
89 request, err := s.dec(ctx, deliv)
90 if err != nil {
91 s.logger.Log("err", err)
92 s.errorEncoder(ctx, err, deliv, ch, &pub)
93 return
94 }
95
96 response, err := s.e(ctx, request)
97 if err != nil {
98 s.logger.Log("err", err)
99 s.errorEncoder(ctx, err, deliv, ch, &pub)
100 return
101 }
102
103 for _, f := range s.after {
104 ctx = f(ctx, deliv, ch, &pub)
105 }
106
107 if err := s.enc(ctx, &pub, response); err != nil {
108 s.logger.Log("err", err)
109 s.errorEncoder(ctx, err, deliv, ch, &pub)
110 return
111 }
112
113 if err := s.publishResponse(ctx, deliv, ch, &pub); err != nil {
114 s.logger.Log("err", err)
115 s.errorEncoder(ctx, err, deliv, ch, &pub)
116 return
117 }
118 }
119
120 }
121
122 func (s Subscriber) publishResponse(
123 ctx context.Context,
124 deliv *amqp.Delivery,
125 ch Channel,
126 pub *amqp.Publishing,
127 ) error {
128 if pub.CorrelationId == "" {
129 pub.CorrelationId = deliv.CorrelationId
130 }
131
132 replyExchange := getPublishExchange(ctx)
133 replyTo := getPublishKey(ctx)
134 if replyTo == "" {
135 replyTo = deliv.ReplyTo
136 }
137
138 return ch.Publish(
139 replyExchange,
140 replyTo,
141 false, // mandatory
142 false, // immediate
143 *pub,
144 )
145 }
146
147 // EncodeJSONResponse marshals the response as JSON as part of the
148 // payload of the AMQP Publishing object.
149 func EncodeJSONResponse(
150 ctx context.Context,
151 pub *amqp.Publishing,
152 response interface{},
153 ) error {
154 b, err := json.Marshal(response)
155 if err != nil {
156 return err
157 }
158 pub.Body = b
159 return nil
160 }
161
162 // EncodeNopResponse is a response function that does nothing.
163 func EncodeNopResponse(
164 ctx context.Context,
165 pub *amqp.Publishing,
166 response interface{},
167 ) error {
168 return nil
169 }
170
171 // ErrorEncoder is responsible for encoding an error to the subscriber reply.
172 // Users are encouraged to use custom ErrorEncoders to encode errors to
173 // their replies, and will likely want to pass and check for their own error
174 // types.
175 type ErrorEncoder func(ctx context.Context,
176 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing)
177
178 // DefaultErrorEncoder simply ignores the message. It does not reply
179 // nor Ack/Nack the message.
180 func DefaultErrorEncoder(ctx context.Context,
181 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
182 }
183
184 // SingleNackRequeueErrorEncoder issues a Nack to the delivery with multiple flag set as false
185 // and requeue flag set as true. It does not reply the message.
186 func SingleNackRequeueErrorEncoder(ctx context.Context,
187 err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
188 deliv.Nack(
189 false, //multiple
190 true, //requeue
191 )
192 duration := getNackSleepDuration(ctx)
193 time.Sleep(duration)
194 }
195
196 // ReplyErrorEncoder serializes the error message as a DefaultErrorResponse
197 // JSON and sends the message to the ReplyTo address.
198 func ReplyErrorEncoder(
199 ctx context.Context,
200 err error,
201 deliv *amqp.Delivery,
202 ch Channel,
203 pub *amqp.Publishing,
204 ) {
205
206 if pub.CorrelationId == "" {
207 pub.CorrelationId = deliv.CorrelationId
208 }
209
210 replyExchange := getPublishExchange(ctx)
211 replyTo := getPublishKey(ctx)
212 if replyTo == "" {
213 replyTo = deliv.ReplyTo
214 }
215
216 response := DefaultErrorResponse{err.Error()}
217
218 b, err := json.Marshal(response)
219 if err != nil {
220 return
221 }
222 pub.Body = b
223
224 ch.Publish(
225 replyExchange,
226 replyTo,
227 false, // mandatory
228 false, // immediate
229 *pub,
230 )
231 }
232
233 // ReplyAndAckErrorEncoder serializes the error message as a DefaultErrorResponse
234 // JSON and sends the message to the ReplyTo address then Acks the original
235 // message.
236 func ReplyAndAckErrorEncoder(ctx context.Context, err error, deliv *amqp.Delivery, ch Channel, pub *amqp.Publishing) {
237 ReplyErrorEncoder(ctx, err, deliv, ch, pub)
238 deliv.Ack(false)
239 }
240
241 // DefaultErrorResponse is the default structure of responses in the event
242 // of an error.
243 type DefaultErrorResponse struct {
244 Error string `json:"err"`
245 }
246
247 // Channel is a channel interface to make testing possible.
248 // It is highly recommended to use *amqp.Channel as the interface implementation.
249 type Channel interface {
250 Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error
251 Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error)
252 }
0 package amqp_test
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "testing"
7 "time"
8
9 amqptransport "github.com/go-kit/kit/transport/amqp"
10 "github.com/streadway/amqp"
11 )
12
13 var (
14 typeAssertionError = errors.New("type assertion error")
15 )
16
17 // mockChannel is a mock of *amqp.Channel.
18 type mockChannel struct {
19 f func(exchange, key string, mandatory, immediate bool)
20 c chan<- amqp.Publishing
21 deliveries []amqp.Delivery
22 }
23
24 // Publish runs a test function f and sends resultant message to a channel.
25 func (ch *mockChannel) Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error {
26 ch.f(exchange, key, mandatory, immediate)
27 ch.c <- msg
28 return nil
29 }
30
31 var nullFunc = func(exchange, key string, mandatory, immediate bool) {
32 }
33
34 func (ch *mockChannel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) {
35 c := make(chan amqp.Delivery, len(ch.deliveries))
36 for _, d := range ch.deliveries {
37 c <- d
38 }
39 return c, nil
40 }
41
42 // TestSubscriberBadDecode checks if decoder errors are handled properly.
43 func TestSubscriberBadDecode(t *testing.T) {
44 sub := amqptransport.NewSubscriber(
45 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
46 func(context.Context, *amqp.Delivery) (interface{}, error) { return nil, errors.New("err!") },
47 func(context.Context, *amqp.Publishing, interface{}) error {
48 return nil
49 },
50 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
51 )
52
53 outputChan := make(chan amqp.Publishing, 1)
54 ch := &mockChannel{f: nullFunc, c: outputChan}
55 sub.ServeDelivery(ch)(&amqp.Delivery{})
56
57 var msg amqp.Publishing
58 select {
59 case msg = <-outputChan:
60 break
61
62 case <-time.After(100 * time.Millisecond):
63 t.Fatal("Timed out waiting for publishing")
64 }
65 res, err := decodeSubscriberError(msg)
66 if err != nil {
67 t.Fatal(err)
68 }
69 if want, have := "err!", res.Error; want != have {
70 t.Errorf("want %s, have %s", want, have)
71 }
72 }
73
74 // TestSubscriberBadEndpoint checks if endpoint errors are handled properly.
75 func TestSubscriberBadEndpoint(t *testing.T) {
76 sub := amqptransport.NewSubscriber(
77 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") },
78 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
79 func(context.Context, *amqp.Publishing, interface{}) error {
80 return nil
81 },
82 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
83 )
84
85 outputChan := make(chan amqp.Publishing, 1)
86 ch := &mockChannel{f: nullFunc, c: outputChan}
87 sub.ServeDelivery(ch)(&amqp.Delivery{})
88
89 var msg amqp.Publishing
90
91 select {
92 case msg = <-outputChan:
93 break
94
95 case <-time.After(100 * time.Millisecond):
96 t.Fatal("Timed out waiting for publishing")
97 }
98
99 res, err := decodeSubscriberError(msg)
100 if err != nil {
101 t.Fatal(err)
102 }
103 if want, have := "err!", res.Error; want != have {
104 t.Errorf("want %s, have %s", want, have)
105 }
106 }
107
108 // TestSubscriberBadEncoder checks if encoder errors are handled properly.
109 func TestSubscriberBadEncoder(t *testing.T) {
110 sub := amqptransport.NewSubscriber(
111 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
112 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
113 func(context.Context, *amqp.Publishing, interface{}) error {
114 return errors.New("err!")
115 },
116 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
117 )
118
119 outputChan := make(chan amqp.Publishing, 1)
120 ch := &mockChannel{f: nullFunc, c: outputChan}
121 sub.ServeDelivery(ch)(&amqp.Delivery{})
122
123 var msg amqp.Publishing
124
125 select {
126 case msg = <-outputChan:
127 break
128
129 case <-time.After(100 * time.Millisecond):
130 t.Fatal("Timed out waiting for publishing")
131 }
132
133 res, err := decodeSubscriberError(msg)
134 if err != nil {
135 t.Fatal(err)
136 }
137 if want, have := "err!", res.Error; want != have {
138 t.Errorf("want %s, have %s", want, have)
139 }
140 }
141
142 // TestSubscriberSuccess checks if CorrelationId and ReplyTo are set properly
143 // and if the payload is encoded properly.
144 func TestSubscriberSuccess(t *testing.T) {
145 cid := "correlation"
146 replyTo := "sender"
147 obj := testReq{
148 Squadron: 436,
149 }
150 b, err := json.Marshal(obj)
151 if err != nil {
152 t.Fatal(err)
153 }
154
155 sub := amqptransport.NewSubscriber(
156 testEndpoint,
157 testReqDecoder,
158 amqptransport.EncodeJSONResponse,
159 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
160 )
161
162 checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) {
163 if want, have := replyTo, key; want != have {
164 t.Errorf("want %s, have %s", want, have)
165 }
166 }
167
168 outputChan := make(chan amqp.Publishing, 1)
169 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
170 sub.ServeDelivery(ch)(&amqp.Delivery{
171 CorrelationId: cid,
172 ReplyTo: replyTo,
173 Body: b,
174 })
175
176 var msg amqp.Publishing
177
178 select {
179 case msg = <-outputChan:
180 break
181
182 case <-time.After(100 * time.Millisecond):
183 t.Fatal("Timed out waiting for publishing")
184 }
185
186 if want, have := cid, msg.CorrelationId; want != have {
187 t.Errorf("want %s, have %s", want, have)
188 }
189
190 // check if error is not thrown
191 errRes, err := decodeSubscriberError(msg)
192 if err != nil {
193 t.Fatal(err)
194 }
195 if errRes.Error != "" {
196 t.Error("Received error from subscriber", errRes.Error)
197 return
198 }
199
200 // check obj vals
201 response, err := testResDecoder(msg.Body)
202 if err != nil {
203 t.Fatal(err)
204 }
205 res, ok := response.(testRes)
206 if !ok {
207 t.Error(typeAssertionError)
208 }
209
210 if want, have := obj.Squadron, res.Squadron; want != have {
211 t.Errorf("want %d, have %d", want, have)
212 }
213 if want, have := names[obj.Squadron], res.Name; want != have {
214 t.Errorf("want %s, have %s", want, have)
215 }
216 }
217
218 // TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode
219 // are working.
220 func TestSubscriberMultipleBefore(t *testing.T) {
221 exchange := "some exchange"
222 key := "some key"
223 deliveryMode := uint8(127)
224 contentType := "some content type"
225 contentEncoding := "some content encoding"
226 sub := amqptransport.NewSubscriber(
227 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
228 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
229 amqptransport.EncodeJSONResponse,
230 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
231 amqptransport.SubscriberBefore(
232 amqptransport.SetPublishExchange(exchange),
233 amqptransport.SetPublishKey(key),
234 amqptransport.SetPublishDeliveryMode(deliveryMode),
235 amqptransport.SetContentType(contentType),
236 amqptransport.SetContentEncoding(contentEncoding),
237 ),
238 )
239 checkReplyToFunc := func(exch, k string, mandatory, immediate bool) {
240 if want, have := exchange, exch; want != have {
241 t.Errorf("want %s, have %s", want, have)
242 }
243 if want, have := key, k; want != have {
244 t.Errorf("want %s, have %s", want, have)
245 }
246 }
247
248 outputChan := make(chan amqp.Publishing, 1)
249 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
250 sub.ServeDelivery(ch)(&amqp.Delivery{})
251
252 var msg amqp.Publishing
253
254 select {
255 case msg = <-outputChan:
256 break
257
258 case <-time.After(100 * time.Millisecond):
259 t.Fatal("Timed out waiting for publishing")
260 }
261
262 // check if error is not thrown
263 errRes, err := decodeSubscriberError(msg)
264 if err != nil {
265 t.Fatal(err)
266 }
267 if errRes.Error != "" {
268 t.Error("Received error from subscriber", errRes.Error)
269 return
270 }
271
272 if want, have := contentType, msg.ContentType; want != have {
273 t.Errorf("want %s, have %s", want, have)
274 }
275
276 if want, have := contentEncoding, msg.ContentEncoding; want != have {
277 t.Errorf("want %s, have %s", want, have)
278 }
279
280 if want, have := deliveryMode, msg.DeliveryMode; want != have {
281 t.Errorf("want %d, have %d", want, have)
282 }
283 }
284
285 // TestDefaultContentMetaData checks that default ContentType and Content-Encoding
286 // is not set as mentioned by AMQP specification.
287 func TestDefaultContentMetaData(t *testing.T) {
288 defaultContentType := ""
289 defaultContentEncoding := ""
290 sub := amqptransport.NewSubscriber(
291 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
292 func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
293 amqptransport.EncodeJSONResponse,
294 amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
295 )
296 checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { return }
297 outputChan := make(chan amqp.Publishing, 1)
298 ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
299 sub.ServeDelivery(ch)(&amqp.Delivery{})
300
301 var msg amqp.Publishing
302
303 select {
304 case msg = <-outputChan:
305 break
306
307 case <-time.After(100 * time.Millisecond):
308 t.Fatal("Timed out waiting for publishing")
309 }
310
311 // check if error is not thrown
312 errRes, err := decodeSubscriberError(msg)
313 if err != nil {
314 t.Fatal(err)
315 }
316 if errRes.Error != "" {
317 t.Error("Received error from subscriber", errRes.Error)
318 return
319 }
320
321 if want, have := defaultContentType, msg.ContentType; want != have {
322 t.Errorf("want %s, have %s", want, have)
323 }
324 if want, have := defaultContentEncoding, msg.ContentEncoding; want != have {
325 t.Errorf("want %s, have %s", want, have)
326 }
327 }
328
329 func decodeSubscriberError(pub amqp.Publishing) (amqptransport.DefaultErrorResponse, error) {
330 var res amqptransport.DefaultErrorResponse
331 err := json.Unmarshal(pub.Body, &res)
332 return res, err
333 }
334
335 type testReq struct {
336 Squadron int `json:"s"`
337 }
338 type testRes struct {
339 Squadron int `json:"s"`
340 Name string `json:"n"`
341 }
342
343 func testEndpoint(_ context.Context, request interface{}) (interface{}, error) {
344 req, ok := request.(testReq)
345 if !ok {
346 return nil, typeAssertionError
347 }
348 name, prs := names[req.Squadron]
349 if !prs {
350 return nil, errors.New("unknown squadron name")
351 }
352 res := testRes{
353 Squadron: req.Squadron,
354 Name: name,
355 }
356 return res, nil
357 }
358
359 func testReqDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
360 var obj testReq
361 err := json.Unmarshal(d.Body, &obj)
362 return obj, err
363 }
364
365 func testReqEncoder(_ context.Context, p *amqp.Publishing, request interface{}) error {
366 req, ok := request.(testReq)
367 if !ok {
368 return errors.New("type assertion failure")
369 }
370 b, err := json.Marshal(req)
371 if err != nil {
372 return err
373 }
374 p.Body = b
375 return nil
376 }
377
378 func testResDeliveryDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
379 return testResDecoder(d.Body)
380 }
381
382 func testResDecoder(b []byte) (interface{}, error) {
383 var obj testRes
384 err := json.Unmarshal(b, &obj)
385 return obj, err
386 }
387
388 var names = map[int]string{
389 424: "tiger",
390 426: "thunderbird",
391 429: "bison",
392 436: "tusker",
393 437: "husky",
394 }
0 package amqp
1
2 import (
3 "math/rand"
4 )
5
6 func randomString(l int) string {
7 bytes := make([]byte, l)
8 for i := 0; i < l; i++ {
9 bytes[i] = byte(randInt(65, 90))
10 }
11 return string(bytes)
12 }
13
14 func randInt(min int, max int) int {
15 return min + rand.Intn(max-min)
16 }
6868 func EncodeKeyValue(key, val string) (string, string) {
6969 key = strings.ToLower(key)
7070 if strings.HasSuffix(key, binHdrSuffix) {
71 v := base64.StdEncoding.EncodeToString([]byte(val))
72 val = string(v)
71 val = base64.StdEncoding.EncodeToString([]byte(val))
7372 }
7473 return key, val
7574 }
5454 // ServerOption sets an optional parameter for servers.
5555 type ServerOption func(*Server)
5656
57 // ServerBefore functions are executed on the HTTP request object before the
57 // ServerBefore functions are executed on the gRPC request object before the
5858 // request is decoded.
5959 func ServerBefore(before ...ServerRequestFunc) ServerOption {
6060 return func(s *Server) { s.before = append(s.before, before...) }
6161 }
6262
63 // ServerAfter functions are executed on the HTTP response writer after the
63 // ServerAfter functions are executed on the gRPC response writer after the
6464 // endpoint is invoked, but before anything is written to the client.
6565 func ServerAfter(after ...ServerResponseFunc) ServerOption {
6666 return func(s *Server) { s.after = append(s.after, after...) }
44 "context"
55 "encoding/json"
66 "encoding/xml"
7 "io"
78 "io/ioutil"
89 "net/http"
910 "net/url"
1112 "github.com/go-kit/kit/endpoint"
1213 )
1314
15 // HTTPClient is an interface that models *http.Client.
16 type HTTPClient interface {
17 Do(req *http.Request) (*http.Response, error)
18 }
19
1420 // Client wraps a URL and provides a method that implements endpoint.Endpoint.
1521 type Client struct {
16 client *http.Client
22 client HTTPClient
1723 method string
1824 tgt *url.URL
1925 enc EncodeRequestFunc
5359
5460 // SetClient sets the underlying HTTP client used for requests.
5561 // By default, http.DefaultClient is used.
56 func SetClient(client *http.Client) ClientOption {
62 func SetClient(client HTTPClient) ClientOption {
5763 return func(c *Client) { c.client = client }
5864 }
5965
7884
7985 // BufferedStream sets whether the Response.Body is left open, allowing it
8086 // to be read from later. Useful for transporting a file as a buffered stream.
87 // That body has to be Closed to propery end the request.
8188 func BufferedStream(buffered bool) ClientOption {
8289 return func(c *Client) { c.bufferedStream = buffered }
8390 }
8693 func (c Client) Endpoint() endpoint.Endpoint {
8794 return func(ctx context.Context, request interface{}) (interface{}, error) {
8895 ctx, cancel := context.WithCancel(ctx)
89 defer cancel()
9096
9197 var (
9298 resp *http.Response
106112
107113 req, err := http.NewRequest(c.method, c.tgt.String(), nil)
108114 if err != nil {
115 cancel()
109116 return nil, err
110117 }
111118
112119 if err = c.enc(ctx, req, request); err != nil {
120 cancel()
113121 return nil, err
114122 }
115123
120128 resp, err = c.client.Do(req.WithContext(ctx))
121129
122130 if err != nil {
123 return nil, err
124 }
125
126 if !c.bufferedStream {
131 cancel()
132 return nil, err
133 }
134
135 // If we expect a buffered stream, we don't cancel the context when the endpoint returns.
136 // Instead, we should call the cancel func when closing the response body.
137 if c.bufferedStream {
138 resp.Body = bodyWithCancel{ReadCloser: resp.Body, cancel: cancel}
139 } else {
127140 defer resp.Body.Close()
141 defer cancel()
128142 }
129143
130144 for _, f := range c.after {
138152
139153 return response, nil
140154 }
155 }
156
157 // bodyWithCancel is a wrapper for an io.ReadCloser with also a
158 // cancel function which is called when the Close is used
159 type bodyWithCancel struct {
160 io.ReadCloser
161
162 cancel context.CancelFunc
163 }
164
165 func (bwc bodyWithCancel) Close() error {
166 bwc.ReadCloser.Close()
167 bwc.cancel()
168 return nil
141169 }
142170
143171 // ClientFinalizerFunc can be used to perform work at the end of a client HTTP
00 package http_test
11
22 import (
3 "bytes"
34 "context"
45 "io"
56 "io/ioutil"
9697 }
9798
9899 func TestHTTPClientBufferedStream(t *testing.T) {
100 // bodysize has a size big enought to make the resopnse.Body not an instant read
101 // so if the response is cancelled it wount be all readed and the test would fail
102 // The 6000 has not a particular meaning, it big enough to fulfill the usecase.
103 const bodysize = 6000
99104 var (
100 testbody = "testbody"
105 testbody = string(make([]byte, bodysize))
101106 encode = func(context.Context, *http.Request, interface{}) error { return nil }
102107 decode = func(_ context.Context, r *http.Response) (interface{}, error) {
103108 return TestResponse{r.Body, ""}, nil
127132 if !ok {
128133 t.Fatal("response should be TestResponse")
129134 }
135 defer response.Body.Close()
136 // Faking work
137 time.Sleep(time.Second * 1)
130138
131139 // Check that response body was NOT closed
132140 b := make([]byte, len(testbody))
251259 }
252260 }
253261
262 func TestSetClient(t *testing.T) {
263 var (
264 encode = func(context.Context, *http.Request, interface{}) error { return nil }
265 decode = func(_ context.Context, r *http.Response) (interface{}, error) {
266 t, err := ioutil.ReadAll(r.Body)
267 if err != nil {
268 return nil, err
269 }
270 return string(t), nil
271 }
272 )
273
274 testHttpClient := httpClientFunc(func(req *http.Request) (*http.Response, error) {
275 return &http.Response{
276 StatusCode: http.StatusOK,
277 Request: req,
278 Body: ioutil.NopCloser(bytes.NewBufferString("hello, world!")),
279 }, nil
280 })
281
282 client := httptransport.NewClient(
283 "GET",
284 &url.URL{},
285 encode,
286 decode,
287 httptransport.SetClient(testHttpClient),
288 ).Endpoint()
289
290 resp, err := client(context.Background(), nil)
291 if err != nil {
292 t.Fatal(err)
293 }
294 if r, ok := resp.(string); !ok || r != "hello, world!" {
295 t.Fatal("Expected response to be 'hello, world!' string")
296 }
297 }
298
254299 func mustParse(s string) *url.URL {
255300 u, err := url.Parse(s)
256301 if err != nil {
264309 }
265310
266311 func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
312
313 type httpClientFunc func(req *http.Request) (*http.Response, error)
314
315 func (f httpClientFunc) Do(req *http.Request) (*http.Response, error) {
316 return f(req)
317 }
77 // DecodeRequestFunc extracts a user-domain request object from an HTTP
88 // request object. It's designed to be used in HTTP servers, for server-side
99 // endpoints. One straightforward DecodeRequestFunc could be something that
10 // JSON decodes from the request body to the concrete response type.
10 // JSON decodes from the request body to the concrete request type.
1111 type DecodeRequestFunc func(context.Context, *http.Request) (request interface{}, err error)
1212
1313 // EncodeRequestFunc encodes the passed request object into the HTTP request
1414 // object. It's designed to be used in HTTP clients, for client-side
15 // endpoints. One straightforward EncodeRequestFunc could something that JSON
15 // endpoints. One straightforward EncodeRequestFunc could be something that JSON
1616 // encodes the object directly to the request body.
1717 type EncodeRequestFunc func(context.Context, *http.Request, interface{}) error
1818
8686
8787 {
8888 "jsonrpc": "2.0",
89 "result": 4,
90 "error": null
89 "result": 4
9190 }
1414
1515 // Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint.
1616 type Client struct {
17 client *http.Client
17 client httptransport.HTTPClient
1818
1919 // JSON RPC endpoint URL
2020 tgt *url.URL
8585
8686 // SetClient sets the underlying HTTP client used for requests.
8787 // By default, http.DefaultClient is used.
88 func SetClient(client *http.Client) ClientOption {
88 func SetClient(client httptransport.HTTPClient) ClientOption {
8989 return func(c *Client) { c.client = client }
9090 }
9191
3434 return nil
3535 }
3636
37 func (id *RequestID) MarshalJSON() ([]byte, error) {
38 if id.intError == nil {
39 return json.Marshal(id.intValue)
40 } else if id.floatError == nil {
41 return json.Marshal(id.floatValue)
42 } else {
43 return json.Marshal(id.stringValue)
44 }
45 }
46
3747 // Int returns the ID as an integer value.
3848 // An error is returned if the ID can't be treated as an int.
3949 func (id *RequestID) Int() (int, error) {
5767 type Response struct {
5868 JSONRPC string `json:"jsonrpc"`
5969 Result json.RawMessage `json:"result,omitempty"`
60 Error *Error `json:"error,omitemty"`
70 Error *Error `json:"error,omitempty"`
71 ID *RequestID `json:"id"`
6172 }
6273
6374 const (
108108 t.Fatalf("Expected ID to be nil, got %+v.\n", r.ID)
109109 }
110110 }
111
112 func TestCanMarshalID(t *testing.T) {
113 cases := []struct {
114 JSON string
115 expType string
116 expValue interface{}
117 }{
118 {`12345`, "int", 12345},
119 {`12345.6`, "float", 12345.6},
120 {`"stringaling"`, "string", "stringaling"},
121 {`null`, "null", nil},
122 }
123
124 for _, c := range cases {
125 req := jsonrpc.Request{}
126 JSON := fmt.Sprintf(`{"jsonrpc":"2.0","id":%s}`, c.JSON)
127 json.Unmarshal([]byte(JSON), &req)
128 resp := jsonrpc.Response{ID: req.ID, JSONRPC: req.JSONRPC}
129
130 want := JSON
131 bol, _ := json.Marshal(resp)
132 got := string(bol)
133 if got != want {
134 t.Fatalf("'%s': want %s, got %s.", c.expType, want, got)
135 }
136 }
137 }
135135 }
136136
137137 res := Response{
138 ID: req.ID,
138139 JSONRPC: Version,
139140 }
140141
0 // Package nats provides a NATS transport.
1 package nats
0 package nats
1
2 import (
3 "context"
4
5 "github.com/nats-io/go-nats"
6 )
7
8 // DecodeRequestFunc extracts a user-domain request object from a publisher
9 // request object. It's designed to be used in NATS subscribers, for subscriber-side
10 // endpoints. One straightforward DecodeRequestFunc could be something that
11 // JSON decodes from the request body to the concrete response type.
12 type DecodeRequestFunc func(context.Context, *nats.Msg) (request interface{}, err error)
13
14 // EncodeRequestFunc encodes the passed request object into the NATS request
15 // object. It's designed to be used in NATS publishers, for publisher-side
16 // endpoints. One straightforward EncodeRequestFunc could something that JSON
17 // encodes the object directly to the request payload.
18 type EncodeRequestFunc func(context.Context, *nats.Msg, interface{}) error
19
20 // EncodeResponseFunc encodes the passed response object to the subscriber reply.
21 // It's designed to be used in NATS subscribers, for subscriber-side
22 // endpoints. One straightforward EncodeResponseFunc could be something that
23 // JSON encodes the object directly to the response body.
24 type EncodeResponseFunc func(context.Context, string, *nats.Conn, interface{}) error
25
26 // DecodeResponseFunc extracts a user-domain response object from an NATS
27 // response object. It's designed to be used in NATS publisher, for publisher-side
28 // endpoints. One straightforward DecodeResponseFunc could be something that
29 // JSON decodes from the response payload to the concrete response type.
30 type DecodeResponseFunc func(context.Context, *nats.Msg) (response interface{}, err error)
31
0 package nats
1
2 import (
3 "context"
4 "encoding/json"
5 "github.com/go-kit/kit/endpoint"
6 "github.com/nats-io/go-nats"
7 "time"
8 )
9
10 // Publisher wraps a URL and provides a method that implements endpoint.Endpoint.
11 type Publisher struct {
12 publisher *nats.Conn
13 subject string
14 enc EncodeRequestFunc
15 dec DecodeResponseFunc
16 before []RequestFunc
17 after []PublisherResponseFunc
18 timeout time.Duration
19 }
20
21 // NewPublisher constructs a usable Publisher for a single remote method.
22 func NewPublisher(
23 publisher *nats.Conn,
24 subject string,
25 enc EncodeRequestFunc,
26 dec DecodeResponseFunc,
27 options ...PublisherOption,
28 ) *Publisher {
29 p := &Publisher{
30 publisher: publisher,
31 subject: subject,
32 enc: enc,
33 dec: dec,
34 timeout: 10 * time.Second,
35 }
36 for _, option := range options {
37 option(p)
38 }
39 return p
40 }
41
42 // PublisherOption sets an optional parameter for clients.
43 type PublisherOption func(*Publisher)
44
45 // PublisherBefore sets the RequestFuncs that are applied to the outgoing NATS
46 // request before it's invoked.
47 func PublisherBefore(before ...RequestFunc) PublisherOption {
48 return func(p *Publisher) { p.before = append(p.before, before...) }
49 }
50
51 // PublisherAfter sets the ClientResponseFuncs applied to the incoming NATS
52 // request prior to it being decoded. This is useful for obtaining anything off
53 // of the response and adding onto the context prior to decoding.
54 func PublisherAfter(after ...PublisherResponseFunc) PublisherOption {
55 return func(p *Publisher) { p.after = append(p.after, after...) }
56 }
57
58 // PublisherTimeout sets the available timeout for NATS request.
59 func PublisherTimeout(timeout time.Duration) PublisherOption {
60 return func(p *Publisher) { p.timeout = timeout }
61 }
62
63 // Endpoint returns a usable endpoint that invokes the remote endpoint.
64 func (p Publisher) Endpoint() endpoint.Endpoint {
65 return func(ctx context.Context, request interface{}) (interface{}, error) {
66 ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
67 defer cancel()
68
69 msg := nats.Msg{Subject: p.subject}
70
71 if err := p.enc(ctx, &msg, request); err != nil {
72 return nil, err
73 }
74
75 for _, f := range p.before {
76 ctx = f(ctx, &msg)
77 }
78
79 resp, err := p.publisher.RequestWithContext(ctx, msg.Subject, msg.Data)
80 if err != nil {
81 return nil, err
82 }
83
84 for _, f := range p.after {
85 ctx = f(ctx, resp)
86 }
87
88 response, err := p.dec(ctx, resp)
89 if err != nil {
90 return nil, err
91 }
92
93 return response, nil
94 }
95 }
96
97 // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a
98 // JSON object to the Data of the Msg. Many JSON-over-NATS services can use it as
99 // a sensible default.
100 func EncodeJSONRequest(_ context.Context, msg *nats.Msg, request interface{}) error {
101 b, err := json.Marshal(request)
102 if err != nil {
103 return err
104 }
105
106 msg.Data = b
107
108 return nil
109 }
0 package nats_test
1
2 import (
3 "context"
4 "strings"
5 "testing"
6 "time"
7
8 natstransport "github.com/go-kit/kit/transport/nats"
9 "github.com/nats-io/go-nats"
10 )
11
12 func TestPublisher(t *testing.T) {
13 var (
14 testdata = "testdata"
15 encode = func(context.Context, *nats.Msg, interface{}) error { return nil }
16 decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) {
17 return TestResponse{string(msg.Data), ""}, nil
18 }
19 )
20
21 nc := newNatsConn(t)
22 defer nc.Close()
23
24 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) {
25 if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil {
26 t.Fatal(err)
27 }
28 })
29 if err != nil {
30 t.Fatal(err)
31 }
32 defer sub.Unsubscribe()
33
34 publisher := natstransport.NewPublisher(
35 nc,
36 "natstransport.test",
37 encode,
38 decode,
39 )
40
41 res, err := publisher.Endpoint()(context.Background(), struct{}{})
42 if err != nil {
43 t.Fatal(err)
44 }
45
46 response, ok := res.(TestResponse)
47 if !ok {
48 t.Fatal("response should be TestResponse")
49 }
50 if want, have := testdata, response.String; want != have {
51 t.Errorf("want %q, have %q", want, have)
52 }
53
54 }
55
56 func TestPublisherBefore(t *testing.T) {
57 var (
58 testdata = "testdata"
59 encode = func(context.Context, *nats.Msg, interface{}) error { return nil }
60 decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) {
61 return TestResponse{string(msg.Data), ""}, nil
62 }
63 )
64
65 nc := newNatsConn(t)
66 defer nc.Close()
67
68 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) {
69 if err := nc.Publish(msg.Reply, msg.Data); err != nil {
70 t.Fatal(err)
71 }
72 })
73 if err != nil {
74 t.Fatal(err)
75 }
76 defer sub.Unsubscribe()
77
78 publisher := natstransport.NewPublisher(
79 nc,
80 "natstransport.test",
81 encode,
82 decode,
83 natstransport.PublisherBefore(func(ctx context.Context, msg *nats.Msg) context.Context {
84 msg.Data = []byte(strings.ToUpper(string(testdata)))
85 return ctx
86 }),
87 )
88
89 res, err := publisher.Endpoint()(context.Background(), struct{}{})
90 if err != nil {
91 t.Fatal(err)
92 }
93
94 response, ok := res.(TestResponse)
95 if !ok {
96 t.Fatal("response should be TestResponse")
97 }
98 if want, have := strings.ToUpper(testdata), response.String; want != have {
99 t.Errorf("want %q, have %q", want, have)
100 }
101
102 }
103
104 func TestPublisherAfter(t *testing.T) {
105 var (
106 testdata = "testdata"
107 encode = func(context.Context, *nats.Msg, interface{}) error { return nil }
108 decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) {
109 return TestResponse{string(msg.Data), ""}, nil
110 }
111 )
112
113 nc := newNatsConn(t)
114 defer nc.Close()
115
116 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) {
117 if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil {
118 t.Fatal(err)
119 }
120 })
121 if err != nil {
122 t.Fatal(err)
123 }
124 defer sub.Unsubscribe()
125
126 publisher := natstransport.NewPublisher(
127 nc,
128 "natstransport.test",
129 encode,
130 decode,
131 natstransport.PublisherAfter(func(ctx context.Context, msg *nats.Msg) context.Context {
132 msg.Data = []byte(strings.ToUpper(string(msg.Data)))
133 return ctx
134 }),
135 )
136
137 res, err := publisher.Endpoint()(context.Background(), struct{}{})
138 if err != nil {
139 t.Fatal(err)
140 }
141
142 response, ok := res.(TestResponse)
143 if !ok {
144 t.Fatal("response should be TestResponse")
145 }
146 if want, have := strings.ToUpper(testdata), response.String; want != have {
147 t.Errorf("want %q, have %q", want, have)
148 }
149
150 }
151
152 func TestPublisherTimeout(t *testing.T) {
153 var (
154 encode = func(context.Context, *nats.Msg, interface{}) error { return nil }
155 decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) {
156 return TestResponse{string(msg.Data), ""}, nil
157 }
158 )
159
160 nc := newNatsConn(t)
161 defer nc.Close()
162
163 ch := make(chan struct{})
164 defer close(ch)
165
166 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) {
167 <-ch
168 })
169 if err != nil {
170 t.Fatal(err)
171 }
172 defer sub.Unsubscribe()
173
174 publisher := natstransport.NewPublisher(
175 nc,
176 "natstransport.test",
177 encode,
178 decode,
179 natstransport.PublisherTimeout(time.Second),
180 )
181
182 _, err = publisher.Endpoint()(context.Background(), struct{}{})
183 if err != context.DeadlineExceeded {
184 t.Errorf("want %s, have %s", context.DeadlineExceeded, err)
185 }
186 }
187
188 func TestEncodeJSONRequest(t *testing.T) {
189 var data string
190
191 nc := newNatsConn(t)
192 defer nc.Close()
193
194 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) {
195 data = string(msg.Data)
196
197 if err := nc.Publish(msg.Reply, []byte("")); err != nil {
198 t.Fatal(err)
199 }
200 })
201 if err != nil {
202 t.Fatal(err)
203 }
204 defer sub.Unsubscribe()
205
206 publisher := natstransport.NewPublisher(
207 nc,
208 "natstransport.test",
209 natstransport.EncodeJSONRequest,
210 func(context.Context, *nats.Msg) (interface{}, error) { return nil, nil },
211 ).Endpoint()
212
213 for _, test := range []struct {
214 value interface{}
215 body string
216 }{
217 {nil, "null"},
218 {12, "12"},
219 {1.2, "1.2"},
220 {true, "true"},
221 {"test", "\"test\""},
222 {struct {
223 Foo string `json:"foo"`
224 }{"foo"}, "{\"foo\":\"foo\"}"},
225 } {
226 if _, err := publisher(context.Background(), test.value); err != nil {
227 t.Fatal(err)
228 continue
229 }
230
231 if data != test.body {
232 t.Errorf("%v: actual %#v, expected %#v", test.value, data, test.body)
233 }
234 }
235
236 }
0 package nats
1
2 import (
3 "context"
4
5 "github.com/nats-io/go-nats"
6 )
7
8 // RequestFunc may take information from a publisher request and put it into a
9 // request context. In Subscribers, RequestFuncs are executed prior to invoking the
10 // endpoint.
11 type RequestFunc func(context.Context, *nats.Msg) context.Context
12
13 // SubscriberResponseFunc may take information from a request context and use it to
14 // manipulate a Publisher. SubscriberResponseFuncs are only executed in
15 // subscribers, after invoking the endpoint but prior to publishing a reply.
16 type SubscriberResponseFunc func(context.Context, *nats.Conn) context.Context
17
18 // PublisherResponseFunc may take information from an NATS request and make the
19 // response available for consumption. ClientResponseFuncs are only executed in
20 // clients, after a request has been made, but prior to it being decoded.
21 type PublisherResponseFunc func(context.Context, *nats.Msg) context.Context
0 package nats
1
2 import (
3 "context"
4 "encoding/json"
5
6 "github.com/go-kit/kit/endpoint"
7 "github.com/go-kit/kit/log"
8
9 "github.com/nats-io/go-nats"
10 )
11
12 // Subscriber wraps an endpoint and provides nats.MsgHandler.
13 type Subscriber struct {
14 e endpoint.Endpoint
15 dec DecodeRequestFunc
16 enc EncodeResponseFunc
17 before []RequestFunc
18 after []SubscriberResponseFunc
19 errorEncoder ErrorEncoder
20 finalizer []SubscriberFinalizerFunc
21 logger log.Logger
22 }
23
24 // NewSubscriber constructs a new subscriber, which provides nats.MsgHandler and wraps
25 // the provided endpoint.
26 func NewSubscriber(
27 e endpoint.Endpoint,
28 dec DecodeRequestFunc,
29 enc EncodeResponseFunc,
30 options ...SubscriberOption,
31 ) *Subscriber {
32 s := &Subscriber{
33 e: e,
34 dec: dec,
35 enc: enc,
36 errorEncoder: DefaultErrorEncoder,
37 logger: log.NewNopLogger(),
38 }
39 for _, option := range options {
40 option(s)
41 }
42 return s
43 }
44
45 // SubscriberOption sets an optional parameter for subscribers.
46 type SubscriberOption func(*Subscriber)
47
48 // SubscriberBefore functions are executed on the publisher request object before the
49 // request is decoded.
50 func SubscriberBefore(before ...RequestFunc) SubscriberOption {
51 return func(s *Subscriber) { s.before = append(s.before, before...) }
52 }
53
54 // SubscriberAfter functions are executed on the subscriber reply after the
55 // endpoint is invoked, but before anything is published to the reply.
56 func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption {
57 return func(s *Subscriber) { s.after = append(s.after, after...) }
58 }
59
60 // SubscriberErrorEncoder is used to encode errors to the subscriber reply
61 // whenever they're encountered in the processing of a request. Clients can
62 // use this to provide custom error formatting. By default,
63 // errors will be published with the DefaultErrorEncoder.
64 func SubscriberErrorEncoder(ee ErrorEncoder) SubscriberOption {
65 return func(s *Subscriber) { s.errorEncoder = ee }
66 }
67
68 // SubscriberErrorLogger is used to log non-terminal errors. By default, no errors
69 // are logged. This is intended as a diagnostic measure. Finer-grained control
70 // of error handling, including logging in more detail, should be performed in a
71 // custom SubscriberErrorEncoder which has access to the context.
72 func SubscriberErrorLogger(logger log.Logger) SubscriberOption {
73 return func(s *Subscriber) { s.logger = logger }
74 }
75
76 // SubscriberFinalizer is executed at the end of every request from a publisher through NATS.
77 // By default, no finalizer is registered.
78 func SubscriberFinalizer(f ...SubscriberFinalizerFunc) SubscriberOption {
79 return func(s *Subscriber) { s.finalizer = f }
80 }
81
82 // ServeMsg provides nats.MsgHandler.
83 func (s Subscriber) ServeMsg(nc *nats.Conn) func(msg *nats.Msg) {
84 return func(msg *nats.Msg) {
85 ctx, cancel := context.WithCancel(context.Background())
86 defer cancel()
87
88 if len(s.finalizer) > 0 {
89 defer func() {
90 for _, f := range s.finalizer {
91 f(ctx, msg)
92 }
93 }()
94 }
95
96 for _, f := range s.before {
97 ctx = f(ctx, msg)
98 }
99
100 request, err := s.dec(ctx, msg)
101 if err != nil {
102 s.logger.Log("err", err)
103 if msg.Reply == "" {
104 return
105 }
106 s.errorEncoder(ctx, err, msg.Reply, nc)
107 return
108 }
109
110 response, err := s.e(ctx, request)
111 if err != nil {
112 s.logger.Log("err", err)
113 if msg.Reply == "" {
114 return
115 }
116 s.errorEncoder(ctx, err, msg.Reply, nc)
117 return
118 }
119
120 for _, f := range s.after {
121 ctx = f(ctx, nc)
122 }
123
124 if msg.Reply == "" {
125 return
126 }
127
128 if err := s.enc(ctx, msg.Reply, nc, response); err != nil {
129 s.logger.Log("err", err)
130 s.errorEncoder(ctx, err, msg.Reply, nc)
131 return
132 }
133 }
134 }
135
136 // ErrorEncoder is responsible for encoding an error to the subscriber reply.
137 // Users are encouraged to use custom ErrorEncoders to encode errors to
138 // their replies, and will likely want to pass and check for their own error
139 // types.
140 type ErrorEncoder func(ctx context.Context, err error, reply string, nc *nats.Conn)
141
142 // ServerFinalizerFunc can be used to perform work at the end of an request
143 // from a publisher, after the response has been written to the publisher. The principal
144 // intended use is for request logging.
145 type SubscriberFinalizerFunc func(ctx context.Context, msg *nats.Msg)
146
147 // NopRequestDecoder is a DecodeRequestFunc that can be used for requests that do not
148 // need to be decoded, and simply returns nil, nil.
149 func NopRequestDecoder(_ context.Context, _ *nats.Msg) (interface{}, error) {
150 return nil, nil
151 }
152
153 // EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a
154 // JSON object to the subscriber reply. Many JSON-over services can use it as
155 // a sensible default.
156 func EncodeJSONResponse(_ context.Context, reply string, nc *nats.Conn, response interface{}) error {
157 b, err := json.Marshal(response)
158 if err != nil {
159 return err
160 }
161
162 return nc.Publish(reply, b)
163 }
164
165 // DefaultErrorEncoder writes the error to the subscriber reply.
166 func DefaultErrorEncoder(_ context.Context, err error, reply string, nc *nats.Conn) {
167 logger := log.NewNopLogger()
168
169 type Response struct {
170 Error string `json:"err"`
171 }
172
173 var response Response
174
175 response.Error = err.Error()
176
177 b, err := json.Marshal(response)
178 if err != nil {
179 logger.Log("err", err)
180 return
181 }
182
183 if err := nc.Publish(reply, b); err != nil {
184 logger.Log("err", err)
185 }
186 }
0 package nats_test
1
2 import (
3 "context"
4 "encoding/json"
5 "errors"
6 "strings"
7 "sync"
8 "testing"
9 "time"
10
11 "github.com/nats-io/gnatsd/server"
12 "github.com/nats-io/go-nats"
13
14 "github.com/go-kit/kit/endpoint"
15 natstransport "github.com/go-kit/kit/transport/nats"
16 )
17
18 type TestResponse struct {
19 String string `json:"str"`
20 Error string `json:"err"`
21 }
22
23 var natsServer *server.Server
24
25 func init() {
26 natsServer = server.New(&server.Options{
27 Host: "localhost",
28 Port: 4222,
29 })
30
31 go func() {
32 natsServer.Start()
33 }()
34
35 if ok := natsServer.ReadyForConnections(2 * time.Second); !ok {
36 panic("Failed start of NATS")
37 }
38 }
39
40 func newNatsConn(t *testing.T) *nats.Conn {
41 // Subscriptions and connections are closed asynchronously, so it's possible
42 // that there's still a subscription from an old connection that must be closed
43 // before the current test can be run.
44 for tries := 20; tries > 0; tries-- {
45 if natsServer.NumSubscriptions() == 0 {
46 break
47 }
48
49 time.Sleep(5 * time.Millisecond)
50 }
51
52 if n := natsServer.NumSubscriptions(); n > 0 {
53 t.Fatalf("found %d active subscriptions on the server", n)
54 }
55
56 nc, err := nats.Connect("nats://"+natsServer.Addr().String(), nats.Name(t.Name()))
57 if err != nil {
58 t.Fatalf("failed to connect to gnatsd server: %s", err)
59 }
60
61 return nc
62 }
63
64 func TestSubscriberBadDecode(t *testing.T) {
65 nc := newNatsConn(t)
66 defer nc.Close()
67
68 handler := natstransport.NewSubscriber(
69 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
70 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, errors.New("dang") },
71 func(context.Context, string, *nats.Conn, interface{}) error { return nil },
72 )
73
74 resp := testRequest(t, nc, handler)
75
76 if want, have := "dang", resp.Error; want != have {
77 t.Errorf("want %s, have %s", want, have)
78 }
79
80 }
81
82 func TestSubscriberBadEndpoint(t *testing.T) {
83 nc := newNatsConn(t)
84 defer nc.Close()
85
86 handler := natstransport.NewSubscriber(
87 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") },
88 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
89 func(context.Context, string, *nats.Conn, interface{}) error { return nil },
90 )
91
92 resp := testRequest(t, nc, handler)
93
94 if want, have := "dang", resp.Error; want != have {
95 t.Errorf("want %s, have %s", want, have)
96 }
97 }
98
99 func TestSubscriberBadEncode(t *testing.T) {
100 nc := newNatsConn(t)
101 defer nc.Close()
102
103 handler := natstransport.NewSubscriber(
104 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
105 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
106 func(context.Context, string, *nats.Conn, interface{}) error { return errors.New("dang") },
107 )
108
109 resp := testRequest(t, nc, handler)
110
111 if want, have := "dang", resp.Error; want != have {
112 t.Errorf("want %s, have %s", want, have)
113 }
114 }
115
116 func TestSubscriberErrorEncoder(t *testing.T) {
117 nc := newNatsConn(t)
118 defer nc.Close()
119
120 errTeapot := errors.New("teapot")
121 code := func(err error) error {
122 if err == errTeapot {
123 return err
124 }
125 return errors.New("dang")
126 }
127 handler := natstransport.NewSubscriber(
128 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
129 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
130 func(context.Context, string, *nats.Conn, interface{}) error { return nil },
131 natstransport.SubscriberErrorEncoder(func(_ context.Context, err error, reply string, nc *nats.Conn) {
132 var r TestResponse
133 r.Error = code(err).Error()
134
135 b, err := json.Marshal(r)
136 if err != nil {
137 t.Fatal(err)
138 }
139
140 if err := nc.Publish(reply, b); err != nil {
141 t.Fatal(err)
142 }
143 }),
144 )
145
146 resp := testRequest(t, nc, handler)
147
148 if want, have := errTeapot.Error(), resp.Error; want != have {
149 t.Errorf("want %s, have %s", want, have)
150 }
151 }
152
153 func TestSubscriberHappySubject(t *testing.T) {
154 step, response := testSubscriber(t)
155 step()
156 r := <-response
157
158 var resp TestResponse
159 err := json.Unmarshal(r.Data, &resp)
160 if err != nil {
161 t.Fatal(err)
162 }
163
164 if want, have := "", resp.Error; want != have {
165 t.Errorf("want %s, have %s (%s)", want, have, r.Data)
166 }
167 }
168
169 func TestMultipleSubscriberBefore(t *testing.T) {
170 nc := newNatsConn(t)
171 defer nc.Close()
172
173 var (
174 response = struct{ Body string }{"go eat a fly ugly\n"}
175 wg sync.WaitGroup
176 done = make(chan struct{})
177 )
178 handler := natstransport.NewSubscriber(
179 endpoint.Nop,
180 func(context.Context, *nats.Msg) (interface{}, error) {
181 return struct{}{}, nil
182 },
183 func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error {
184 b, err := json.Marshal(response)
185 if err != nil {
186 return err
187 }
188
189 return nc.Publish(reply, b)
190 },
191 natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context {
192 ctx = context.WithValue(ctx, "one", 1)
193
194 return ctx
195 }),
196 natstransport.SubscriberBefore(func(ctx context.Context, _ *nats.Msg) context.Context {
197 if _, ok := ctx.Value("one").(int); !ok {
198 t.Error("Value was not set properly when multiple ServerBefores are used")
199 }
200
201 close(done)
202 return ctx
203 }),
204 )
205
206 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
207 if err != nil {
208 t.Fatal(err)
209 }
210 defer sub.Unsubscribe()
211
212 wg.Add(1)
213 go func() {
214 defer wg.Done()
215 _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
216 if err != nil {
217 t.Fatal(err)
218 }
219 }()
220
221 select {
222 case <-done:
223 case <-time.After(time.Second):
224 t.Fatal("timeout waiting for finalizer")
225 }
226
227 wg.Wait()
228 }
229
230 func TestMultipleSubscriberAfter(t *testing.T) {
231 nc := newNatsConn(t)
232 defer nc.Close()
233
234 var (
235 response = struct{ Body string }{"go eat a fly ugly\n"}
236 wg sync.WaitGroup
237 done = make(chan struct{})
238 )
239 handler := natstransport.NewSubscriber(
240 endpoint.Nop,
241 func(context.Context, *nats.Msg) (interface{}, error) {
242 return struct{}{}, nil
243 },
244 func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error {
245 b, err := json.Marshal(response)
246 if err != nil {
247 return err
248 }
249
250 return nc.Publish(reply, b)
251 },
252 natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context {
253 ctx = context.WithValue(ctx, "one", 1)
254
255 return ctx
256 }),
257 natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context {
258 if _, ok := ctx.Value("one").(int); !ok {
259 t.Error("Value was not set properly when multiple ServerAfters are used")
260 }
261
262 close(done)
263 return ctx
264 }),
265 )
266
267 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
268 if err != nil {
269 t.Fatal(err)
270 }
271 defer sub.Unsubscribe()
272
273 wg.Add(1)
274 go func() {
275 defer wg.Done()
276 _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
277 if err != nil {
278 t.Fatal(err)
279 }
280 }()
281
282 select {
283 case <-done:
284 case <-time.After(time.Second):
285 t.Fatal("timeout waiting for finalizer")
286 }
287
288 wg.Wait()
289 }
290
291 func TestSubscriberFinalizerFunc(t *testing.T) {
292 nc := newNatsConn(t)
293 defer nc.Close()
294
295 var (
296 response = struct{ Body string }{"go eat a fly ugly\n"}
297 wg sync.WaitGroup
298 done = make(chan struct{})
299 )
300 handler := natstransport.NewSubscriber(
301 endpoint.Nop,
302 func(context.Context, *nats.Msg) (interface{}, error) {
303 return struct{}{}, nil
304 },
305 func(_ context.Context, reply string, nc *nats.Conn, _ interface{}) error {
306 b, err := json.Marshal(response)
307 if err != nil {
308 return err
309 }
310
311 return nc.Publish(reply, b)
312 },
313 natstransport.SubscriberFinalizer(func(ctx context.Context, _ *nats.Msg) {
314 close(done)
315 }),
316 )
317
318 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
319 if err != nil {
320 t.Fatal(err)
321 }
322 defer sub.Unsubscribe()
323
324 wg.Add(1)
325 go func() {
326 defer wg.Done()
327 _, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
328 if err != nil {
329 t.Fatal(err)
330 }
331 }()
332
333 select {
334 case <-done:
335 case <-time.After(time.Second):
336 t.Fatal("timeout waiting for finalizer")
337 }
338
339 wg.Wait()
340 }
341
342 func TestEncodeJSONResponse(t *testing.T) {
343 nc := newNatsConn(t)
344 defer nc.Close()
345
346 handler := natstransport.NewSubscriber(
347 func(context.Context, interface{}) (interface{}, error) {
348 return struct {
349 Foo string `json:"foo"`
350 }{"bar"}, nil
351 },
352 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
353 natstransport.EncodeJSONResponse,
354 )
355
356 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
357 if err != nil {
358 t.Fatal(err)
359 }
360 defer sub.Unsubscribe()
361
362 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
363 if err != nil {
364 t.Fatal(err)
365 }
366
367 if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(r.Data)); want != have {
368 t.Errorf("Body: want %s, have %s", want, have)
369 }
370 }
371
372 type responseError struct {
373 msg string
374 }
375
376 func (m responseError) Error() string {
377 return m.msg
378 }
379
380 func TestErrorEncoder(t *testing.T) {
381 nc := newNatsConn(t)
382 defer nc.Close()
383
384 errResp := struct {
385 Error string `json:"err"`
386 }{"oh no"}
387 handler := natstransport.NewSubscriber(
388 func(context.Context, interface{}) (interface{}, error) {
389 return nil, responseError{msg: errResp.Error}
390 },
391 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
392 natstransport.EncodeJSONResponse,
393 )
394
395 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
396 if err != nil {
397 t.Fatal(err)
398 }
399 defer sub.Unsubscribe()
400
401 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
402 if err != nil {
403 t.Fatal(err)
404 }
405
406 b, err := json.Marshal(errResp)
407 if err != nil {
408 t.Fatal(err)
409 }
410 if string(b) != string(r.Data) {
411 t.Errorf("ErrorEncoder: got: %q, expected: %q", r.Data, b)
412 }
413 }
414
415 type noContentResponse struct{}
416
417 func TestEncodeNoContent(t *testing.T) {
418 nc := newNatsConn(t)
419 defer nc.Close()
420
421 handler := natstransport.NewSubscriber(
422 func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
423 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
424 natstransport.EncodeJSONResponse,
425 )
426
427 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
428 if err != nil {
429 t.Fatal(err)
430 }
431 defer sub.Unsubscribe()
432
433 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
434 if err != nil {
435 t.Fatal(err)
436 }
437
438 if want, have := `{}`, strings.TrimSpace(string(r.Data)); want != have {
439 t.Errorf("Body: want %s, have %s", want, have)
440 }
441 }
442
443 func TestNoOpRequestDecoder(t *testing.T) {
444 nc := newNatsConn(t)
445 defer nc.Close()
446
447 handler := natstransport.NewSubscriber(
448 func(ctx context.Context, request interface{}) (interface{}, error) {
449 if request != nil {
450 t.Error("Expected nil request in endpoint when using NopRequestDecoder")
451 }
452 return nil, nil
453 },
454 natstransport.NopRequestDecoder,
455 natstransport.EncodeJSONResponse,
456 )
457
458 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
459 if err != nil {
460 t.Fatal(err)
461 }
462 defer sub.Unsubscribe()
463
464 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
465 if err != nil {
466 t.Fatal(err)
467 }
468
469 if want, have := `null`, strings.TrimSpace(string(r.Data)); want != have {
470 t.Errorf("Body: want %s, have %s", want, have)
471 }
472 }
473
474 func testSubscriber(t *testing.T) (step func(), resp <-chan *nats.Msg) {
475 var (
476 stepch = make(chan bool)
477 endpoint = func(context.Context, interface{}) (interface{}, error) {
478 <-stepch
479 return struct{}{}, nil
480 }
481 response = make(chan *nats.Msg)
482 handler = natstransport.NewSubscriber(
483 endpoint,
484 func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil },
485 natstransport.EncodeJSONResponse,
486 natstransport.SubscriberBefore(func(ctx context.Context, msg *nats.Msg) context.Context { return ctx }),
487 natstransport.SubscriberAfter(func(ctx context.Context, nc *nats.Conn) context.Context { return ctx }),
488 )
489 )
490
491 go func() {
492 nc := newNatsConn(t)
493 defer nc.Close()
494
495 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
496 if err != nil {
497 t.Fatal(err)
498 }
499 defer sub.Unsubscribe()
500
501 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
502 if err != nil {
503 t.Fatal(err)
504 }
505
506 response <- r
507 }()
508
509 return func() { stepch <- true }, response
510 }
511
512 func testRequest(t *testing.T, nc *nats.Conn, handler *natstransport.Subscriber) TestResponse {
513 sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))
514 if err != nil {
515 t.Fatal(err)
516 }
517 defer sub.Unsubscribe()
518
519 r, err := nc.Request("natstransport.test", []byte("test data"), 2*time.Second)
520 if err != nil {
521 t.Fatal(err)
522 }
523
524 var resp TestResponse
525 err = json.Unmarshal(r.Data, &resp)
526 if err != nil {
527 t.Fatal(err)
528 }
529
530 return resp
531 }
11
22 import (
33 "errors"
4 "math/rand"
45 "net"
56 "time"
67
102103 case conn = <-connc:
103104 if conn == nil {
104105 // didn't work
105 backoff = exponential(backoff) // wait longer
106 backoff = Exponential(backoff) // wait longer
106107 reconnectc = m.after(backoff) // try again
107108 } else {
108109 // worked!
131132 return conn
132133 }
133134
134 func exponential(d time.Duration) time.Duration {
135 // Exponential takes a duration and returns another one that is twice as long, +/- 50%. It is
136 // used to provide backoff for operations that may fail and should avoid thundering herds.
137 // See https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ for rationale
138 func Exponential(d time.Duration) time.Duration {
135139 d *= 2
140 jitter := rand.Float64() + 0.5
141 d = time.Duration(int64(float64(d.Nanoseconds()) * jitter))
136142 if d > time.Minute {
137143 d = time.Minute
138144 }
139145 return d
146
140147 }
141148
142149 // ErrConnectionUnavailable is returned by the Manager's Write method when the