Codebase list golang-github-influxdata-yarpc / 64bb1c83-1ef7-4be1-9ecd-61368bc45879/main server.go
64bb1c83-1ef7-4be1-9ecd-61368bc45879/main

Tree @64bb1c83-1ef7-4be1-9ecd-61368bc45879/main (Download .tar.gz)

server.go @64bb1c83-1ef7-4be1-9ecd-61368bc45879/mainraw · history · blame

package yarpc

import (
	"net"
	"sync"

	"encoding/binary"
	"io"

	"context"

	"reflect"

	"log"

	"github.com/influxdata/yamux"
	"github.com/influxdata/yarpc/codes"
	"github.com/influxdata/yarpc/status"
)

type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error) (interface{}, error)

type MethodDesc struct {
	Index      uint8
	MethodName string
	Handler    methodHandler
}

// ServiceDesc represents an RPC service's specification.
type ServiceDesc struct {
	Index       uint8
	ServiceName string
	// The pointer to the service interface. Used to check whether the user
	// provided implementation satisfies the interface requirements.
	HandlerType interface{}
	Methods     []MethodDesc
	Streams     []StreamDesc
	Metadata    interface{}
}

type service struct {
	server interface{}
	md     map[uint8]*MethodDesc
	sd     map[uint8]*StreamDesc
}

type Server struct {
	opts  options
	m     map[uint8]*service
	serve bool
	lis   net.Listener
	lisMu sync.Mutex
}

type options struct {
	codec Codec
}

type ServerOption func(*options)

func CustomCodec(c Codec) ServerOption {
	return func(o *options) {
		o.codec = c
	}
}

func NewServer(opts ...ServerOption) *Server {
	s := &Server{
		m: make(map[uint8]*service),
	}

	for _, opt := range opts {
		opt(&s.opts)
	}

	// defaults
	if s.opts.codec == nil {
		s.opts.codec = NewCodec()
	}

	return s
}

// RegisterService registers a service and its implementation to the gRPC
// server. It is called from the IDL generated code. This must be called before
// invoking Serve.
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
	ht := reflect.TypeOf(sd.HandlerType).Elem()
	st := reflect.TypeOf(ss)
	if !st.Implements(ht) {
		log.Fatalf("rpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
	}
	s.register(sd, ss)
}

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
	// s.opts.log.Info("register service", zap.String("name", sd.ServiceName), zap.Uint("index", uint(sd.Index)))
	if s.serve {
		log.Fatalf("rpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
	}
	if _, ok := s.m[sd.Index]; ok {
		log.Fatalf("rpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
	}

	srv := &service{
		server: ss,
		md:     make(map[uint8]*MethodDesc),
		sd:     make(map[uint8]*StreamDesc),
	}
	for i := range sd.Methods {
		d := &sd.Methods[i]
		srv.md[d.Index] = d
	}
	for i := range sd.Streams {
		d := &sd.Streams[i]
		srv.sd[d.Index] = d
	}
	s.m[sd.Index] = srv
}

func (s *Server) Serve(lis net.Listener) error {
	s.lisMu.Lock()
	s.lis = lis
	s.lisMu.Unlock()

	for {
		rawConn, err := lis.Accept()
		if err != nil {
			if ne, ok := err.(interface {
				Temporary() bool
			}); ok && ne.Temporary() {
				// TODO(sgc): add logic to handle temporary errors
			}
			return err
		}

		go s.handleRawConn(rawConn)
	}
}

func (s *Server) Stop() {
	s.lisMu.Lock()
	defer s.lisMu.Unlock()

	if s.lis != nil {
		s.lis.Close()
		s.lis = nil
	}
}

func (s *Server) handleRawConn(rawConn net.Conn) {
	session, err := yamux.Server(rawConn, nil)
	if err != nil {
		log.Printf("ERR yamux.Server failed: error=%v", err)
		rawConn.Close()
		return
	}

	s.serveSession(session)
}

func (s *Server) serveSession(session *yamux.Session) {
	for {
		stream, err := session.AcceptStream()
		if err != nil {
			if err != io.EOF {
				// TODO(sgc): handle session errors
				log.Printf("ERR session.AcceptStream failed: error=%v", err)
				session.Close()
			}
			return
		}

		go s.handleStream(stream)
	}
}

func decodeServiceMethod(v uint16) (svc, mth uint8) {
	//┌────────────────────────┬────────────────────────┐
	//│      SERVICE (8)       │       METHOD (8)       │
	//└────────────────────────┴────────────────────────┘

	return uint8(v >> 8), uint8(v)
}

func (s *Server) handleStream(st *yamux.Stream) {
	defer st.Close()

	var tmp [2]byte
	io.ReadAtLeast(st, tmp[:], 2)
	service, method := decodeServiceMethod(binary.BigEndian.Uint16(tmp[:]))
	srv, ok := s.m[service]
	if !ok {
		// TODO(sgc): handle unknown service
		log.Printf("invalid service identifier: service=%d", service)
		return
	}

	if md, ok := srv.md[method]; ok {
		// handle unary
		s.handleUnaryRPC(st, srv, md)
		return
	}

	if sd, ok := srv.sd[method]; ok {
		// handle unary
		s.handleStreamingRPC(st, srv, sd)
		return
	}

	// TODO(sgc): handle unknown method
	log.Printf("ERR invalid method identifier: service=%d method=%d", service, method)
}

func (s *Server) handleStreamingRPC(st *yamux.Stream, srv *service, sd *StreamDesc) {
	ss := &serverStream{
		cn:    st,
		codec: s.opts.codec,
		p:     &parser{r: st},
	}

	var appErr error
	var server interface{}
	if srv != nil {
		server = srv.server
	}

	appErr = sd.Handler(server, ss)
	if appErr != nil {
		// TODO(sgc): handle app error using similar code style to gRPC
		log.Printf("ERR sd.Handler failed: error=%v", appErr)
		// appStatus, ok := status.FromError(appErr)
		return
	}

	// TODO(sgc): write OK status?
}

func (s *Server) handleUnaryRPC(st *yamux.Stream, srv *service, md *MethodDesc) error {
	p := &parser{r: st}
	req, err := p.recvMsg()
	if err == io.EOF {
		return err
	}

	if err == io.ErrUnexpectedEOF {
		return status.Errorf(codes.Internal, err.Error())
	}

	df := func(v interface{}) error {
		if err := s.opts.codec.Unmarshal(req, v); err != nil {
			return status.Errorf(codes.Internal, "rpc: error unmarshalling request: %v", err)
		}
		return nil
	}

	reply, appErr := md.Handler(srv.server, context.Background(), df)
	if appErr != nil {
		appStatus, ok := status.FromError(appErr)
		if !ok {
			// convert to app error
			appStatus = &status.Status{Code: codes.Unknown, Message: appErr.Error()}
			appErr = appStatus
		}

		// TODO(sgc): write error status
		return appErr
	}

	if err := s.sendResponse(st, reply); err != nil {
		if err == io.EOF {
			return err
		}

		if s, ok := status.FromError(err); ok {
			// TODO(sgc): write error status
			_ = s
		}

		return err
	}

	// TODO(sgc): write OK status
	return nil
}

func (s *Server) sendResponse(stream *yamux.Stream, msg interface{}) error {
	buf, err := encode(s.opts.codec, msg)
	if err != nil {
		// s.opts.log.Error("rpc: server failed to encode reply", zap.Error(err))
		return err
	}

	_, err = stream.Write(buf)
	return err
}