diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 84e27cf..742c1a0 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -59,13 +59,13 @@ // ServerBefore functions are executed on the HTTP request object before the // request is decoded. func ServerBefore(before ...RequestFunc) ServerOption { - return func(s *Server) { s.before = before } + return func(s *Server) { s.before = append(s.before, before...) } } // ServerAfter functions are executed on the HTTP response writer after the // endpoint is invoked, but before anything is written to the client. func ServerAfter(after ...ResponseFunc) ServerOption { - return func(s *Server) { s.after = after } + return func(s *Server) { s.after = append(s.after, after...) } } // ServerErrorLogger is used to log non-terminal errors. By default, no errors diff --git a/transport/http/server.go b/transport/http/server.go index ab2b22e..de109be 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -51,13 +51,13 @@ // ServerBefore functions are executed on the HTTP request object before the // request is decoded. func ServerBefore(before ...RequestFunc) ServerOption { - return func(s *Server) { s.before = before } + return func(s *Server) { s.before = append(s.before, before...) } } // ServerAfter functions are executed on the HTTP response writer after the // endpoint is invoked, but before anything is written to the client. func ServerAfter(after ...ServerResponseFunc) ServerOption { - return func(s *Server) { s.after = after } + return func(s *Server) { s.after = append(s.after, after...) } } // ServerErrorEncoder is used to encode errors to the http.ResponseWriter diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 654fbec..51e311e 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -90,6 +90,99 @@ buf, _ := ioutil.ReadAll(resp.Body) if want, have := http.StatusOK, resp.StatusCode; want != have { t.Errorf("want %d, have %d (%s)", want, have, buf) + } +} + + +func TestMultipleServerBefore(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer( + context.Background(), + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerBefores are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestMultipleServerAfter(t *testing.T) { + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) + handler := httptransport.NewServer( + context.Background(), + endpoint.Nop, + func(context.Context, *http.Request) (interface{}, error) { + return struct{}{}, nil + }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) + return nil + }, + httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerAfters are used") + } + + close(done) + return ctx + }), + ) + + server := httptest.NewServer(handler) + defer server.Close() + go http.Get(server.URL) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") } } diff --git a/transport/httprp/server.go b/transport/httprp/server.go index 25776ed..5c3ad1b 100644 --- a/transport/httprp/server.go +++ b/transport/httprp/server.go @@ -45,7 +45,7 @@ // ServerBefore functions are executed on the HTTP request object before the // request is decoded. func ServerBefore(before ...RequestFunc) ServerOption { - return func(s *Server) { s.before = before } + return func(s *Server) { s.before = append(s.before, before...) } } // ServeHTTP implements http.Handler. diff --git a/transport/httprp/server_test.go b/transport/httprp/server_test.go index 45fd429..47d6f77 100644 --- a/transport/httprp/server_test.go +++ b/transport/httprp/server_test.go @@ -119,3 +119,45 @@ t.Errorf("want %d or %d, have %d", http.StatusBadGateway, http.StatusInternalServerError, resp.StatusCode) } } + +func TestMultipleServerBefore(t *testing.T) { + const ( + headerKey = "X-TEST-HEADER" + headerVal = "go-kit-proxy" + ) + + originServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if want, have := headerVal, r.Header.Get(headerKey); want != have { + t.Errorf("want %q, have %q", want, have) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("hey")) + })) + defer originServer.Close() + originURL, _ := url.Parse(originServer.URL) + + handler := httptransport.NewServer( + context.Background(), + originURL, + httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + r.Header.Add(headerKey, headerVal) + return ctx + }), + httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + return ctx + }), + ) + proxyServer := httptest.NewServer(handler) + defer proxyServer.Close() + + resp, _ := http.Get(proxyServer.URL) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } + + responseBody, _ := ioutil.ReadAll(resp.Body) + if want, have := "hey", string(responseBody); want != have { + t.Errorf("want %q, have %q", want, have) + } +}