Codebase list golang-github-go-kit-kit / ff58f40
Merge pull request #203 from marshauf/apigateway_example [WIP] apigateway example Peter Bourgon 8 years ago
1 changed file(s) with 202 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 package main
1
2 import (
3 "encoding/json"
4 "flag"
5 "fmt"
6 "io"
7 "io/ioutil"
8 stdlog "log"
9 "net/http"
10 "net/url"
11 "os"
12 "os/signal"
13 "strings"
14 "syscall"
15 "time"
16
17 "github.com/gorilla/mux"
18 "github.com/hashicorp/consul/api"
19 "golang.org/x/net/context"
20 "google.golang.org/grpc"
21
22 "github.com/go-kit/kit/endpoint"
23 addsvc "github.com/go-kit/kit/examples/addsvc/client/grpc"
24 "github.com/go-kit/kit/examples/addsvc/server"
25 "github.com/go-kit/kit/loadbalancer"
26 "github.com/go-kit/kit/loadbalancer/consul"
27 "github.com/go-kit/kit/log"
28 httptransport "github.com/go-kit/kit/transport/http"
29 )
30
31 func main() {
32 var (
33 httpAddr = flag.String("http.addr", ":8000", "Address for HTTP (JSON) server")
34 consulAddr = flag.String("consul.addr", "", "Consul agent address")
35 retryMax = flag.Int("retry.max", 3, "per-request retries to different instances")
36 retryTimeout = flag.Duration("retry.timeout", 500*time.Millisecond, "per-request timeout, including retries")
37 )
38 flag.Parse()
39
40 // Log domain
41 logger := log.NewLogfmtLogger(os.Stderr)
42 logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC).With("caller", log.DefaultCaller)
43 stdlog.SetFlags(0) // flags are handled by Go kit's logger
44 stdlog.SetOutput(log.NewStdlibAdapter(logger)) // redirect anything using stdlib log to us
45
46 // Service discovery domain. In this example we use Consul.
47 consulConfig := api.DefaultConfig()
48 if len(*consulAddr) > 0 {
49 consulConfig.Address = *consulAddr
50 }
51 consulClient, err := api.NewClient(consulConfig)
52 if err != nil {
53 logger.Log("err", err)
54 os.Exit(1)
55 }
56 discoveryClient := consul.NewClient(consulClient)
57
58 // Context domain.
59 ctx := context.Background()
60
61 // Set up our routes.
62 //
63 // Each Consul service name maps to multiple instances of that service. We
64 // connect to each instance according to its pre-determined transport: in this
65 // case, we choose to access addsvc via its gRPC client, and stringsvc over
66 // plain transport/http (it has no client package).
67 //
68 // Each service instance implements multiple methods, and we want to map each
69 // method to a unique path on the API gateway. So, we define that path and its
70 // corresponding factory function, which takes an instance string and returns an
71 // endpoint.Endpoint for the specific method.
72 //
73 // Finally, we mount that path + endpoint handler into the router.
74 r := mux.NewRouter()
75 for consulName, methods := range map[string][]struct {
76 path string
77 factory loadbalancer.Factory
78 }{
79 "addsvc": {
80 {path: "/api/addsvc/concat", factory: addsvcGRPCFactory(ctx, makeConcatEndpoint, logger)},
81 {path: "/api/addsvc/sum", factory: addsvcGRPCFactory(ctx, makeSumEndpoint, logger)},
82 },
83 "stringsvc": {
84 {path: "/api/stringsvc/uppercase", factory: httpFactory(ctx, "GET", "uppercase/")},
85 {path: "/api/stringsvc/concat", factory: httpFactory(ctx, "GET", "concat/")},
86 },
87 } {
88 for _, method := range methods {
89 publisher, err := consul.NewPublisher(discoveryClient, method.factory, logger, consulName)
90 if err != nil {
91 logger.Log("service", consulName, "path", method.path, "err", err)
92 continue
93 }
94 lb := loadbalancer.NewRoundRobin(publisher)
95 e := loadbalancer.Retry(*retryMax, *retryTimeout, lb)
96 h := makeHandler(ctx, e, logger)
97 r.HandleFunc(method.path, h)
98 }
99 }
100
101 // Mechanical stuff.
102 errc := make(chan error)
103 go func() {
104 errc <- interrupt()
105 }()
106 go func() {
107 logger.Log("transport", "http", "addr", *httpAddr)
108 errc <- http.ListenAndServe(*httpAddr, r)
109 }()
110 logger.Log("err", <-errc)
111 }
112
113 func makeHandler(ctx context.Context, e endpoint.Endpoint, logger log.Logger) http.HandlerFunc {
114 return func(w http.ResponseWriter, r *http.Request) {
115 resp, err := e(ctx, r.Body)
116 if err != nil {
117 logger.Log("err", err)
118 http.Error(w, err.Error(), http.StatusInternalServerError)
119 return
120 }
121 b, ok := resp.([]byte)
122 if !ok {
123 logger.Log("err", "endpoint response is not of type []byte")
124 http.Error(w, err.Error(), http.StatusInternalServerError)
125 return
126 }
127 _, err = w.Write(b)
128 if err != nil {
129 logger.Log("err", err)
130 return
131 }
132 }
133 }
134
135 func addsvcGRPCFactory(ctx context.Context, makeEndpoint func(server.AddService) endpoint.Endpoint, logger log.Logger) loadbalancer.Factory {
136 return func(instance string) (endpoint.Endpoint, io.Closer, error) {
137 var e endpoint.Endpoint
138 conn, err := grpc.Dial(instance, grpc.WithInsecure())
139 if err != nil {
140 return e, nil, err
141 }
142 svc := addsvc.New(ctx, conn, logger)
143 return makeEndpoint(svc), nil, nil
144 }
145 }
146
147 func makeSumEndpoint(svc server.AddService) endpoint.Endpoint {
148 return func(ctx context.Context, request interface{}) (interface{}, error) {
149 r := request.(io.Reader)
150 var req server.SumRequest
151 if err := json.NewDecoder(r).Decode(&req); err != nil {
152 return nil, err
153 }
154 v := svc.Sum(req.A, req.B)
155 return json.Marshal(v)
156 }
157 }
158
159 func makeConcatEndpoint(svc server.AddService) endpoint.Endpoint {
160 return func(ctx context.Context, request interface{}) (interface{}, error) {
161 r := request.(io.Reader)
162 var req server.ConcatRequest
163 if err := json.NewDecoder(r).Decode(&req); err != nil {
164 return nil, err
165 }
166 v := svc.Concat(req.A, req.B)
167 return json.Marshal(v)
168 }
169 }
170
171 func httpFactory(ctx context.Context, method, path string) loadbalancer.Factory {
172 return func(instance string) (endpoint.Endpoint, io.Closer, error) {
173 var e endpoint.Endpoint
174 if !strings.HasPrefix(instance, "http") {
175 instance = "http://" + instance
176 }
177 u, err := url.Parse(instance)
178 if err != nil {
179 return nil, nil, err
180 }
181 u.Path = path
182
183 e = httptransport.NewClient(method, u, passEncode, passDecode).Endpoint()
184 return e, nil, nil
185 }
186 }
187
188 func passEncode(r *http.Request, request interface{}) error {
189 r.Body = request.(io.ReadCloser)
190 return nil
191 }
192
193 func passDecode(r *http.Response) (interface{}, error) {
194 return ioutil.ReadAll(r.Body)
195 }
196
197 func interrupt() error {
198 c := make(chan os.Signal)
199 signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
200 return fmt.Errorf("%s", <-c)
201 }