diff --git a/transport/awslambda/doc.go b/transport/awslambda/doc.go new file mode 100644 index 0000000..18942a1 --- /dev/null +++ b/transport/awslambda/doc.go @@ -0,0 +1,2 @@ +// Package awslambda provides an AWS Lambda transport layer. +package awslambda diff --git a/transport/awslambda/encode_decode.go b/transport/awslambda/encode_decode.go new file mode 100644 index 0000000..6cb5a69 --- /dev/null +++ b/transport/awslambda/encode_decode.go @@ -0,0 +1,16 @@ +package awslambda + +import ( + "context" +) + +// DecodeRequestFunc extracts a user-domain request object from an +// AWS Lambda payload. +type DecodeRequestFunc func(context.Context, []byte) (interface{}, error) + +// EncodeResponseFunc encodes the passed response object into []byte, +// ready to be sent as AWS Lambda response. +type EncodeResponseFunc func(context.Context, interface{}) ([]byte, error) + +// ErrorEncoder is responsible for encoding an error. +type ErrorEncoder func(ctx context.Context, err error) ([]byte, error) diff --git a/transport/awslambda/handler.go b/transport/awslambda/handler.go new file mode 100644 index 0000000..1aedb28 --- /dev/null +++ b/transport/awslambda/handler.go @@ -0,0 +1,120 @@ +package awslambda + +import ( + "context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +// Handler wraps an endpoint. +type Handler struct { + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + before []HandlerRequestFunc + after []HandlerResponseFunc + errorEncoder ErrorEncoder + finalizer []HandlerFinalizerFunc + logger log.Logger +} + +// NewHandler constructs a new handler, which implements +// the AWS lambda.Handler interface. +func NewHandler( + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + options ...HandlerOption, +) *Handler { + h := &Handler{ + e: e, + dec: dec, + enc: enc, + logger: log.NewNopLogger(), + errorEncoder: DefaultErrorEncoder, + } + for _, option := range options { + option(h) + } + return h +} + +// HandlerOption sets an optional parameter for handlers. +type HandlerOption func(*Handler) + +// HandlerBefore functions are executed on the payload byte, +// before the request is decoded. +func HandlerBefore(before ...HandlerRequestFunc) HandlerOption { + return func(h *Handler) { h.before = append(h.before, before...) } +} + +// HandlerAfter functions are only executed after invoking the endpoint +// but prior to returning a response. +func HandlerAfter(after ...HandlerResponseFunc) HandlerOption { + return func(h *Handler) { h.after = append(h.after, after...) } +} + +// HandlerErrorLogger is used to log non-terminal errors. +// By default, no errors are logged. +func HandlerErrorLogger(logger log.Logger) HandlerOption { + return func(h *Handler) { h.logger = logger } +} + +// HandlerErrorEncoder is used to encode errors. +func HandlerErrorEncoder(ee ErrorEncoder) HandlerOption { + return func(h *Handler) { h.errorEncoder = ee } +} + +// HandlerFinalizer sets finalizer which are called at the end of +// request. By default no finalizer is registered. +func HandlerFinalizer(f ...HandlerFinalizerFunc) HandlerOption { + return func(h *Handler) { h.finalizer = append(h.finalizer, f...) } +} + +// DefaultErrorEncoder defines the default behavior of encoding an error response, +// where it returns nil, and the error itself. +func DefaultErrorEncoder(ctx context.Context, err error) ([]byte, error) { + return nil, err +} + +// Invoke represents implementation of the AWS lambda.Handler interface. +func (h *Handler) Invoke( + ctx context.Context, + payload []byte, +) (resp []byte, err error) { + if len(h.finalizer) > 0 { + defer func() { + for _, f := range h.finalizer { + f(ctx, resp, err) + } + }() + } + + for _, f := range h.before { + ctx = f(ctx, payload) + } + + request, err := h.dec(ctx, payload) + if err != nil { + h.logger.Log("err", err) + return h.errorEncoder(ctx, err) + } + + response, err := h.e(ctx, request) + if err != nil { + h.logger.Log("err", err) + return h.errorEncoder(ctx, err) + } + + for _, f := range h.after { + ctx = f(ctx, response) + } + + if resp, err = h.enc(ctx, response); err != nil { + h.logger.Log("err", err) + return h.errorEncoder(ctx, err) + } + + return resp, err +} diff --git a/transport/awslambda/handler_test.go b/transport/awslambda/handler_test.go new file mode 100644 index 0000000..8add6be --- /dev/null +++ b/transport/awslambda/handler_test.go @@ -0,0 +1,350 @@ +package awslambda + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/aws/aws-lambda-go/events" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +type key int + +const ( + KeyBeforeOne key = iota + KeyBeforeTwo key = iota + KeyAfterOne key = iota + KeyEncMode key = iota +) + +func TestDefaultErrorEncoder(t *testing.T) { + ctx := context.Background() + rootErr := fmt.Errorf("root") + b, err := DefaultErrorEncoder(ctx, rootErr) + if b != nil { + t.Fatalf("DefaultErrorEncoder should return nil as []byte") + } + if err != rootErr { + t.Fatalf("DefaultErrorEncoder expects return back the given error.") + } +} + +func TestInvokeHappyPath(t *testing.T) { + svc := serviceTest01{} + + helloHandler := NewHandler( + makeTest01HelloEndpoint(svc), + decodeHelloRequestWithTwoBefores, + encodeResponse, + HandlerErrorLogger(log.NewNopLogger()), + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeOne, "bef1") + return ctx + }), + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2") + return ctx + }), + HandlerAfter(func( + ctx context.Context, + response interface{}, + ) context.Context { + ctx = context.WithValue(ctx, KeyAfterOne, "af1") + return ctx + }), + HandlerAfter(func( + ctx context.Context, + response interface{}, + ) context.Context { + if _, ok := ctx.Value(KeyAfterOne).(string); !ok { + t.Fatalf("Value was not set properly during multi HandlerAfter") + } + return ctx + }), + HandlerFinalizer(func( + _ context.Context, + resp []byte, + _ error, + ) { + apigwResp := events.APIGatewayProxyResponse{} + err := json.Unmarshal(resp, &apigwResp) + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + response := helloResponse{} + err = json.Unmarshal([]byte(apigwResp.Body), &response) + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + expectedGreeting := "hello john doe bef1 bef2" + if response.Greeting != expectedGreeting { + t.Fatalf( + "Expect: %s, Actual: %s", expectedGreeting, response.Greeting) + } + }), + ) + + ctx := context.Background() + req, _ := json.Marshal(events.APIGatewayProxyRequest{ + Body: `{"name":"john doe"}`, + }) + resp, err := helloHandler.Invoke(ctx, req) + + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + apigwResp := events.APIGatewayProxyResponse{} + err = json.Unmarshal(resp, &apigwResp) + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + response := helloResponse{} + err = json.Unmarshal([]byte(apigwResp.Body), &response) + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + expectedGreeting := "hello john doe bef1 bef2" + if response.Greeting != expectedGreeting { + t.Fatalf( + "Expect: %s, Actual: %s", expectedGreeting, response.Greeting) + } +} + +func TestInvokeFailDecode(t *testing.T) { + svc := serviceTest01{} + + helloHandler := NewHandler( + makeTest01HelloEndpoint(svc), + decodeHelloRequestWithTwoBefores, + encodeResponse, + HandlerErrorEncoder(func( + ctx context.Context, + err error, + ) ([]byte, error) { + apigwResp := events.APIGatewayProxyResponse{} + apigwResp.Body = `{"error":"yes"}` + apigwResp.StatusCode = 500 + resp, err := json.Marshal(apigwResp) + return resp, err + }), + ) + + ctx := context.Background() + req, _ := json.Marshal(events.APIGatewayProxyRequest{ + Body: `{"name":"john doe"}`, + }) + resp, err := helloHandler.Invoke(ctx, req) + + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + apigwResp := events.APIGatewayProxyResponse{} + json.Unmarshal(resp, &apigwResp) + if apigwResp.StatusCode != 500 { + t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode) + } +} + +func TestInvokeFailEndpoint(t *testing.T) { + svc := serviceTest01{} + + helloHandler := NewHandler( + makeTest01FailEndpoint(svc), + decodeHelloRequestWithTwoBefores, + encodeResponse, + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeOne, "bef1") + return ctx + }), + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2") + return ctx + }), + HandlerErrorEncoder(func( + ctx context.Context, + err error, + ) ([]byte, error) { + apigwResp := events.APIGatewayProxyResponse{} + apigwResp.Body = `{"error":"yes"}` + apigwResp.StatusCode = 500 + resp, err := json.Marshal(apigwResp) + return resp, err + }), + ) + + ctx := context.Background() + req, _ := json.Marshal(events.APIGatewayProxyRequest{ + Body: `{"name":"john doe"}`, + }) + resp, err := helloHandler.Invoke(ctx, req) + + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + apigwResp := events.APIGatewayProxyResponse{} + json.Unmarshal(resp, &apigwResp) + if apigwResp.StatusCode != 500 { + t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode) + } +} + +func TestInvokeFailEncode(t *testing.T) { + svc := serviceTest01{} + + helloHandler := NewHandler( + makeTest01HelloEndpoint(svc), + decodeHelloRequestWithTwoBefores, + encodeResponse, + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeOne, "bef1") + return ctx + }), + HandlerBefore(func( + ctx context.Context, + payload []byte, + ) context.Context { + ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2") + return ctx + }), + HandlerAfter(func( + ctx context.Context, + response interface{}, + ) context.Context { + ctx = context.WithValue(ctx, KeyEncMode, "fail_encode") + return ctx + }), + HandlerErrorEncoder(func( + ctx context.Context, + err error, + ) ([]byte, error) { + // convert error into proper APIGateway response. + apigwResp := events.APIGatewayProxyResponse{} + apigwResp.Body = `{"error":"yes"}` + apigwResp.StatusCode = 500 + resp, err := json.Marshal(apigwResp) + return resp, err + }), + ) + + ctx := context.Background() + req, _ := json.Marshal(events.APIGatewayProxyRequest{ + Body: `{"name":"john doe"}`, + }) + resp, err := helloHandler.Invoke(ctx, req) + + if err != nil { + t.Fatalf("Should have no error, but got: %+v", err) + } + + apigwResp := events.APIGatewayProxyResponse{} + json.Unmarshal(resp, &apigwResp) + if apigwResp.StatusCode != 500 { + t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode) + } +} + +func decodeHelloRequestWithTwoBefores( + ctx context.Context, req []byte, +) (interface{}, error) { + apigwReq := events.APIGatewayProxyRequest{} + err := json.Unmarshal([]byte(req), &apigwReq) + if err != nil { + return apigwReq, err + } + + request := helloRequest{} + err = json.Unmarshal([]byte(apigwReq.Body), &request) + if err != nil { + return request, err + } + + valOne, ok := ctx.Value(KeyBeforeOne).(string) + if !ok { + return request, fmt.Errorf( + "Value was not set properly when multiple HandlerBefores are used") + } + + valTwo, ok := ctx.Value(KeyBeforeTwo).(string) + if !ok { + return request, fmt.Errorf( + "Value was not set properly when multiple HandlerBefores are used") + } + + request.Name += " " + valOne + " " + valTwo + return request, err +} + +func encodeResponse( + ctx context.Context, response interface{}, +) ([]byte, error) { + apigwResp := events.APIGatewayProxyResponse{} + + mode, ok := ctx.Value(KeyEncMode).(string) + if ok && mode == "fail_encode" { + return nil, fmt.Errorf("fail encoding") + } + + respByte, err := json.Marshal(response) + if err != nil { + return nil, err + } + + apigwResp.Body = string(respByte) + apigwResp.StatusCode = 200 + + resp, err := json.Marshal(apigwResp) + return resp, err +} + +type helloRequest struct { + Name string `json:"name"` +} + +type helloResponse struct { + Greeting string `json:"greeting"` +} + +func makeTest01HelloEndpoint(svc serviceTest01) endpoint.Endpoint { + return func(_ context.Context, request interface{}) (interface{}, error) { + req := request.(helloRequest) + greeting := svc.hello(req.Name) + return helloResponse{greeting}, nil + } +} + +func makeTest01FailEndpoint(_ serviceTest01) endpoint.Endpoint { + return func(_ context.Context, request interface{}) (interface{}, error) { + return nil, fmt.Errorf("test error endpoint") + } +} + +type serviceTest01 struct{} + +func (ts *serviceTest01) hello(name string) string { + return fmt.Sprintf("hello %s", name) +} diff --git a/transport/awslambda/request_response_funcs.go b/transport/awslambda/request_response_funcs.go new file mode 100644 index 0000000..d85f273 --- /dev/null +++ b/transport/awslambda/request_response_funcs.go @@ -0,0 +1,21 @@ +package awslambda + +import ( + "context" +) + +// HandlerRequestFunc may take information from the received +// payload and use it to place items in the request scoped context. +// HandlerRequestFuncs are executed prior to invoking the endpoint and +// decoding of the payload. +type HandlerRequestFunc func(ctx context.Context, payload []byte) context.Context + +// HandlerResponseFunc may take information from a request context +// and use it to manipulate the response before it's marshaled. +// HandlerResponseFunc are executed after invoking the endpoint +// but prior to returning a response. +type HandlerResponseFunc func(ctx context.Context, response interface{}) context.Context + +// HandlerFinalizerFunc is executed at the end of Invoke. +// This can be used for logging purposes. +type HandlerFinalizerFunc func(ctx context.Context, resp []byte, err error)