diff --git a/examples/apigateway/main.go b/examples/apigateway/main.go index fe3ea13..5781c82 100644 --- a/examples/apigateway/main.go +++ b/examples/apigateway/main.go @@ -13,134 +13,127 @@ "os/signal" "strings" "syscall" + "time" + + "github.com/gorilla/mux" + "github.com/hashicorp/consul/api" + "golang.org/x/net/context" + "google.golang.org/grpc" "github.com/go-kit/kit/endpoint" addsvc "github.com/go-kit/kit/examples/addsvc/client/grpc" "github.com/go-kit/kit/examples/addsvc/server" "github.com/go-kit/kit/loadbalancer" "github.com/go-kit/kit/loadbalancer/consul" - log "github.com/go-kit/kit/log" - //grpctransport "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/log" httptransport "github.com/go-kit/kit/transport/http" - //proto "github.com/golang/protobuf/proto" - "github.com/gorilla/mux" - "github.com/hashicorp/consul/api" - "golang.org/x/net/context" - "google.golang.org/grpc" ) func main() { - fs := flag.NewFlagSet("", flag.ExitOnError) var ( - httpAddr = fs.String("http.addr", ":8000", "Address for HTTP (JSON) server") - consulAddr = fs.String("consul.addr", "", "Consul agent address") + httpAddr = flag.String("http.addr", ":8000", "Address for HTTP (JSON) server") + consulAddr = flag.String("consul.addr", "", "Consul agent address") + retryMax = flag.Int("retry.max", 3, "per-request retries to different instances") + retryTimeout = flag.Duration("retry.timeout", 500*time.Millisecond, "per-request timeout, including retries") ) - flag.Usage = fs.Usage - if err := fs.Parse(os.Args[1:]); err != nil { - fmt.Fprintf(os.Stderr, "%v", err) - os.Exit(1) - } - - // log + flag.Parse() + + // Log domain logger := log.NewLogfmtLogger(os.Stderr) logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC).With("caller", log.DefaultCaller) stdlog.SetFlags(0) // flags are handled by Go kit's logger stdlog.SetOutput(log.NewStdlibAdapter(logger)) // redirect anything using stdlib log to us - // errors + // Service discovery domain. In this example we use Consul. + consulConfig := api.DefaultConfig() + if len(*consulAddr) > 0 { + consulConfig.Address = *consulAddr + } + consulClient, err := api.NewClient(consulConfig) + if err != nil { + logger.Log("err", err) + os.Exit(1) + } + discoveryClient := consul.NewClient(consulClient) + + // Context domain. + ctx := context.Background() + + // Set up our routes. + // + // Each Consul service name maps to multiple instances of that service. We + // connect to each instance according to its pre-determined transport: in this + // case, we choose to access addsvc via its gRPC client, and stringsvc over + // plain transport/http (it has no client package). + // + // Each service instance implements multiple methods, and we want to map each + // method to a unique path on the API gateway. So, we define that path and its + // corresponding factory function, which takes an instance string and returns an + // endpoint.Endpoint for the specific method. + // + // Finally, we mount that path + endpoint handler into the router. + r := mux.NewRouter() + for consulName, methods := range map[string][]struct { + path string + factory loadbalancer.Factory + }{ + "addsvc": { + {path: "/api/addsvc/concat", factory: addsvcGRPCFactory(ctx, makeConcatEndpoint, logger)}, + {path: "/api/addsvc/sum", factory: addsvcGRPCFactory(ctx, makeSumEndpoint, logger)}, + }, + "stringsvc": { + {path: "/api/stringsvc/uppercase", factory: httpFactory(ctx, "GET", "uppercase/")}, + {path: "/api/stringsvc/concat", factory: httpFactory(ctx, "GET", "concat/")}, + }, + } { + for _, method := range methods { + publisher, err := consul.NewPublisher(discoveryClient, method.factory, logger, consulName) + if err != nil { + logger.Log("service", consulName, "path", method.path, "err", err) + continue + } + lb := loadbalancer.NewRoundRobin(publisher) + e := loadbalancer.Retry(*retryMax, *retryTimeout, lb) + h := makeHandler(ctx, e, logger) + r.HandleFunc(method.path, h) + } + } + + // Mechanical stuff. errc := make(chan error) go func() { errc <- interrupt() }() - - // consul - consulConfig := api.DefaultConfig() - if len(*consulAddr) > 0 { - consulConfig.Address = *consulAddr - } - consulClient, err := api.NewClient(consulConfig) - if err != nil { - logger.Log("fatal", err) - } - discoveryClient := consul.NewClient(consulClient) - - ctx := context.Background() - - // service definitions - serviceDefs := []*ServiceDef{} - serviceDefs = append(serviceDefs, &ServiceDef{ - Name: "addsvc", - Endpoints: map[string]loadbalancer.Factory{ - "/api/addsvc/concat": factoryAddsvc(ctx, logger, makeConcatEndpoint), - "/api/addsvc/sum": factoryAddsvc(ctx, logger, makeSumEndpoint), - }, - }) - serviceDefs = append(serviceDefs, &ServiceDef{ - Name: "stringsvc", - Endpoints: map[string]loadbalancer.Factory{ - "/api/stringsvc/uppercase": routeFactory(ctx, "uppercase"), - "/api/stringsvc/count": routeFactory(ctx, "count"), - }, - }) - - // discover instances and register endpoints - r := mux.NewRouter() - for _, def := range serviceDefs { - for path, e := range def.Endpoints { - pub, err := consul.NewPublisher(discoveryClient, e, logger, def.Name) - if err != nil { - logger.Log("fatal", err) - } - r.HandleFunc(path, makeHandler(ctx, loadbalancer.NewRoundRobin(pub), logger)) - } - } - - // apigateway go func() { + logger.Log("transport", "http", "addr", *httpAddr) errc <- http.ListenAndServe(*httpAddr, r) }() - - // wait for interrupt/error - logger.Log("fatal", <-errc) -} - -type ServiceDef struct { - Name string - Endpoints map[string]loadbalancer.Factory -} - -func interrupt() error { - c := make(chan os.Signal) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) - return fmt.Errorf("%s", <-c) -} - -func makeHandler(ctx context.Context, lb loadbalancer.LoadBalancer, logger log.Logger) http.HandlerFunc { + logger.Log("err", <-errc) +} + +func makeHandler(ctx context.Context, e endpoint.Endpoint, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - e, err := lb.Endpoint() - if err != nil { - logger.Log("error", err) - return - } resp, err := e(ctx, r.Body) if err != nil { - logger.Log("error", err) + logger.Log("err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) return } b, ok := resp.([]byte) if !ok { - logger.Log("error", "endpoint response is not of type []byte") + logger.Log("err", "endpoint response is not of type []byte") + http.Error(w, err.Error(), http.StatusInternalServerError) return } _, err = w.Write(b) if err != nil { - logger.Log("error", err) + logger.Log("err", err) return } } } -func factoryAddsvc(ctx context.Context, logger log.Logger, maker func(server.AddService) endpoint.Endpoint) loadbalancer.Factory { +func addsvcGRPCFactory(ctx context.Context, makeEndpoint func(server.AddService) endpoint.Endpoint, logger log.Logger) loadbalancer.Factory { return func(instance string) (endpoint.Endpoint, io.Closer, error) { var e endpoint.Endpoint conn, err := grpc.Dial(instance, grpc.WithInsecure()) @@ -148,7 +141,7 @@ return e, nil, err } svc := addsvc.New(ctx, conn, logger) - return maker(svc), nil, nil + return makeEndpoint(svc), nil, nil } } @@ -176,7 +169,7 @@ } } -func routeFactory(ctx context.Context, method string) loadbalancer.Factory { +func httpFactory(ctx context.Context, method, path string) loadbalancer.Factory { return func(instance string) (endpoint.Endpoint, io.Closer, error) { var e endpoint.Endpoint if !strings.HasPrefix(instance, "http") { @@ -186,9 +179,9 @@ if err != nil { return nil, nil, err } - u.Path = method - - e = httptransport.NewClient("GET", u, passEncode, passDecode).Endpoint() + u.Path = path + + e = httptransport.NewClient(method, u, passEncode, passDecode).Endpoint() return e, nil, nil } } @@ -201,3 +194,9 @@ func passDecode(r *http.Response) (interface{}, error) { return ioutil.ReadAll(r.Body) } + +func interrupt() error { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + return fmt.Errorf("%s", <-c) +}