diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go index 11d78ca..a70ebc2 100644 --- a/transport/grpc/_grpc_test/client.go +++ b/transport/grpc/_grpc_test/client.go @@ -32,8 +32,19 @@ encodeRequest, decodeResponse, &pb.TestResponse{}, - grpctransport.ClientBefore(clientBefore), - grpctransport.ClientAfter(clientAfter), + grpctransport.ClientBefore( + injectCorrelationID, + ), + grpctransport.ClientBefore( + displayClientRequestHeaders, + ), + grpctransport.ClientAfter( + displayClientResponseHeaders, + displayClientResponseTrailers, + ), + grpctransport.ClientAfter( + extractConsumedCorrelationID, + ), ).Endpoint(), } } diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go index f31b50b..5bd545d 100644 --- a/transport/grpc/_grpc_test/context_metadata.go +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -3,24 +3,30 @@ import ( "context" "fmt" - "log" - "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) type metaContext string const ( - correlationID metaContext = "correlation-id" - responseHDR metaContext = "my-response-header" - responseTRLR metaContext = "correlation-id-consumed" + correlationID metaContext = "correlation-id" + responseHDR metaContext = "my-response-header" + responseTRLR metaContext = "my-response-trailer" + correlationIDTRLR metaContext = "correlation-id-consumed" ) -func clientBefore(ctx context.Context, md *metadata.MD) context.Context { +/* client before functions */ + +func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context { if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr) (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr) } + return ctx +} + +func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tClient >> Request Headers:") for key, val := range *md { @@ -30,76 +36,100 @@ return ctx } -func serverBefore(ctx context.Context, md *metadata.MD) context.Context { +/* server before functions */ + +func extractCorrelationID(ctx context.Context, md *metadata.MD) context.Context { + if hdr, ok := (*md)[string(correlationID)]; ok { + cID := hdr[len(hdr)-1] + ctx = context.WithValue(ctx, correlationID, cID) + fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID) + } + return ctx +} + +func displayServerRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer << Request Headers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } - if hdr, ok := (*md)[string(correlationID)]; ok { - cID := hdr[len(hdr)-1] - ctx = context.WithValue(ctx, correlationID, cID) - fmt.Printf("\tServer placed correlationID %q in context\n", cID) - } return ctx } -func serverAfter(ctx context.Context, _ *metadata.MD) { - var mdHeader, mdTrailer metadata.MD +/* server after functions */ - mdHeader = metadata.Pairs(string(responseHDR), "has-a-value") - if err := grpc.SendHeader(ctx, mdHeader); err != nil { - log.Fatalf("unable to send header: %+v\n", err) - } +func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) { + *md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value")) +} - if hdr, ok := ctx.Value(correlationID).(string); ok { - mdTrailer = metadata.Pairs(string(responseTRLR), hdr) - if err := grpc.SetTrailer(ctx, mdTrailer); err != nil { - log.Fatalf("unable to set trailer: %+v\n", err) - } - fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) - } - if len(mdHeader) > 0 { +func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) { + if len(*md) > 0 { fmt.Println("\tServer >> Response Headers:") - for key, val := range mdHeader { - fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) - } - } - if len(mdTrailer) > 0 { - fmt.Println("\tServer >> Response Trailers:") - for key, val := range mdTrailer { + for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } } -func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context { - if len(mdHeader) > 0 { - fmt.Println("\tClient << Response Headers:") - for key, val := range mdHeader { +func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + *md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too")) +} + +func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) + *md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr)) + } +} + +func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + if len(*md) > 0 { + fmt.Println("\tServer >> Response Trailers:") + for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } - if len(mdTrailer) > 0 { - fmt.Println("\tClient << Response Trailers:") - for key, val := range mdTrailer { +} + +/* client after functions */ + +func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context { + if len(md) > 0 { + fmt.Println("\tClient << Response Headers:") + for key, val := range md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx +} - if hdr, ok := mdTrailer[string(responseTRLR)]; ok { - ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1]) +func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if len(md) > 0 { + fmt.Println("\tClient << Response Trailers:") + for key, val := range md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } } return ctx } + +func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationIDTRLR)]; ok { + fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1]) + ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1]) + } + return ctx +} + +/* CorrelationID context handlers */ func SetCorrelationID(ctx context.Context, v string) context.Context { return context.WithValue(ctx, correlationID, v) } func GetConsumedCorrelationID(ctx context.Context) string { - if trlr, ok := ctx.Value(responseTRLR).(string); ok { + if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok { return trlr } return "" diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go index 6c55b11..52e9048 100644 --- a/transport/grpc/_grpc_test/server.go +++ b/transport/grpc/_grpc_test/server.go @@ -50,8 +50,21 @@ makeTestEndpoint(svc), decodeRequest, encodeResponse, - grpctransport.ServerBefore(serverBefore), - grpctransport.ServerAfter(serverAfter), + grpctransport.ServerBefore( + extractCorrelationID, + ), + grpctransport.ServerBefore( + displayServerRequestHeaders, + ), + grpctransport.ServerAfter( + injectResponseHeader, + injectResponseTrailer, + injectConsumedCorrelationID, + ), + grpctransport.ServerAfter( + displayServerResponseHeaders, + displayServerResponseTrailers, + ), ), } } diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 7192bb5..05a9d34 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -19,24 +19,16 @@ type RequestFunc func(context.Context, *metadata.MD) context.Context // ServerResponseFunc may take information from a request context and use it to -// manipulate the gRPC metadata header. ResponseFuncs are only executed in -// servers, after invoking the endpoint but prior to writing a response. -type ServerResponseFunc func(context.Context, *metadata.MD) +// manipulate the gRPC response metadata headers and trailers. ResponseFuncs are +// only executed in servers, after invoking the endpoint but prior to writing a +// response. +type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) // ClientResponseFunc may take information from a gRPC metadata header and/or // trailer and make the responses available for consumption. ClientResponseFuncs // are only executed in clients, after a request has been made, but prior to it // being decoded. type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context - -// SetResponseHeader returns a ResponseFunc that sets the specified metadata -// key-value pair. -func SetResponseHeader(key, val string) ServerResponseFunc { - return func(_ context.Context, md *metadata.MD) { - key, val := EncodeKeyValue(key, val) - (*md)[key] = append((*md)[key], val) - } -} // SetRequestHeader returns a RequestFunc that sets the specified metadata // key-value pair. @@ -45,6 +37,24 @@ key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) return ctx + } +} + +// SetResponseHeader returns a ResponseFunc that sets the specified metadata +// key-value pair. +func SetResponseHeader(key, val string) ServerResponseFunc { + return func(_ context.Context, md *metadata.MD, _ *metadata.MD) { + key, val := EncodeKeyValue(key, val) + (*md)[key] = append((*md)[key], val) + } +} + +// SetResponseTrailer returns a ResponseFunc that sets the specified metadata +// key-value pair. +func SetResponseTrailer(key, val string) ServerResponseFunc { + return func(_ context.Context, _ *metadata.MD, md *metadata.MD) { + key, val := EncodeKeyValue(key, val) + (*md)[key] = append((*md)[key], val) } } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 9289c4f..476902e 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -2,6 +2,7 @@ import ( oldcontext "golang.org/x/net/context" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/go-kit/kit/endpoint" @@ -96,12 +97,10 @@ return ctx, nil, err } + var mdHeader, mdTrailer metadata.MD for _, f := range s.after { - f(ctx, &md) + f(ctx, &mdHeader, &mdTrailer) } - - // Store potentially updated metadata in the gRPC context. - ctx = metadata.NewContext(ctx, md) grpcResp, err := s.enc(ctx, response) if err != nil { @@ -109,5 +108,19 @@ return ctx, nil, err } + if len(mdHeader) > 0 { + if err = grpc.SendHeader(ctx, mdHeader); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + } + + if len(mdTrailer) > 0 { + if err = grpc.SetTrailer(ctx, mdTrailer); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + } + return ctx, grpcResp, nil }