Merge pull request #203 from marshauf/apigateway_example
[WIP] apigateway example
Peter Bourgon
8 years ago
0 | package main | |
1 | ||
2 | import ( | |
3 | "encoding/json" | |
4 | "flag" | |
5 | "fmt" | |
6 | "io" | |
7 | "io/ioutil" | |
8 | stdlog "log" | |
9 | "net/http" | |
10 | "net/url" | |
11 | "os" | |
12 | "os/signal" | |
13 | "strings" | |
14 | "syscall" | |
15 | "time" | |
16 | ||
17 | "github.com/gorilla/mux" | |
18 | "github.com/hashicorp/consul/api" | |
19 | "golang.org/x/net/context" | |
20 | "google.golang.org/grpc" | |
21 | ||
22 | "github.com/go-kit/kit/endpoint" | |
23 | addsvc "github.com/go-kit/kit/examples/addsvc/client/grpc" | |
24 | "github.com/go-kit/kit/examples/addsvc/server" | |
25 | "github.com/go-kit/kit/loadbalancer" | |
26 | "github.com/go-kit/kit/loadbalancer/consul" | |
27 | "github.com/go-kit/kit/log" | |
28 | httptransport "github.com/go-kit/kit/transport/http" | |
29 | ) | |
30 | ||
31 | func main() { | |
32 | var ( | |
33 | httpAddr = flag.String("http.addr", ":8000", "Address for HTTP (JSON) server") | |
34 | consulAddr = flag.String("consul.addr", "", "Consul agent address") | |
35 | retryMax = flag.Int("retry.max", 3, "per-request retries to different instances") | |
36 | retryTimeout = flag.Duration("retry.timeout", 500*time.Millisecond, "per-request timeout, including retries") | |
37 | ) | |
38 | flag.Parse() | |
39 | ||
40 | // Log domain | |
41 | logger := log.NewLogfmtLogger(os.Stderr) | |
42 | logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC).With("caller", log.DefaultCaller) | |
43 | stdlog.SetFlags(0) // flags are handled by Go kit's logger | |
44 | stdlog.SetOutput(log.NewStdlibAdapter(logger)) // redirect anything using stdlib log to us | |
45 | ||
46 | // Service discovery domain. In this example we use Consul. | |
47 | consulConfig := api.DefaultConfig() | |
48 | if len(*consulAddr) > 0 { | |
49 | consulConfig.Address = *consulAddr | |
50 | } | |
51 | consulClient, err := api.NewClient(consulConfig) | |
52 | if err != nil { | |
53 | logger.Log("err", err) | |
54 | os.Exit(1) | |
55 | } | |
56 | discoveryClient := consul.NewClient(consulClient) | |
57 | ||
58 | // Context domain. | |
59 | ctx := context.Background() | |
60 | ||
61 | // Set up our routes. | |
62 | // | |
63 | // Each Consul service name maps to multiple instances of that service. We | |
64 | // connect to each instance according to its pre-determined transport: in this | |
65 | // case, we choose to access addsvc via its gRPC client, and stringsvc over | |
66 | // plain transport/http (it has no client package). | |
67 | // | |
68 | // Each service instance implements multiple methods, and we want to map each | |
69 | // method to a unique path on the API gateway. So, we define that path and its | |
70 | // corresponding factory function, which takes an instance string and returns an | |
71 | // endpoint.Endpoint for the specific method. | |
72 | // | |
73 | // Finally, we mount that path + endpoint handler into the router. | |
74 | r := mux.NewRouter() | |
75 | for consulName, methods := range map[string][]struct { | |
76 | path string | |
77 | factory loadbalancer.Factory | |
78 | }{ | |
79 | "addsvc": { | |
80 | {path: "/api/addsvc/concat", factory: addsvcGRPCFactory(ctx, makeConcatEndpoint, logger)}, | |
81 | {path: "/api/addsvc/sum", factory: addsvcGRPCFactory(ctx, makeSumEndpoint, logger)}, | |
82 | }, | |
83 | "stringsvc": { | |
84 | {path: "/api/stringsvc/uppercase", factory: httpFactory(ctx, "GET", "uppercase/")}, | |
85 | {path: "/api/stringsvc/concat", factory: httpFactory(ctx, "GET", "concat/")}, | |
86 | }, | |
87 | } { | |
88 | for _, method := range methods { | |
89 | publisher, err := consul.NewPublisher(discoveryClient, method.factory, logger, consulName) | |
90 | if err != nil { | |
91 | logger.Log("service", consulName, "path", method.path, "err", err) | |
92 | continue | |
93 | } | |
94 | lb := loadbalancer.NewRoundRobin(publisher) | |
95 | e := loadbalancer.Retry(*retryMax, *retryTimeout, lb) | |
96 | h := makeHandler(ctx, e, logger) | |
97 | r.HandleFunc(method.path, h) | |
98 | } | |
99 | } | |
100 | ||
101 | // Mechanical stuff. | |
102 | errc := make(chan error) | |
103 | go func() { | |
104 | errc <- interrupt() | |
105 | }() | |
106 | go func() { | |
107 | logger.Log("transport", "http", "addr", *httpAddr) | |
108 | errc <- http.ListenAndServe(*httpAddr, r) | |
109 | }() | |
110 | logger.Log("err", <-errc) | |
111 | } | |
112 | ||
113 | func makeHandler(ctx context.Context, e endpoint.Endpoint, logger log.Logger) http.HandlerFunc { | |
114 | return func(w http.ResponseWriter, r *http.Request) { | |
115 | resp, err := e(ctx, r.Body) | |
116 | if err != nil { | |
117 | logger.Log("err", err) | |
118 | http.Error(w, err.Error(), http.StatusInternalServerError) | |
119 | return | |
120 | } | |
121 | b, ok := resp.([]byte) | |
122 | if !ok { | |
123 | logger.Log("err", "endpoint response is not of type []byte") | |
124 | http.Error(w, err.Error(), http.StatusInternalServerError) | |
125 | return | |
126 | } | |
127 | _, err = w.Write(b) | |
128 | if err != nil { | |
129 | logger.Log("err", err) | |
130 | return | |
131 | } | |
132 | } | |
133 | } | |
134 | ||
135 | func addsvcGRPCFactory(ctx context.Context, makeEndpoint func(server.AddService) endpoint.Endpoint, logger log.Logger) loadbalancer.Factory { | |
136 | return func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
137 | var e endpoint.Endpoint | |
138 | conn, err := grpc.Dial(instance, grpc.WithInsecure()) | |
139 | if err != nil { | |
140 | return e, nil, err | |
141 | } | |
142 | svc := addsvc.New(ctx, conn, logger) | |
143 | return makeEndpoint(svc), nil, nil | |
144 | } | |
145 | } | |
146 | ||
147 | func makeSumEndpoint(svc server.AddService) endpoint.Endpoint { | |
148 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
149 | r := request.(io.Reader) | |
150 | var req server.SumRequest | |
151 | if err := json.NewDecoder(r).Decode(&req); err != nil { | |
152 | return nil, err | |
153 | } | |
154 | v := svc.Sum(req.A, req.B) | |
155 | return json.Marshal(v) | |
156 | } | |
157 | } | |
158 | ||
159 | func makeConcatEndpoint(svc server.AddService) endpoint.Endpoint { | |
160 | return func(ctx context.Context, request interface{}) (interface{}, error) { | |
161 | r := request.(io.Reader) | |
162 | var req server.ConcatRequest | |
163 | if err := json.NewDecoder(r).Decode(&req); err != nil { | |
164 | return nil, err | |
165 | } | |
166 | v := svc.Concat(req.A, req.B) | |
167 | return json.Marshal(v) | |
168 | } | |
169 | } | |
170 | ||
171 | func httpFactory(ctx context.Context, method, path string) loadbalancer.Factory { | |
172 | return func(instance string) (endpoint.Endpoint, io.Closer, error) { | |
173 | var e endpoint.Endpoint | |
174 | if !strings.HasPrefix(instance, "http") { | |
175 | instance = "http://" + instance | |
176 | } | |
177 | u, err := url.Parse(instance) | |
178 | if err != nil { | |
179 | return nil, nil, err | |
180 | } | |
181 | u.Path = path | |
182 | ||
183 | e = httptransport.NewClient(method, u, passEncode, passDecode).Endpoint() | |
184 | return e, nil, nil | |
185 | } | |
186 | } | |
187 | ||
188 | func passEncode(r *http.Request, request interface{}) error { | |
189 | r.Body = request.(io.ReadCloser) | |
190 | return nil | |
191 | } | |
192 | ||
193 | func passDecode(r *http.Response) (interface{}, error) { | |
194 | return ioutil.ReadAll(r.Body) | |
195 | } | |
196 | ||
197 | func interrupt() error { | |
198 | c := make(chan os.Signal) | |
199 | signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) | |
200 | return fmt.Errorf("%s", <-c) | |
201 | } |