Codebase list golang-github-go-kit-kit / 2677e49
added unit test for gRPC header and trailer request/response propagation Bas van Beek 7 years ago
14 changed file(s) with 518 addition(s) and 23 deletion(s). Raw diff Collapse all Expand all
221221 return
222222 }
223223
224 srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger)
224 srv := addsvc.MakeGRPCServer(endpoints, tracer, logger)
225225 s := grpc.NewServer()
226226 pb.RegisterAddServer(s, srv)
227227
1515 )
1616
1717 // MakeGRPCServer makes a set of endpoints available as a gRPC AddServer.
18 func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
18 func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
1919 options := []grpctransport.ServerOption{
2020 grpctransport.ServerErrorLogger(logger),
2121 }
2222 return &grpcServer{
2323 sum: grpctransport.NewServer(
24 ctx,
2524 endpoints.SumEndpoint,
2625 DecodeGRPCSumRequest,
2726 EncodeGRPCSumResponse,
2827 append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))...,
2928 ),
3029 concat: grpctransport.NewServer(
31 ctx,
3230 endpoints.ConcatEndpoint,
3331 DecodeGRPCConcatRequest,
3432 EncodeGRPCConcatResponse,
0 package test
1
2 import (
3 "context"
4
5 "google.golang.org/grpc"
6
7 "github.com/go-kit/kit/endpoint"
8 grpctransport "github.com/go-kit/kit/transport/grpc"
9 pb "github.com/go-kit/kit/transport/grpc/_pb"
10 )
11
12 type clientBinding struct {
13 test endpoint.Endpoint
14 }
15
16 func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
17 response, err := c.test(ctx, TestRequest{A: a, B: b})
18 if err != nil {
19 return nil, "", err
20 }
21 r := response.(*TestResponse)
22 return r.Ctx, r.V, nil
23 }
24
25 func NewClient(cc *grpc.ClientConn) Service {
26 return &clientBinding{
27 test: grpctransport.NewClient(
28 cc,
29 "pb.Test",
30 "Test",
31 encodeRequest,
32 decodeResponse,
33 &pb.TestResponse{},
34 grpctransport.ClientBefore(clientBefore),
35 grpctransport.ClientAfter(clientAfter),
36 ).Endpoint(),
37 }
38 }
0 package test
1
2 import (
3 "context"
4 "fmt"
5 "log"
6
7 "google.golang.org/grpc"
8 "google.golang.org/grpc/metadata"
9 )
10
11 type metaContext string
12
13 const (
14 correlationID metaContext = "correlation-id"
15 responseHDR metaContext = "my-response-header"
16 responseTRLR metaContext = "correlation-id-consumed"
17 )
18
19 func clientBefore(ctx context.Context, md *metadata.MD) context.Context {
20 if hdr, ok := ctx.Value(correlationID).(string); ok {
21 (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
22 }
23 if len(*md) > 0 {
24 fmt.Println("\tClient >> Request Headers:")
25 for key, val := range *md {
26 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
27 }
28 }
29 return ctx
30 }
31
32 func serverBefore(ctx context.Context, md *metadata.MD) context.Context {
33 if len(*md) > 0 {
34 fmt.Println("\tServer << Request Headers:")
35 for key, val := range *md {
36 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
37 }
38 }
39 if hdr, ok := (*md)[string(correlationID)]; ok {
40 cID := hdr[len(hdr)-1]
41 ctx = context.WithValue(ctx, correlationID, cID)
42 fmt.Printf("\tServer placed correlationID %q in context\n", cID)
43 }
44 return ctx
45 }
46
47 func serverAfter(ctx context.Context, _ *metadata.MD) {
48 var mdHeader, mdTrailer metadata.MD
49
50 mdHeader = metadata.Pairs(string(responseHDR), "has-a-value")
51 if err := grpc.SendHeader(ctx, mdHeader); err != nil {
52 log.Fatalf("unable to send header: %+v\n", err)
53 }
54
55 if hdr, ok := ctx.Value(correlationID).(string); ok {
56 mdTrailer = metadata.Pairs(string(responseTRLR), hdr)
57 if err := grpc.SetTrailer(ctx, mdTrailer); err != nil {
58 log.Fatalf("unable to set trailer: %+v\n", err)
59 }
60 fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
61 }
62 if len(mdHeader) > 0 {
63 fmt.Println("\tServer >> Response Headers:")
64 for key, val := range mdHeader {
65 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
66 }
67 }
68 if len(mdTrailer) > 0 {
69 fmt.Println("\tServer >> Response Trailers:")
70 for key, val := range mdTrailer {
71 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
72 }
73 }
74 }
75
76 func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context {
77 if len(mdHeader) > 0 {
78 fmt.Println("\tClient << Response Headers:")
79 for key, val := range mdHeader {
80 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
81 }
82 }
83 if len(mdTrailer) > 0 {
84 fmt.Println("\tClient << Response Trailers:")
85 for key, val := range mdTrailer {
86 fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
87 }
88 }
89
90 if hdr, ok := mdTrailer[string(responseTRLR)]; ok {
91 ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1])
92 }
93 return ctx
94 }
95
96 func SetCorrelationID(ctx context.Context, v string) context.Context {
97 return context.WithValue(ctx, correlationID, v)
98 }
99
100 func GetConsumedCorrelationID(ctx context.Context) string {
101 if trlr, ok := ctx.Value(responseTRLR).(string); ok {
102 return trlr
103 }
104 return ""
105 }
0 package test
1
2 import (
3 "context"
4 "errors"
5
6 pb "github.com/go-kit/kit/transport/grpc/_pb"
7 )
8
9 func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
10 r, ok := req.(TestRequest)
11 if !ok {
12 return nil, errors.New("request encode error")
13 }
14 return &pb.TestRequest{A: r.A, B: r.B}, nil
15 }
16
17 func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) {
18 r, ok := req.(*pb.TestRequest)
19 if !ok {
20 return nil, errors.New("request decode error")
21 }
22 return TestRequest{A: r.A, B: r.B}, nil
23 }
24
25 func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
26 r, ok := resp.(*TestResponse)
27 if !ok {
28 return nil, errors.New("response encode error")
29 }
30 return &pb.TestResponse{V: r.V}, nil
31 }
32
33 func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) {
34 r, ok := resp.(*pb.TestResponse)
35 if !ok {
36 return nil, errors.New("response decode error")
37 }
38 return &TestResponse{V: r.V, Ctx: ctx}, nil
39 }
0 package test
1
2 import (
3 "context"
4 "fmt"
5
6 oldcontext "golang.org/x/net/context"
7
8 "github.com/go-kit/kit/endpoint"
9 grpctransport "github.com/go-kit/kit/transport/grpc"
10 pb "github.com/go-kit/kit/transport/grpc/_pb"
11 )
12
13 type service struct{}
14
15 func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
16 return nil, fmt.Sprintf("%s = %d", a, b), nil
17 }
18
19 func NewService() Service {
20 return service{}
21 }
22
23 func makeTestEndpoint(svc Service) endpoint.Endpoint {
24 return func(ctx context.Context, request interface{}) (interface{}, error) {
25 req := request.(TestRequest)
26 newCtx, v, err := svc.Test(ctx, req.A, req.B)
27 return &TestResponse{
28 V: v,
29 Ctx: newCtx,
30 }, err
31 }
32 }
33
34 type serverBinding struct {
35 test grpctransport.Handler
36 }
37
38 func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) {
39 _, response, err := b.test.ServeGRPC(ctx, req)
40 if err != nil {
41 return nil, err
42 }
43 return response.(*pb.TestResponse), nil
44 }
45
46 func NewBinding(svc Service) *serverBinding {
47 return &serverBinding{
48 test: grpctransport.NewServer(
49 makeTestEndpoint(svc),
50 decodeRequest,
51 encodeResponse,
52 grpctransport.ServerBefore(serverBefore),
53 grpctransport.ServerAfter(serverAfter),
54 ),
55 }
56 }
0 package test
1
2 import "context"
3
4 type Service interface {
5 Test(ctx context.Context, a string, b int64) (context.Context, string, error)
6 }
7
8 type TestRequest struct {
9 A string
10 B int64
11 }
12
13 type TestResponse struct {
14 Ctx context.Context
15 V string
16 }
0 package pb
1
2 //go:generate protoc test.proto --go_out=plugins=grpc:.
0 // Code generated by protoc-gen-go.
1 // source: test.proto
2 // DO NOT EDIT!
3
4 /*
5 Package pb is a generated protocol buffer package.
6
7 It is generated from these files:
8 test.proto
9
10 It has these top-level messages:
11 TestRequest
12 TestResponse
13 */
14 package pb
15
16 import proto "github.com/golang/protobuf/proto"
17 import fmt "fmt"
18 import math "math"
19
20 import (
21 context "golang.org/x/net/context"
22 grpc "google.golang.org/grpc"
23 )
24
25 // Reference imports to suppress errors if they are not otherwise used.
26 var _ = proto.Marshal
27 var _ = fmt.Errorf
28 var _ = math.Inf
29
30 // This is a compile-time assertion to ensure that this generated file
31 // is compatible with the proto package it is being compiled against.
32 // A compilation error at this line likely means your copy of the
33 // proto package needs to be updated.
34 const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
35
36 type TestRequest struct {
37 A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"`
38 B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"`
39 }
40
41 func (m *TestRequest) Reset() { *m = TestRequest{} }
42 func (m *TestRequest) String() string { return proto.CompactTextString(m) }
43 func (*TestRequest) ProtoMessage() {}
44 func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
45
46 func (m *TestRequest) GetA() string {
47 if m != nil {
48 return m.A
49 }
50 return ""
51 }
52
53 func (m *TestRequest) GetB() int64 {
54 if m != nil {
55 return m.B
56 }
57 return 0
58 }
59
60 type TestResponse struct {
61 V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"`
62 }
63
64 func (m *TestResponse) Reset() { *m = TestResponse{} }
65 func (m *TestResponse) String() string { return proto.CompactTextString(m) }
66 func (*TestResponse) ProtoMessage() {}
67 func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
68
69 func (m *TestResponse) GetV() string {
70 if m != nil {
71 return m.V
72 }
73 return ""
74 }
75
76 func init() {
77 proto.RegisterType((*TestRequest)(nil), "pb.TestRequest")
78 proto.RegisterType((*TestResponse)(nil), "pb.TestResponse")
79 }
80
81 // Reference imports to suppress errors if they are not otherwise used.
82 var _ context.Context
83 var _ grpc.ClientConn
84
85 // This is a compile-time assertion to ensure that this generated file
86 // is compatible with the grpc package it is being compiled against.
87 const _ = grpc.SupportPackageIsVersion4
88
89 // Client API for Test service
90
91 type TestClient interface {
92 Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error)
93 }
94
95 type testClient struct {
96 cc *grpc.ClientConn
97 }
98
99 func NewTestClient(cc *grpc.ClientConn) TestClient {
100 return &testClient{cc}
101 }
102
103 func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) {
104 out := new(TestResponse)
105 err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...)
106 if err != nil {
107 return nil, err
108 }
109 return out, nil
110 }
111
112 // Server API for Test service
113
114 type TestServer interface {
115 Test(context.Context, *TestRequest) (*TestResponse, error)
116 }
117
118 func RegisterTestServer(s *grpc.Server, srv TestServer) {
119 s.RegisterService(&_Test_serviceDesc, srv)
120 }
121
122 func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
123 in := new(TestRequest)
124 if err := dec(in); err != nil {
125 return nil, err
126 }
127 if interceptor == nil {
128 return srv.(TestServer).Test(ctx, in)
129 }
130 info := &grpc.UnaryServerInfo{
131 Server: srv,
132 FullMethod: "/pb.Test/Test",
133 }
134 handler := func(ctx context.Context, req interface{}) (interface{}, error) {
135 return srv.(TestServer).Test(ctx, req.(*TestRequest))
136 }
137 return interceptor(ctx, in, info, handler)
138 }
139
140 var _Test_serviceDesc = grpc.ServiceDesc{
141 ServiceName: "pb.Test",
142 HandlerType: (*TestServer)(nil),
143 Methods: []grpc.MethodDesc{
144 {
145 MethodName: "Test",
146 Handler: _Test_Test_Handler,
147 },
148 },
149 Streams: []grpc.StreamDesc{},
150 Metadata: "test.proto",
151 }
152
153 func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
154
155 var fileDescriptor0 = []byte{
156 // 129 bytes of a gzipped FileDescriptorProto
157 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
158 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49,
159 0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60,
160 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98,
161 0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41,
162 0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf,
163 0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b,
164 0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00,
165 0x00,
166 }
0 syntax = "proto3";
1
2 package pb;
3
4 service Test {
5 rpc Test (TestRequest) returns (TestResponse) {}
6 }
7
8 message TestRequest {
9 string a = 1;
10 int64 b = 2;
11 }
12
13 message TestResponse {
14 string v = 1;
15 }
106106 }
107107
108108 for _, f := range c.after {
109 ctx = f(ctx, &header, &trailer)
109 ctx = f(ctx, header, trailer)
110110 }
111111
112112 response, err := c.dec(ctx, grpcReply)
0 package grpc_test
1
2 import (
3 "context"
4 "fmt"
5 "net"
6 "testing"
7
8 "google.golang.org/grpc"
9
10 test "github.com/go-kit/kit/transport/grpc/_grpc_test"
11 pb "github.com/go-kit/kit/transport/grpc/_pb"
12 )
13
14 const (
15 hostPort string = "localhost:8002"
16 )
17
18 func TestGRPCClient(t *testing.T) {
19 var (
20 server = grpc.NewServer()
21 service = test.NewService()
22 )
23
24 sc, err := net.Listen("tcp", hostPort)
25 if err != nil {
26 t.Fatalf("unable to listen: %+v", err)
27 }
28 defer server.GracefulStop()
29
30 go func() {
31 pb.RegisterTestServer(server, test.NewBinding(service))
32 _ = server.Serve(sc)
33 }()
34
35 cc, err := grpc.Dial(hostPort, grpc.WithInsecure())
36 if err != nil {
37 t.Fatalf("unable to Dial: %+v", err)
38 }
39
40 client := test.NewClient(cc)
41
42 var (
43 a = "the answer to life the universe and everything"
44 b = int64(42)
45 cID = "request-1"
46 ctx = test.SetCorrelationID(context.Background(), cID)
47 )
48
49 responseCTX, v, err := client.Test(ctx, a, b)
50
51 if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have {
52 t.Fatalf("want %q, have %q", want, have)
53 }
54
55 if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have {
56 t.Fatalf("want %q, have %q", want, have)
57 }
58 }
2626 // trailer and make the responses available for consumption. ClientResponseFuncs
2727 // are only executed in clients, after a request has been made, but prior to it
2828 // being decoded.
29 type ClientResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context
29 type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context
3030
3131 // SetResponseHeader returns a ResponseFunc that sets the specified metadata
3232 // key-value pair.
00 package grpc
11
22 import (
3 "context"
4
53 oldcontext "golang.org/x/net/context"
64 "google.golang.org/grpc/metadata"
75
1816
1917 // Server wraps an endpoint and implements grpc.Handler.
2018 type Server struct {
21 ctx context.Context
2219 e endpoint.Endpoint
2320 dec DecodeRequestFunc
2421 enc EncodeResponseFunc
3330 // definitions to individual handlers. Request and response objects are from the
3431 // caller business domain, not gRPC request and reply types.
3532 func NewServer(
36 ctx context.Context,
3733 e endpoint.Endpoint,
3834 dec DecodeRequestFunc,
3935 enc EncodeResponseFunc,
4036 options ...ServerOption,
4137 ) *Server {
4238 s := &Server{
43 ctx: ctx,
4439 e: e,
4540 dec: dec,
4641 enc: enc,
7469 }
7570
7671 // ServeGRPC implements the Handler interface.
77 func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) {
78 ctx := s.ctx
79
72 func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) {
8073 // Retrieve gRPC metadata.
81 md, ok := metadata.FromContext(grpcCtx)
74 md, ok := metadata.FromContext(ctx)
8275 if !ok {
8376 md = metadata.MD{}
8477 }
8881 }
8982
9083 // Store potentially updated metadata in the gRPC context.
91 grpcCtx = metadata.NewContext(grpcCtx, md)
84 ctx = metadata.NewContext(ctx, md)
9285
93 request, err := s.dec(grpcCtx, req)
86 request, err := s.dec(ctx, req)
9487 if err != nil {
9588 s.logger.Log("err", err)
96 return grpcCtx, nil, err
89 return ctx, nil, err
9790 }
9891
9992 response, err := s.e(ctx, request)
10093 if err != nil {
10194 s.logger.Log("err", err)
102 return grpcCtx, nil, err
95 return ctx, nil, err
10396 }
10497
10598 for _, f := range s.after {
107100 }
108101
109102 // Store potentially updated metadata in the gRPC context.
110 grpcCtx = metadata.NewContext(grpcCtx, md)
103 ctx = metadata.NewContext(ctx, md)
111104
112 grpcResp, err := s.enc(grpcCtx, response)
105 grpcResp, err := s.enc(ctx, response)
113106 if err != nil {
114107 s.logger.Log("err", err)
115 return grpcCtx, nil, err
108 return ctx, nil, err
116109 }
117110
118 return grpcCtx, grpcResp, nil
111 return ctx, grpcResp, nil
119112 }