From 8d2345093094ff48a1c06a80b8988685edd592b0 Mon Sep 17 00:00:00 2001 From: George Blue Date: Fri, 9 Apr 2021 11:55:21 +0100 Subject: [PATCH] fix: use non-string types for context value keys (#155) The documentation for context.WithValue() says that: > The provided key must be comparable and should not be of type string or any other built-in type to avoid collisions between packages using context. https://golang.org/pkg/context/#WithValue Updated to use derived types for context value keys, which are all exported as constants so that users can import the constant to access the values. CorrelationIDKey and RequestIdentityKey were already exported constants, so users who imported and used these constants should not experience a breaking change. BREAKING CHANGE --- api_test.go | 16 ++++++------ fakes/fake_service_broker.go | 29 ++++++++++++++-------- handlers/bind.go | 2 +- handlers/catalog.go | 4 ++- handlers/deprovision.go | 2 +- handlers/get_binding.go | 2 +- handlers/get_instance.go | 2 +- handlers/last_binding_operation.go | 2 +- handlers/last_operation.go | 2 +- handlers/provision.go | 2 +- handlers/unbind.go | 2 +- handlers/update.go | 2 +- middlewares/context_keys.go | 10 ++++++++ middlewares/correlation_id_header.go | 2 -- middlewares/info_location_header.go | 6 +---- middlewares/originating_identity_header.go | 6 +---- middlewares/request_identity_header.go | 2 -- staticcheck.conf | 2 +- utils/context.go | 5 ++-- utils/context_test.go | 11 +++++--- 20 files changed, 61 insertions(+), 50 deletions(-) create mode 100644 middlewares/context_keys.go diff --git a/api_test.go b/api_test.go index 573084ac..f1c40414 100644 --- a/api_test.go +++ b/api_test.go @@ -140,7 +140,7 @@ var _ = Describe("Service Broker API", func() { } BeforeEach(func() { - ctx = context.WithValue(context.Background(), "test_context", true) + ctx = context.WithValue(context.Background(), fakes.FakeBrokerContextDataKey, true) reqBody = fmt.Sprintf(`{"service_id":"%s","plan_id":"456"}`, fakeServiceBroker.ServiceID) }) @@ -291,7 +291,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("originatingIdentity")).To(Equal(originatingIdentity)) + Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal(originatingIdentity)) }) }) @@ -302,7 +302,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("originatingIdentity")).To(Equal("")) + Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal("")) }) }) }) @@ -340,7 +340,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("requestIdentity")).To(Equal(requestIdentity)) + Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal(requestIdentity)) header := response.Header.Get("X-Broker-API-Request-Identity") Expect(header).To(Equal(requestIdentity)) @@ -354,7 +354,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("requestIdentity")).To(Equal("")) + Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal("")) header := response.Header.Get("X-Broker-API-Request-Identity") Expect(header).To(Equal("")) @@ -396,7 +396,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("infoLocation")).To(Equal(infoLocation)) + Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal(infoLocation)) }) }) @@ -407,7 +407,7 @@ var _ = Describe("Service Broker API", func() { Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value("infoLocation")).To(Equal("")) + Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal("")) }) }) }) @@ -492,7 +492,7 @@ var _ = Describe("Service Broker API", func() { request.Header.Add("X-Broker-API-Request-Identity", requestIdentity) ctx := context.Background() if fail { - ctx = context.WithValue(ctx, "fails", true) + ctx = context.WithValue(ctx, fakes.FakeBrokerContextFailsKey, true) } request = request.WithContext(ctx) brokerAPI.ServeHTTP(recorder, request) diff --git a/fakes/fake_service_broker.go b/fakes/fake_service_broker.go index 827944c9..23c72296 100644 --- a/fakes/fake_service_broker.go +++ b/fakes/fake_service_broker.go @@ -71,14 +71,21 @@ type FakeAsyncOnlyServiceBroker struct { FakeServiceBroker } +type FakeBrokerContextKeyType string + +const ( + FakeBrokerContextDataKey FakeBrokerContextKeyType = "test_context" + FakeBrokerContextFailsKey FakeBrokerContextKeyType = "fails" +) + func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi.Service, error) { fakeBroker.BrokerCalled = true - if val, ok := ctx.Value("test_context").(bool); ok { + if val, ok := ctx.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } - if val, ok := ctx.Value("fails").(bool); ok && val { + if val, ok := ctx.Value(FakeBrokerContextFailsKey).(bool); ok && val { return []brokerapi.Service{}, errors.New("something went wrong!") } @@ -164,7 +171,7 @@ func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi. func (fakeBroker *FakeServiceBroker) Provision(context context.Context, instanceID string, details brokerapi.ProvisionDetails, asyncAllowed bool) (brokerapi.ProvisionedServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -241,7 +248,7 @@ func (fakeBroker *FakeAsyncOnlyServiceBroker) Provision(context context.Context, func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID string, details brokerapi.UpdateDetails, asyncAllowed bool) (brokerapi.UpdateServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -258,7 +265,7 @@ func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instanceID string, details domain.FetchInstanceDetails) (brokerapi.GetInstanceDetailsSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -277,7 +284,7 @@ func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instan func (fakeBroker *FakeServiceBroker) Deprovision(context context.Context, instanceID string, details brokerapi.DeprovisionDetails, asyncAllowed bool) (brokerapi.DeprovisionServiceSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -335,7 +342,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Deprovision(context context.Context, i func (fakeBroker *FakeServiceBroker) GetBinding(context context.Context, instanceID, bindingID string, details domain.FetchBindingDetails) (brokerapi.GetBindingSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -375,7 +382,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Bind(context context.Context, instance func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, bindingID string, details brokerapi.BindDetails, asyncAllowed bool) (brokerapi.Binding, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -415,7 +422,7 @@ func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, b func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID, bindingID string, details brokerapi.UnbindDetails, asyncAllowed bool) (brokerapi.UnbindSpec, error) { fakeBroker.BrokerCalled = true - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -437,7 +444,7 @@ func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID, func (fakeBroker *FakeServiceBroker) LastBindingOperation(context context.Context, instanceID, bindingID string, details brokerapi.PollDetails) (brokerapi.LastOperation, error) { - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } @@ -452,7 +459,7 @@ func (fakeBroker *FakeServiceBroker) LastOperation(context context.Context, inst fakeBroker.LastOperationInstanceID = instanceID fakeBroker.LastOperationData = details.OperationData - if val, ok := context.Value("test_context").(bool); ok { + if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok { fakeBroker.ReceivedContext = val } diff --git a/handlers/bind.go b/handlers/bind.go index 4529dd2e..00186dd9 100644 --- a/handlers/bind.go +++ b/handlers/bind.go @@ -34,7 +34,7 @@ func (h APIHandler) Bind(w http.ResponseWriter, req *http.Request) { asyncAllowed = req.FormValue("accepts_incomplete") == "true" } - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.BindDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/handlers/catalog.go b/handlers/catalog.go index 40e25cab..5f6c53b2 100644 --- a/handlers/catalog.go +++ b/handlers/catalog.go @@ -4,11 +4,13 @@ import ( "fmt" "net/http" + "github.com/pivotal-cf/brokerapi/v8/middlewares" + "github.com/pivotal-cf/brokerapi/v8/domain/apiresponses" ) func (h *APIHandler) Catalog(w http.ResponseWriter, req *http.Request) { - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) services, err := h.serviceBroker.Services(req.Context()) if err != nil { diff --git a/handlers/deprovision.go b/handlers/deprovision.go index 4f599581..01131e07 100644 --- a/handlers/deprovision.go +++ b/handlers/deprovision.go @@ -28,7 +28,7 @@ func (h APIHandler) Deprovision(w http.ResponseWriter, req *http.Request) { Force: req.FormValue("force") == "true", } - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) if details.ServiceID == "" { h.respond(w, http.StatusBadRequest, requestId, apiresponses.ErrorResponse{ diff --git a/handlers/get_binding.go b/handlers/get_binding.go index 6d833c1f..4206c918 100644 --- a/handlers/get_binding.go +++ b/handlers/get_binding.go @@ -25,7 +25,7 @@ func (h APIHandler) GetBinding(w http.ResponseWriter, req *http.Request) { bindingIDLogKey: bindingID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/get_instance.go b/handlers/get_instance.go index d2a362aa..3499ef9b 100644 --- a/handlers/get_instance.go +++ b/handlers/get_instance.go @@ -24,7 +24,7 @@ func (h APIHandler) GetInstance(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/last_binding_operation.go b/handlers/last_binding_operation.go index d37e1d3d..ffbae14b 100644 --- a/handlers/last_binding_operation.go +++ b/handlers/last_binding_operation.go @@ -29,7 +29,7 @@ func (h APIHandler) LastBindingOperation(w http.ResponseWriter, req *http.Reques instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) version := getAPIVersion(req) if version.Minor < 14 { diff --git a/handlers/last_operation.go b/handlers/last_operation.go index b9607e18..21f86b3a 100644 --- a/handlers/last_operation.go +++ b/handlers/last_operation.go @@ -29,7 +29,7 @@ func (h APIHandler) LastOperation(w http.ResponseWriter, req *http.Request) { logger.Info("starting-check-for-operation") - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) lastOperation, err := h.serviceBroker.LastOperation(req.Context(), instanceID, pollDetails) if err != nil { diff --git a/handlers/provision.go b/handlers/provision.go index 991d9e22..a041a47a 100644 --- a/handlers/provision.go +++ b/handlers/provision.go @@ -30,7 +30,7 @@ func (h *APIHandler) Provision(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.ProvisionDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/handlers/unbind.go b/handlers/unbind.go index c16c1b82..35bc2d51 100644 --- a/handlers/unbind.go +++ b/handlers/unbind.go @@ -24,7 +24,7 @@ func (h APIHandler) Unbind(w http.ResponseWriter, req *http.Request) { bindingIDLogKey: bindingID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) details := domain.UnbindDetails{ PlanID: req.FormValue("plan_id"), diff --git a/handlers/update.go b/handlers/update.go index aab37188..f95b0a3e 100644 --- a/handlers/update.go +++ b/handlers/update.go @@ -24,7 +24,7 @@ func (h APIHandler) Update(w http.ResponseWriter, req *http.Request) { instanceIDLogKey: instanceID, }, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey)) - requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity")) + requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey)) var details domain.UpdateDetails if err := json.NewDecoder(req.Body).Decode(&details); err != nil { diff --git a/middlewares/context_keys.go b/middlewares/context_keys.go new file mode 100644 index 00000000..03e38b2a --- /dev/null +++ b/middlewares/context_keys.go @@ -0,0 +1,10 @@ +package middlewares + +type ContextKey string + +const ( + CorrelationIDKey ContextKey = "correlation-id" + InfoLocationKey ContextKey = "infoLocation" + OriginatingIdentityKey ContextKey = "originatingIdentity" + RequestIdentityKey ContextKey = "requestIdentity" +) diff --git a/middlewares/correlation_id_header.go b/middlewares/correlation_id_header.go index c90ec1ce..a314e7d0 100644 --- a/middlewares/correlation_id_header.go +++ b/middlewares/correlation_id_header.go @@ -7,8 +7,6 @@ import ( "github.com/pborman/uuid" ) -const CorrelationIDKey = "correlation-id" - var correlationIDHeaders = []string{"X-Correlation-ID", "X-CorrelationID", "X-ForRequest-ID", "X-Request-ID", "X-Vcap-Request-Id"} func AddCorrelationIDToContext(next http.Handler) http.Handler { diff --git a/middlewares/info_location_header.go b/middlewares/info_location_header.go index f23acc90..7790f8f2 100644 --- a/middlewares/info_location_header.go +++ b/middlewares/info_location_header.go @@ -20,14 +20,10 @@ import ( "net/http" ) -const ( - infoLocationKey = "infoLocation" -) - func AddInfoLocationToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { infoLocation := req.Header.Get("X-Api-Info-Location") - newCtx := context.WithValue(req.Context(), infoLocationKey, infoLocation) + newCtx := context.WithValue(req.Context(), InfoLocationKey, infoLocation) next.ServeHTTP(w, req.WithContext(newCtx)) }) } diff --git a/middlewares/originating_identity_header.go b/middlewares/originating_identity_header.go index a95fb891..db60c211 100644 --- a/middlewares/originating_identity_header.go +++ b/middlewares/originating_identity_header.go @@ -20,14 +20,10 @@ import ( "net/http" ) -const ( - originatingIdentityKey = "originatingIdentity" -) - func AddOriginatingIdentityToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { originatingIdentity := req.Header.Get("X-Broker-API-Originating-Identity") - newCtx := context.WithValue(req.Context(), originatingIdentityKey, originatingIdentity) + newCtx := context.WithValue(req.Context(), OriginatingIdentityKey, originatingIdentity) next.ServeHTTP(w, req.WithContext(newCtx)) }) } diff --git a/middlewares/request_identity_header.go b/middlewares/request_identity_header.go index 0305a4c7..d11ba6e6 100644 --- a/middlewares/request_identity_header.go +++ b/middlewares/request_identity_header.go @@ -5,8 +5,6 @@ import ( "net/http" ) -const RequestIdentityKey = "requestIdentity" - func AddRequestIdentityToContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { requestIdentity := req.Header.Get("X-Broker-API-Request-Identity") diff --git a/staticcheck.conf b/staticcheck.conf index 7feb2a23..ea71844c 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -1,4 +1,4 @@ # When adding staticcheck, we thought it was better to get it working with some checks disabled # rather than fixing all the problems in one go. Some problems cannot be fixed without making # breaking changes. -checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-SA1029", "-ST1020"] +checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-ST1020"] diff --git a/utils/context.go b/utils/context.go index 39997106..480bbd77 100644 --- a/utils/context.go +++ b/utils/context.go @@ -5,6 +5,7 @@ import ( "code.cloudfoundry.org/lager" "github.com/pivotal-cf/brokerapi/v8/domain" + "github.com/pivotal-cf/brokerapi/v8/middlewares" ) type contextKey string @@ -42,11 +43,11 @@ func RetrieveServicePlanFromContext(ctx context.Context) *domain.ServicePlan { return nil } -func DataForContext(context context.Context, dataKeys ...string) lager.Data { +func DataForContext(context context.Context, dataKeys ...middlewares.ContextKey) lager.Data { data := lager.Data{} for _, key := range dataKeys { if value := context.Value(key); value != nil { - data[key] = value + data[string(key)] = value } } diff --git a/utils/context_test.go b/utils/context_test.go index ca63263c..d6c802ae 100644 --- a/utils/context_test.go +++ b/utils/context_test.go @@ -6,6 +6,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/pivotal-cf/brokerapi/v8/domain" + "github.com/pivotal-cf/brokerapi/v8/middlewares" "github.com/pivotal-cf/brokerapi/v8/utils" ) @@ -75,7 +76,7 @@ var _ = Describe("Context", func() { }) Describe("Log data for context", func() { - const testKey = "test-key" + const testKey middlewares.ContextKey = "test-key" Context("the provided key is present in the context", func() { It("returns data containing the key", func() { @@ -83,10 +84,11 @@ var _ = Describe("Context", func() { ctx = context.WithValue(ctx, testKey, expectedValue) data := utils.DataForContext(ctx, testKey) - value, ok := data[testKey] + value, ok := data[string(testKey)] Expect(ok).To(BeTrue()) Expect(value).Should(Equal(expectedValue)) }) + Context("the key value is a struct", func() { It("returns data containing the key", func() { type testType struct{} @@ -94,16 +96,17 @@ var _ = Describe("Context", func() { ctx = context.WithValue(ctx, testKey, expectedValue) data := utils.DataForContext(ctx, testKey) - value, ok := data[testKey] + value, ok := data[string(testKey)] Expect(ok).To(BeTrue()) Expect(value).Should(Equal(expectedValue)) }) }) }) + Context("the provided key is not in the context", func() { It("returns data without the key", func() { data := utils.DataForContext(ctx, testKey) - _, ok := data[testKey] + _, ok := data[string(testKey)] Expect(ok).To(BeFalse()) }) })