diff --git a/api_test.go b/api_test.go index 573084ac..f1c40414 100644 --- a/api_test.go +++ b/api_test.go @@ -140,7 +140,7 @@ var _ = Describe("Service Broker API", func() { } BeforeEach(func() { - ctx = context.WithValue(context.Background(), "test_context", true) + ctx = context.WithValue(context.Background(), fakes.FakeBrokerContextDataKey, true) reqBody = fmt.Sprintf(`{"service_id":"%s","plan_id":"456"}`, fakeServiceBroker.ServiceID) }) @@ -291,7 +291,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("originatingIdentity")).To(Equal(originatingIdentity)) + Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal(originatingIdentity)) }) }) @@ -302,7 +302,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("originatingIdentity")).To(Equal("")) + Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal("")) }) }) }) @@ -340,7 +340,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("requestIdentity")).To(Equal(requestIdentity)) + Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal(requestIdentity)) header := response.Header.Get("X-Broker-API-Request-Identity") Expect(header).To(Equal(requestIdentity)) @@ -354,7 +354,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("requestIdentity")).To(Equal("")) + Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal("")) header := response.Header.Get("X-Broker-API-Request-Identity") Expect(header).To(Equal("")) @@ -396,7 +396,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("infoLocation")).To(Equal(infoLocation)) + Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal(infoLocation)) }) }) @@ -407,7 +407,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("infoLocation")).To(Equal("")) + Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal("")) }) }) }) @@ -492,7 +492,7 @@ var _ = Describe("Service Broker API", func() { request.Header.Add("X-Broker-API-Request-Identity", requestIdentity) ctx := context.Background() if fail { - ctx = context.WithValue(ctx, "fails", true) + ctx = context.WithValue(ctx, fakes.FakeBrokerContextFailsKey, true) } request = request.WithContext(ctx) brokerAPI.ServeHTTP(recorder, request) diff --git a/fakes/fake_service_broker.go b/fakes/fake_service_broker.go index 827944c9..23c72296 100644 --- a/fakes/fake_service_broker.go +++ b/fakes/fake_service_broker.go @@ -71,14 +71,21 @@ type FakeAsyncOnlyServiceBroker struct { FakeServiceBroker } +type FakeBrokerContextKeyType string + +const ( + FakeBrokerContextDataKey FakeBrokerContextKeyType = "test_context" + FakeBrokerContextFailsKey FakeBrokerContextKeyType = "fails" +) + func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi.Service, error) { fakeBroker.BrokerCalled = true - if val, ok := ctx.Value("test_context").(bool); ok { + if val, ok := ctx.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } - if val, ok := ctx.Value("fails").(bool); ok && val { + if val, ok := ctx.Value(FakeBrokerContextFailsKey).(bool); ok && val { return []brokerapi.Service{}, errors.New("something went wrong!") } @@ -164,7 +171,7 @@ func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi. func (fakeBroker *FakeServiceBroker) Provision(context context.Context, instanceID string, details brokerapi.ProvisionDetails, asyncAllowed bool) (brokerapi.ProvisionedServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -241,7 +248,7 @@ func (fakeBroker *FakeAsyncOnlyServiceBroker) Provision(context context.Context, func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID string, details brokerapi.UpdateDetails, asyncAllowed bool) (brokerapi.UpdateServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -258,7 +265,7 @@ func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instanceID string, details domain.FetchInstanceDetails) (brokerapi.GetInstanceDetailsSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -277,7 +284,7 @@ func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instan func (fakeBroker *FakeServiceBroker) Deprovision(context context.Context, instanceID string, details brokerapi.DeprovisionDetails, asyncAllowed bool) (brokerapi.DeprovisionServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -335,7 +342,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Deprovision(context context.Context, i func (fakeBroker *FakeServiceBroker) GetBinding(context context.Context, instanceID, bindingID string, details domain.FetchBindingDetails) (brokerapi.GetBindingSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -375,7 +382,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Bind(context context.Context, instance func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, bindingID string, details brokerapi.BindDetails, asyncAllowed bool) (brokerapi.Binding, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -415,7 +422,7 @@ func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, b func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID, bindingID string, details brokerapi.UnbindDetails, asyncAllowed bool) (brokerapi.UnbindSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -437,7 +444,7 @@ func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID, func (fakeBroker *FakeServiceBroker) LastBindingOperation(context context.Context, instanceID, bindingID string, details brokerapi.PollDetails) (brokerapi.LastOperation, error) { - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -452,7 +459,7 @@ func (fakeBroker *FakeServiceBroker) LastOperation(context context.Context, inst fakeBroker.LastOperationInstanceID = instanceID fakeBroker.LastOperationData = details.OperationData - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } diff --git a/handlers/bind.go b/handlers/bind.go index 4529dd2e..00186dd9 100644 --- a/handlers/bind.go +++ b/handlers/bind.go @@ -34,7 +34,7 @@ func (h APIHandler) Bind(w http.ResponseWriter, req *http.Request) { asyncAllowed = req.FormValue("accepts_incomplete") == "true" } - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.BindDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/handlers/catalog.go b/handlers/catalog.go index 40e25cab..5f6c53b2 100644 --- a/handlers/catalog.go +++ b/handlers/catalog.go @@ -4,11 +4,13 @@ import ( "fmt" "net/http" + "github.com/pivotal-cf/brokerapi/v8/middlewares" + "github.com/pivotal-cf/brokerapi/v8/domain/apiresponses" ) func (h *APIHandler) Catalog(w http.ResponseWriter, req *http.Request) { - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) services, err := h.serviceBroker.Services(req.Context()) if err != nil { diff --git a/handlers/deprovision.go b/handlers/deprovision.go index 4f599581..01131e07 100644 --- a/handlers/deprovision.go +++ b/handlers/deprovision.go @@ -28,7 +28,7 @@ func (h APIHandler) Deprovision(w http.ResponseWriter, req *http.Request) { Force: req.FormValue("force") == "true", } - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) if details.ServiceID == "" { h.respond(w, http.StatusBadRequest, requestId, apiresponses.ErrorResponse{ diff --git a/handlers/get_binding.go b/handlers/get_binding.go index 6d833c1f..4206c918 100644 --- a/handlers/get_binding.go +++ b/handlers/get_binding.go @@ -25,7 +25,7 @@ func (h APIHandler) GetBinding(w http.ResponseWriter, req *http.Request) { bindingIDLogKey: bindingID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/get_instance.go b/handlers/get_instance.go index d2a362aa..3499ef9b 100644 --- a/handlers/get_instance.go +++ b/handlers/get_instance.go @@ -24,7 +24,7 @@ func (h APIHandler) GetInstance(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/last_binding_operation.go b/handlers/last_binding_operation.go index d37e1d3d..ffbae14b 100644 --- a/handlers/last_binding_operation.go +++ b/handlers/last_binding_operation.go @@ -29,7 +29,7 @@ func (h APIHandler) LastBindingOperation(w http.ResponseWriter, req *http.Reques instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/last_operation.go b/handlers/last_operation.go index b9607e18..21f86b3a 100644 --- a/handlers/last_operation.go +++ b/handlers/last_operation.go @@ -29,7 +29,7 @@ func (h APIHandler) LastOperation(w http.ResponseWriter, req *http.Request) { logger.Info("starting-check-for-operation") - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) lastOperation, err := h.serviceBroker.LastOperation(req.Context(), instanceID, pollDetails) if err != nil { diff --git a/handlers/provision.go b/handlers/provision.go index 991d9e22..a041a47a 100644 --- a/handlers/provision.go +++ b/handlers/provision.go @@ -30,7 +30,7 @@ func (h *APIHandler) Provision(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.ProvisionDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/handlers/unbind.go b/handlers/unbind.go index c16c1b82..35bc2d51 100644 --- a/handlers/unbind.go +++ b/handlers/unbind.go @@ -24,7 +24,7 @@ func (h APIHandler) Unbind(w http.ResponseWriter, req *http.Request) { bindingIDLogKey: bindingID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) details := domain.UnbindDetails{ PlanID: req.FormValue("plan_id"), diff --git a/handlers/update.go b/handlers/update.go index aab37188..f95b0a3e 100644 --- a/handlers/update.go +++ b/handlers/update.go @@ -24,7 +24,7 @@ func (h APIHandler) Update(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.UpdateDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/middlewares/context_keys.go b/middlewares/context_keys.go new file mode 100644 index 00000000..03e38b2a --- /dev/null +++ b/middlewares/context_keys.go @@ -0,0 +1,10 @@ +package middlewares + +type ContextKey string + +const ( + CorrelationIDKey ContextKey = "correlation-id" + InfoLocationKey ContextKey = "infoLocation" + OriginatingIdentityKey ContextKey = "originatingIdentity" + RequestIdentityKey ContextKey = "requestIdentity" +) diff --git a/middlewares/correlation_id_header.go b/middlewares/correlation_id_header.go index c90ec1ce..a314e7d0 100644 --- a/middlewares/correlation_id_header.go +++ b/middlewares/correlation_id_header.go @@ -7,8 +7,6 @@ import ( "github.com/pborman/uuid" ) -const CorrelationIDKey = "correlation-id" - var correlationIDHeaders = []string{"X-Correlation-ID", "X-CorrelationID", "X-ForRequest-ID", "X-Request-ID", "X-Vcap-Request-Id"} func AddCorrelationIDToContext(next http.Handler) http.Handler { diff --git a/middlewares/info_location_header.go b/middlewares/info_location_header.go index f23acc90..7790f8f2 100644 --- a/middlewares/info_location_header.go +++ b/middlewares/info_location_header.go @@ -20,14 +20,10 @@ import ( "net/http" ) -const ( - infoLocationKey = "infoLocation" -) - func AddInfoLocationToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { infoLocation := req.Header.Get("X-Api-Info-Location") - newCtx := context.WithValue(req.Context(), infoLocationKey, infoLocation) + newCtx := context.WithValue(req.Context(), InfoLocationKey, infoLocation) next.ServeHTTP(w, req.WithContext(newCtx)) }) } diff --git a/middlewares/originating_identity_header.go b/middlewares/originating_identity_header.go index a95fb891..db60c211 100644 --- a/middlewares/originating_identity_header.go +++ b/middlewares/originating_identity_header.go @@ -20,14 +20,10 @@ import ( "net/http" ) -const ( - originatingIdentityKey = "originatingIdentity" -) - func AddOriginatingIdentityToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { originatingIdentity := req.Header.Get("X-Broker-API-Originating-Identity") - newCtx := context.WithValue(req.Context(), originatingIdentityKey, originatingIdentity) + newCtx := context.WithValue(req.Context(), OriginatingIdentityKey, originatingIdentity) next.ServeHTTP(w, req.WithContext(newCtx)) }) } diff --git a/middlewares/request_identity_header.go b/middlewares/request_identity_header.go index 0305a4c7..d11ba6e6 100644 --- a/middlewares/request_identity_header.go +++ b/middlewares/request_identity_header.go @@ -5,8 +5,6 @@ import ( "net/http" ) -const RequestIdentityKey = "requestIdentity" - func AddRequestIdentityToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { requestIdentity := req.Header.Get("X-Broker-API-Request-Identity") diff --git a/staticcheck.conf b/staticcheck.conf index 7feb2a23..ea71844c 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -1,4 +1,4 @@ # When adding staticcheck, we thought it was better to get it working with some checks disabled # rather than fixing all the problems in one go. Some problems cannot be fixed without making # breaking changes. -checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-SA1029", "-ST1020"] +checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-ST1020"] diff --git a/utils/context.go b/utils/context.go index 39997106..480bbd77 100644 --- a/utils/context.go +++ b/utils/context.go @@ -5,6 +5,7 @@ import ( "code.cloudfoundry.org/lager" "github.com/pivotal-cf/brokerapi/v8/domain" + "github.com/pivotal-cf/brokerapi/v8/middlewares" ) type contextKey string @@ -42,11 +43,11 @@ func RetrieveServicePlanFromContext(ctx context.Context) *domain.ServicePlan { return nil } -func DataForContext(context context.Context, dataKeys ...string) lager.Data { +func DataForContext(context context.Context, dataKeys ...middlewares.ContextKey) lager.Data { data := lager.Data{} for _, key := range dataKeys { if value := context.Value(key); value != nil { - data[key] = value + data[string(key)] = value } } diff --git a/utils/context_test.go b/utils/context_test.go index ca63263c..d6c802ae 100644 --- a/utils/context_test.go +++ b/utils/context_test.go @@ -6,6 +6,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/pivotal-cf/brokerapi/v8/domain" + "github.com/pivotal-cf/brokerapi/v8/middlewares" "github.com/pivotal-cf/brokerapi/v8/utils" ) @@ -75,7 +76,7 @@ var _ = Describe("Context", func() { }) Describe("Log data for context", func() { - const testKey = "test-key" + const testKey middlewares.ContextKey = "test-key" Context("the provided key is present in the context", func() { It("returns data containing the key", func() { @@ -83,10 +84,11 @@ var _ = Describe("Context", func() { ctx = context.WithValue(ctx, testKey, expectedValue) data := utils.DataForContext(ctx, testKey) - value, ok := data[testKey] + value, ok := data[string(testKey)] Expect(ok).To(BeTrue()) Expect(value).Should(Equal(expectedValue)) }) + Context("the key value is a struct", func() { It("returns data containing the key", func() { type testType struct{} @@ -94,16 +96,17 @@ var _ = Describe("Context", func() { ctx = context.WithValue(ctx, testKey, expectedValue) data := utils.DataForContext(ctx, testKey) - value, ok := data[testKey] + value, ok := data[string(testKey)] Expect(ok).To(BeTrue()) Expect(value).Should(Equal(expectedValue)) }) }) }) + Context("the provided key is not in the context", func() { It("returns data without the key", func() { data := utils.DataForContext(ctx, testKey) - _, ok := data[testKey] + _, ok := data[string(testKey)] Expect(ok).To(BeFalse()) }) })