Codebase list golang-github-go-kit-kit / 2677e49 transport / grpc / _grpc_test / context_metadata.go
2677e49

Tree @2677e49 (Download .tar.gz)

context_metadata.go @2677e49raw · history · blame

package test

import (
	"context"
	"fmt"
	"log"

	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

type metaContext string

const (
	correlationID metaContext = "correlation-id"
	responseHDR   metaContext = "my-response-header"
	responseTRLR  metaContext = "correlation-id-consumed"
)

func clientBefore(ctx context.Context, md *metadata.MD) context.Context {
	if hdr, ok := ctx.Value(correlationID).(string); ok {
		(*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
	}
	if len(*md) > 0 {
		fmt.Println("\tClient >> Request Headers:")
		for key, val := range *md {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}
	return ctx
}

func serverBefore(ctx context.Context, md *metadata.MD) context.Context {
	if len(*md) > 0 {
		fmt.Println("\tServer << Request Headers:")
		for key, val := range *md {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}
	if hdr, ok := (*md)[string(correlationID)]; ok {
		cID := hdr[len(hdr)-1]
		ctx = context.WithValue(ctx, correlationID, cID)
		fmt.Printf("\tServer placed correlationID %q in context\n", cID)
	}
	return ctx
}

func serverAfter(ctx context.Context, _ *metadata.MD) {
	var mdHeader, mdTrailer metadata.MD

	mdHeader = metadata.Pairs(string(responseHDR), "has-a-value")
	if err := grpc.SendHeader(ctx, mdHeader); err != nil {
		log.Fatalf("unable to send header: %+v\n", err)
	}

	if hdr, ok := ctx.Value(correlationID).(string); ok {
		mdTrailer = metadata.Pairs(string(responseTRLR), hdr)
		if err := grpc.SetTrailer(ctx, mdTrailer); err != nil {
			log.Fatalf("unable to set trailer: %+v\n", err)
		}
		fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
	}
	if len(mdHeader) > 0 {
		fmt.Println("\tServer >> Response Headers:")
		for key, val := range mdHeader {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}
	if len(mdTrailer) > 0 {
		fmt.Println("\tServer >> Response Trailers:")
		for key, val := range mdTrailer {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}
}

func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context {
	if len(mdHeader) > 0 {
		fmt.Println("\tClient << Response Headers:")
		for key, val := range mdHeader {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}
	if len(mdTrailer) > 0 {
		fmt.Println("\tClient << Response Trailers:")
		for key, val := range mdTrailer {
			fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
		}
	}

	if hdr, ok := mdTrailer[string(responseTRLR)]; ok {
		ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1])
	}
	return ctx
}

func SetCorrelationID(ctx context.Context, v string) context.Context {
	return context.WithValue(ctx, correlationID, v)
}

func GetConsumedCorrelationID(ctx context.Context) string {
	if trlr, ok := ctx.Value(responseTRLR).(string); ok {
		return trlr
	}
	return ""
}