diff --git a/api_test.go b/api_test.go index 3c126366..92bef5c5 100644 --- a/api_test.go +++ b/api_test.go @@ -26,6 +26,8 @@ import ( "net/url" "strings" + "github.com/onsi/ginkgo/extensions/table" + "github.com/pivotal-cf/brokerapi/middlewares" "code.cloudfoundry.org/lager" @@ -111,13 +113,6 @@ var _ = Describe("Service Broker API", func() { return recorder } - It("has a X-Correlation-ID header", func() { - response := makeRequest() - - header := response.Header().Get("X-Correlation-ID") - Expect(header).Should(Not(BeNil())) - }) - It("has a Content-Type header", func() { response := makeRequest() @@ -2337,6 +2332,11 @@ var _ = Describe("Service Broker API", func() { }) Describe("CorrelationIDHeader", func() { + const correlationID = "fake-correlation-id" + + type testCase struct { + correlationIDHeaderName string + } var ( fakeServiceBroker *fakes.AutoFakeServiceBroker @@ -2360,20 +2360,33 @@ var _ = Describe("Service Broker API", func() { testServer.Close() }) - When("X-Correlation-ID is passed", func() { - It("Adds correlation id to the context", func() { - const correlationID = "fake-correlation-id" - req.Header.Add("X-Correlation-ID", correlationID) + table.DescribeTable("Adds correlation id to the context", func(tc testCase) { + req.Header.Add(tc.correlationIDHeaderName, correlationID) - _, err := http.DefaultClient.Do(req) - Expect(err).NotTo(HaveOccurred()) + _, err := http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) - Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") - ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID)) + Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") + ctx := fakeServiceBroker.ServicesArgsForCall(0) + Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID)) + }, + table.Entry("X-Correlation-ID", testCase{ + correlationIDHeaderName: "X-Correlation-ID", + }), + table.Entry("X-CorrelationID", testCase{ + correlationIDHeaderName: "X-CorrelationID", + }), + table.Entry("X-ForRequest-ID", testCase{ + correlationIDHeaderName: "X-ForRequest-ID", + }), + table.Entry("X-Request-ID", testCase{ + correlationIDHeaderName: "X-Request-ID", + }), + table.Entry("X-Vcap-Request-Id", testCase{ + correlationIDHeaderName: "X-Vcap-Request-Id", + }), + ) - }) - }) When("X-Correlation-ID is not passed", func() { It("Generates correlation id and adds it to the context", func() { _, err := http.DefaultClient.Do(req) diff --git a/middlewares/correlation_id_header.go b/middlewares/correlation_id_header.go index c6374f6f..4094732e 100644 --- a/middlewares/correlation_id_header.go +++ b/middlewares/correlation_id_header.go @@ -13,14 +13,13 @@ var correlationIDHeaders = []string{"X-Correlation-ID", "X-CorrelationID", "X-Fo func AddCorrelationIDToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var correlationID, headerName string + var correlationID string var found bool for _, header := range correlationIDHeaders { headerValue := req.Header.Get(header) if headerValue != "" { correlationID = headerValue - headerName = header found = true break } @@ -28,10 +27,8 @@ func AddCorrelationIDToContext(next http.Handler) http.Handler { if !found { correlationID = generateCorrelationID() - headerName = correlationIDHeaders[0] } - w.Header().Set(headerName, correlationID) newCtx := context.WithValue(req.Context(), CorrelationIDKey, correlationID) next.ServeHTTP(w, req.WithContext(newCtx)) })