diff --git a/examples/addsvc/cmd/addsvc/main.go b/examples/addsvc/cmd/addsvc/main.go index 842f34b..3fba0ad 100644 --- a/examples/addsvc/cmd/addsvc/main.go +++ b/examples/addsvc/cmd/addsvc/main.go @@ -222,7 +222,7 @@ return } - srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger) + srv := addsvc.MakeGRPCServer(endpoints, tracer, logger) s := grpc.NewServer() pb.RegisterAddServer(s, srv) diff --git a/examples/addsvc/transport_grpc.go b/examples/addsvc/transport_grpc.go index 21e60bc..dcfc03a 100644 --- a/examples/addsvc/transport_grpc.go +++ b/examples/addsvc/transport_grpc.go @@ -16,20 +16,18 @@ ) // MakeGRPCServer makes a set of endpoints available as a gRPC AddServer. -func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { +func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { options := []grpctransport.ServerOption{ grpctransport.ServerErrorLogger(logger), } return &grpcServer{ sum: grpctransport.NewServer( - ctx, endpoints.SumEndpoint, DecodeGRPCSumRequest, EncodeGRPCSumResponse, append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))..., ), concat: grpctransport.NewServer( - ctx, endpoints.ConcatEndpoint, DecodeGRPCConcatRequest, EncodeGRPCConcatResponse, diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go new file mode 100644 index 0000000..11d78ca --- /dev/null +++ b/transport/grpc/_grpc_test/client.go @@ -0,0 +1,39 @@ +package test + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +type clientBinding struct { + test endpoint.Endpoint +} + +func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + response, err := c.test(ctx, TestRequest{A: a, B: b}) + if err != nil { + return nil, "", err + } + r := response.(*TestResponse) + return r.Ctx, r.V, nil +} + +func NewClient(cc *grpc.ClientConn) Service { + return &clientBinding{ + test: grpctransport.NewClient( + cc, + "pb.Test", + "Test", + encodeRequest, + decodeResponse, + &pb.TestResponse{}, + grpctransport.ClientBefore(clientBefore), + grpctransport.ClientAfter(clientAfter), + ).Endpoint(), + } +} diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go new file mode 100644 index 0000000..f31b50b --- /dev/null +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -0,0 +1,106 @@ +package test + +import ( + "context" + "fmt" + "log" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type metaContext string + +const ( + correlationID metaContext = "correlation-id" + responseHDR metaContext = "my-response-header" + responseTRLR metaContext = "correlation-id-consumed" +) + +func clientBefore(ctx context.Context, md *metadata.MD) context.Context { + if hdr, ok := ctx.Value(correlationID).(string); ok { + (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr) + } + if len(*md) > 0 { + fmt.Println("\tClient >> Request Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +func serverBefore(ctx context.Context, md *metadata.MD) context.Context { + if len(*md) > 0 { + fmt.Println("\tServer << Request Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if hdr, ok := (*md)[string(correlationID)]; ok { + cID := hdr[len(hdr)-1] + ctx = context.WithValue(ctx, correlationID, cID) + fmt.Printf("\tServer placed correlationID %q in context\n", cID) + } + return ctx +} + +func serverAfter(ctx context.Context, _ *metadata.MD) { + var mdHeader, mdTrailer metadata.MD + + mdHeader = metadata.Pairs(string(responseHDR), "has-a-value") + if err := grpc.SendHeader(ctx, mdHeader); err != nil { + log.Fatalf("unable to send header: %+v\n", err) + } + + if hdr, ok := ctx.Value(correlationID).(string); ok { + mdTrailer = metadata.Pairs(string(responseTRLR), hdr) + if err := grpc.SetTrailer(ctx, mdTrailer); err != nil { + log.Fatalf("unable to set trailer: %+v\n", err) + } + fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) + } + if len(mdHeader) > 0 { + fmt.Println("\tServer >> Response Headers:") + for key, val := range mdHeader { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if len(mdTrailer) > 0 { + fmt.Println("\tServer >> Response Trailers:") + for key, val := range mdTrailer { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } +} + +func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context { + if len(mdHeader) > 0 { + fmt.Println("\tClient << Response Headers:") + for key, val := range mdHeader { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if len(mdTrailer) > 0 { + fmt.Println("\tClient << Response Trailers:") + for key, val := range mdTrailer { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + + if hdr, ok := mdTrailer[string(responseTRLR)]; ok { + ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1]) + } + return ctx +} + +func SetCorrelationID(ctx context.Context, v string) context.Context { + return context.WithValue(ctx, correlationID, v) +} + +func GetConsumedCorrelationID(ctx context.Context) string { + if trlr, ok := ctx.Value(responseTRLR).(string); ok { + return trlr + } + return "" +} diff --git a/transport/grpc/_grpc_test/request_response.go b/transport/grpc/_grpc_test/request_response.go new file mode 100644 index 0000000..441bc65 --- /dev/null +++ b/transport/grpc/_grpc_test/request_response.go @@ -0,0 +1,40 @@ +package test + +import ( + "context" + "errors" + + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r, ok := req.(TestRequest) + if !ok { + return nil, errors.New("request encode error") + } + return &pb.TestRequest{A: r.A, B: r.B}, nil +} + +func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r, ok := req.(*pb.TestRequest) + if !ok { + return nil, errors.New("request decode error") + } + return TestRequest{A: r.A, B: r.B}, nil +} + +func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r, ok := resp.(*TestResponse) + if !ok { + return nil, errors.New("response encode error") + } + return &pb.TestResponse{V: r.V}, nil +} + +func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r, ok := resp.(*pb.TestResponse) + if !ok { + return nil, errors.New("response decode error") + } + return &TestResponse{V: r.V, Ctx: ctx}, nil +} diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go new file mode 100644 index 0000000..6c55b11 --- /dev/null +++ b/transport/grpc/_grpc_test/server.go @@ -0,0 +1,57 @@ +package test + +import ( + "context" + "fmt" + + oldcontext "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +type service struct{} + +func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + return nil, fmt.Sprintf("%s = %d", a, b), nil +} + +func NewService() Service { + return service{} +} + +func makeTestEndpoint(svc Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(TestRequest) + newCtx, v, err := svc.Test(ctx, req.A, req.B) + return &TestResponse{ + V: v, + Ctx: newCtx, + }, err + } +} + +type serverBinding struct { + test grpctransport.Handler +} + +func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) { + _, response, err := b.test.ServeGRPC(ctx, req) + if err != nil { + return nil, err + } + return response.(*pb.TestResponse), nil +} + +func NewBinding(svc Service) *serverBinding { + return &serverBinding{ + test: grpctransport.NewServer( + makeTestEndpoint(svc), + decodeRequest, + encodeResponse, + grpctransport.ServerBefore(serverBefore), + grpctransport.ServerAfter(serverAfter), + ), + } +} diff --git a/transport/grpc/_grpc_test/service.go b/transport/grpc/_grpc_test/service.go new file mode 100644 index 0000000..536b27c --- /dev/null +++ b/transport/grpc/_grpc_test/service.go @@ -0,0 +1,17 @@ +package test + +import "context" + +type Service interface { + Test(ctx context.Context, a string, b int64) (context.Context, string, error) +} + +type TestRequest struct { + A string + B int64 +} + +type TestResponse struct { + Ctx context.Context + V string +} diff --git a/transport/grpc/_pb/generate.go b/transport/grpc/_pb/generate.go new file mode 100644 index 0000000..aa20bb6 --- /dev/null +++ b/transport/grpc/_pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc test.proto --go_out=plugins=grpc:. diff --git a/transport/grpc/_pb/test.pb.go b/transport/grpc/_pb/test.pb.go new file mode 100644 index 0000000..97d29bb --- /dev/null +++ b/transport/grpc/_pb/test.pb.go @@ -0,0 +1,167 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package pb is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + TestRequest + TestResponse +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type TestRequest struct { + A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"` + B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"` +} + +func (m *TestRequest) Reset() { *m = TestRequest{} } +func (m *TestRequest) String() string { return proto.CompactTextString(m) } +func (*TestRequest) ProtoMessage() {} +func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *TestRequest) GetA() string { + if m != nil { + return m.A + } + return "" +} + +func (m *TestRequest) GetB() int64 { + if m != nil { + return m.B + } + return 0 +} + +type TestResponse struct { + V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"` +} + +func (m *TestResponse) Reset() { *m = TestResponse{} } +func (m *TestResponse) String() string { return proto.CompactTextString(m) } +func (*TestResponse) ProtoMessage() {} +func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *TestResponse) GetV() string { + if m != nil { + return m.V + } + return "" +} + +func init() { + proto.RegisterType((*TestRequest)(nil), "pb.TestRequest") + proto.RegisterType((*TestResponse)(nil), "pb.TestResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Test service + +type TestClient interface { + Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) +} + +type testClient struct { + cc *grpc.ClientConn +} + +func NewTestClient(cc *grpc.ClientConn) TestClient { + return &testClient{cc} +} + +func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) { + out := new(TestResponse) + err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Test service + +type TestServer interface { + Test(context.Context, *TestRequest) (*TestResponse, error) +} + +func RegisterTestServer(s *grpc.Server, srv TestServer) { + s.RegisterService(&_Test_serviceDesc, srv) +} + +func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServer).Test(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Test/Test", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServer).Test(ctx, req.(*TestRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Test_serviceDesc = grpc.ServiceDesc{ + ServiceName: "pb.Test", + HandlerType: (*TestServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Test", + Handler: _Test_Test_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "test.proto", +} + +func init() { proto.RegisterFile("test.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 129 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49, + 0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60, + 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98, + 0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41, + 0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf, + 0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b, + 0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00, + 0x00, +} diff --git a/transport/grpc/_pb/test.proto b/transport/grpc/_pb/test.proto new file mode 100644 index 0000000..6a3555e --- /dev/null +++ b/transport/grpc/_pb/test.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package pb; + +service Test { + rpc Test (TestRequest) returns (TestResponse) {} +} + +message TestRequest { + string a = 1; + int64 b = 2; +} + +message TestResponse { + string v = 1; +} diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 437d931..622ca88 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -107,7 +107,7 @@ } for _, f := range c.after { - ctx = f(ctx, &header, &trailer) + ctx = f(ctx, header, trailer) } response, err := c.dec(ctx, grpcReply) diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go new file mode 100644 index 0000000..bbd2f5e --- /dev/null +++ b/transport/grpc/client_test.go @@ -0,0 +1,59 @@ +package grpc_test + +import ( + "context" + "fmt" + "net" + "testing" + + "google.golang.org/grpc" + + test "github.com/go-kit/kit/transport/grpc/_grpc_test" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +const ( + hostPort string = "localhost:8002" +) + +func TestGRPCClient(t *testing.T) { + var ( + server = grpc.NewServer() + service = test.NewService() + ) + + sc, err := net.Listen("tcp", hostPort) + if err != nil { + t.Fatalf("unable to listen: %+v", err) + } + defer server.GracefulStop() + + go func() { + pb.RegisterTestServer(server, test.NewBinding(service)) + _ = server.Serve(sc) + }() + + cc, err := grpc.Dial(hostPort, grpc.WithInsecure()) + if err != nil { + t.Fatalf("unable to Dial: %+v", err) + } + + client := test.NewClient(cc) + + var ( + a = "the answer to life the universe and everything" + b = int64(42) + cID = "request-1" + ctx = test.SetCorrelationID(context.Background(), cID) + ) + + responseCTX, v, err := client.Test(ctx, a, b) + + if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have { + t.Fatalf("want %q, have %q", want, have) + } + + if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have { + t.Fatalf("want %q, have %q", want, have) + } +} diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 067ef3f..7192bb5 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -27,7 +27,7 @@ // trailer and make the responses available for consumption. ClientResponseFuncs // are only executed in clients, after a request has been made, but prior to it // being decoded. -type ClientResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context +type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 9f6a94a..9289c4f 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -1,8 +1,6 @@ package grpc import ( - "context" - oldcontext "golang.org/x/net/context" "google.golang.org/grpc/metadata" @@ -19,7 +17,6 @@ // Server wraps an endpoint and implements grpc.Handler. type Server struct { - ctx context.Context e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc @@ -34,14 +31,12 @@ // definitions to individual handlers. Request and response objects are from the // caller business domain, not gRPC request and reply types. func NewServer( - ctx context.Context, e endpoint.Endpoint, dec DecodeRequestFunc, enc EncodeResponseFunc, options ...ServerOption, ) *Server { s := &Server{ - ctx: ctx, e: e, dec: dec, enc: enc, @@ -75,11 +70,9 @@ } // ServeGRPC implements the Handler interface. -func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { - ctx := s.ctx - +func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { // Retrieve gRPC metadata. - md, ok := metadata.FromContext(grpcCtx) + md, ok := metadata.FromContext(ctx) if !ok { md = metadata.MD{} } @@ -89,18 +82,18 @@ } // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) + ctx = metadata.NewContext(ctx, md) - request, err := s.dec(grpcCtx, req) + request, err := s.dec(ctx, req) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } response, err := s.e(ctx, request) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } for _, f := range s.after { @@ -108,13 +101,13 @@ } // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) + ctx = metadata.NewContext(ctx, md) - grpcResp, err := s.enc(grpcCtx, response) + grpcResp, err := s.enc(ctx, response) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } - return grpcCtx, grpcResp, nil + return ctx, grpcResp, nil }