Codebase list golang-github-go-kit-kit / e9f8dfb transport / grpc / server.go
e9f8dfb

Tree @e9f8dfb (Download .tar.gz)

server.go @e9f8dfbraw · history · blame

package grpc

import (
	"golang.org/x/net/context"
	"google.golang.org/grpc/metadata"

	"github.com/go-kit/kit/endpoint"
	"github.com/go-kit/kit/log"
)

// Handler which should be called from the grpc binding of the service
// implementation.
type Handler interface {
	ServeGRPC(context.Context, interface{}) (context.Context, interface{}, error)
}

// Server wraps an endpoint and implements grpc.Handler.
type Server struct {
	ctx    context.Context
	e      endpoint.Endpoint
	dec    DecodeRequestFunc
	enc    EncodeResponseFunc
	before []RequestFunc
	after  []ResponseFunc
	logger log.Logger
}

// NewServer constructs a new server, which implements grpc.Server and wraps
// the provided endpoint.
func NewServer(
	ctx context.Context,
	e endpoint.Endpoint,
	dec DecodeRequestFunc,
	enc EncodeResponseFunc,
	options ...ServerOption,
) *Server {
	s := &Server{
		ctx:    ctx,
		e:      e,
		dec:    dec,
		enc:    enc,
		logger: log.NewNopLogger(),
	}
	for _, option := range options {
		option(s)
	}
	return s
}

// ServerOption sets an optional parameter for servers.
type ServerOption func(*Server)

// ServerBefore functions are executed on the HTTP request object before the
// request is decoded.
func ServerBefore(before ...RequestFunc) ServerOption {
	return func(s *Server) { s.before = before }
}

// ServerAfter functions are executed on the HTTP response writer after the
// endpoint is invoked, but before anything is written to the client.
func ServerAfter(after ...ResponseFunc) ServerOption {
	return func(s *Server) { s.after = after }
}

// ServerErrorLogger is used to log non-terminal errors. By default, no errors
// are logged.
func ServerErrorLogger(logger log.Logger) ServerOption {
	return func(s *Server) { s.logger = logger }
}

// ServeGRPC implements grpc.Handler
func (s Server) ServeGRPC(grpcCtx context.Context, r interface{}) (context.Context, interface{}, error) {
	ctx, cancel := context.WithCancel(s.ctx)
	defer cancel()

	// retrieve gRPC metadata
	md, ok := metadata.FromContext(grpcCtx)
	if !ok {
		md = metadata.MD{}
	}

	for _, f := range s.before {
		ctx = f(ctx, &md)
	}

	// store potentially updated metadata in the gRPC context
	grpcCtx = metadata.NewContext(grpcCtx, md)

	request, err := s.dec(grpcCtx, r)
	if err != nil {
		s.logger.Log("err", err)
		return grpcCtx, nil, BadRequestError{err}
	}

	response, err := s.e(ctx, request)
	if err != nil {
		s.logger.Log("err", err)
		return grpcCtx, nil, err
	}

	for _, f := range s.after {
		f(ctx, &md)
	}

	// store potentially updated metadata in the gRPC context
	grpcCtx = metadata.NewContext(grpcCtx, md)

	grpcResp, err := s.enc(grpcCtx, response)
	if err != nil {
		s.logger.Log("err", err)
		return grpcCtx, nil, err
	}
	return grpcCtx, grpcResp, nil
}

// BadRequestError is an error in decoding the request.
type BadRequestError struct {
	Err error
}

// Error implements the error interface.
func (err BadRequestError) Error() string {
	return err.Err.Error()
}