diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go index 5bd545d..0769325 100644 --- a/transport/grpc/_grpc_test/context_metadata.go +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -38,8 +38,8 @@ /* server before functions */ -func extractCorrelationID(ctx context.Context, md *metadata.MD) context.Context { - if hdr, ok := (*md)[string(correlationID)]; ok { +func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationID)]; ok { cID := hdr[len(hdr)-1] ctx = context.WithValue(ctx, correlationID, cID) fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID) @@ -47,10 +47,10 @@ return ctx } -func displayServerRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { - if len(*md) > 0 { +func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context { + if len(md) > 0 { fmt.Println("\tServer << Request Headers:") - for key, val := range *md { + for key, val := range md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } @@ -59,37 +59,42 @@ /* server after functions */ -func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) { +func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { *md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value")) + return ctx } -func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) { +func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer >> Response Headers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx } -func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { *md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too")) + return ctx } -func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { if hdr, ok := ctx.Value(correlationID).(string); ok { fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) *md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr)) } + return ctx } -func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer >> Response Trailers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx } /* client after functions */ diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 622ca88..c0faa2b 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -21,7 +21,7 @@ enc EncodeRequestFunc dec DecodeResponseFunc grpcReply reflect.Type - before []RequestFunc + before []ClientRequestFunc after []ClientResponseFunc } @@ -54,7 +54,7 @@ reflect.ValueOf(grpcReply), ).Interface(), ), - before: []RequestFunc{}, + before: []ClientRequestFunc{}, after: []ClientResponseFunc{}, } for _, option := range options { @@ -68,7 +68,7 @@ // ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC // request before it's invoked. -func ClientBefore(before ...RequestFunc) ClientOption { +func ClientBefore(before ...ClientRequestFunc) ClientOption { return func(c *Client) { c.before = append(c.before, before...) } } diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 05a9d34..8d072ed 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -12,17 +12,22 @@ binHdrSuffix = "-bin" ) -// RequestFunc may take information from a gRPC request and put it into a -// request context. In Servers, RequestFuncs are executed prior to invoking the -// endpoint. In Clients, RequestFuncs are executed after creating the request -// but prior to invoking the gRPC client. -type RequestFunc func(context.Context, *metadata.MD) context.Context +// ClientRequestFunc may take information from context and use it to construct +// metadata headers to be transported to the server. ClientRequestFuncs are +// executed after creating the request but prior to sending the gRPC request to +// the server. +type ClientRequestFunc func(context.Context, *metadata.MD) context.Context + +// ServerRequestFunc may take information from the received metadata header and +// use it to place items in the request scoped context. ServerRequestFuncs are +// executed prior to invoking the endpoint. +type ServerRequestFunc func(context.Context, metadata.MD) context.Context // ServerResponseFunc may take information from a request context and use it to // manipulate the gRPC response metadata headers and trailers. ResponseFuncs are // only executed in servers, after invoking the endpoint but prior to writing a // response. -type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) +type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context // ClientResponseFunc may take information from a gRPC metadata header and/or // trailer and make the responses available for consumption. ClientResponseFuncs @@ -30,9 +35,9 @@ // being decoded. type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context -// SetRequestHeader returns a RequestFunc that sets the specified metadata +// SetRequestHeader returns a ClientRequestFunc that sets the specified metadata // key-value pair. -func SetRequestHeader(key, val string) RequestFunc { +func SetRequestHeader(key, val string) ClientRequestFunc { return func(ctx context.Context, md *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) @@ -43,18 +48,20 @@ // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. func SetResponseHeader(key, val string) ServerResponseFunc { - return func(_ context.Context, md *metadata.MD, _ *metadata.MD) { + return func(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) + return ctx } } // SetResponseTrailer returns a ResponseFunc that sets the specified metadata // key-value pair. func SetResponseTrailer(key, val string) ServerResponseFunc { - return func(_ context.Context, _ *metadata.MD, md *metadata.MD) { + return func(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) + return ctx } } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 476902e..b14d7d8 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -21,7 +21,7 @@ e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc - before []RequestFunc + before []ServerRequestFunc after []ServerResponseFunc logger log.Logger } @@ -54,7 +54,7 @@ // ServerBefore functions are executed on the HTTP request object before the // request is decoded. -func ServerBefore(before ...RequestFunc) ServerOption { +func ServerBefore(before ...ServerRequestFunc) ServerOption { return func(s *Server) { s.before = append(s.before, before...) } } @@ -79,11 +79,8 @@ } for _, f := range s.before { - ctx = f(ctx, &md) + ctx = f(ctx, md) } - - // Store potentially updated metadata in the gRPC context. - ctx = metadata.NewContext(ctx, md) request, err := s.dec(ctx, req) if err != nil { @@ -99,7 +96,7 @@ var mdHeader, mdTrailer metadata.MD for _, f := range s.after { - f(ctx, &mdHeader, &mdTrailer) + ctx = f(ctx, &mdHeader, &mdTrailer) } grpcResp, err := s.enc(ctx, response)