Skip to content

Commit

Permalink
fix: use non-string types for context value keys (#155)
Browse files Browse the repository at this point in the history
The documentation for context.WithValue() says that:
> The provided key must be comparable and should not be of type string
or any other built-in type to avoid collisions between packages using
context.

https://golang.org/pkg/context/#WithValue

Updated to use derived types for context value keys, which are all
exported as constants so that users can import the constant to access
the values.

CorrelationIDKey and RequestIdentityKey were already exported constants, so users
who imported and used these constants should not experience a breaking change.

BREAKING CHANGE
  • Loading branch information
blgm authored Apr 9, 2021
1 parent 3d753e9 commit 8d23450
Show file tree
Hide file tree
Showing 20 changed files with 61 additions and 50 deletions.
16 changes: 8 additions & 8 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ var _ = Describe("Service Broker API", func() {
}

BeforeEach(func() {
ctx = context.WithValue(context.Background(), "test_context", true)
ctx = context.WithValue(context.Background(), fakes.FakeBrokerContextDataKey, true)
reqBody = fmt.Sprintf(`{"service_id":"%s","plan_id":"456"}`, fakeServiceBroker.ServiceID)
})

Expand Down Expand Up @@ -291,7 +291,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("originatingIdentity")).To(Equal(originatingIdentity))
Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal(originatingIdentity))
})
})

Expand All @@ -302,7 +302,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("originatingIdentity")).To(Equal(""))
Expect(ctx.Value(middlewares.OriginatingIdentityKey)).To(Equal(""))
})
})
})
Expand Down Expand Up @@ -340,7 +340,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("requestIdentity")).To(Equal(requestIdentity))
Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal(requestIdentity))

header := response.Header.Get("X-Broker-API-Request-Identity")
Expect(header).To(Equal(requestIdentity))
Expand All @@ -354,7 +354,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("requestIdentity")).To(Equal(""))
Expect(ctx.Value(middlewares.RequestIdentityKey)).To(Equal(""))

header := response.Header.Get("X-Broker-API-Request-Identity")
Expect(header).To(Equal(""))
Expand Down Expand Up @@ -396,7 +396,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("infoLocation")).To(Equal(infoLocation))
Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal(infoLocation))

})
})
Expand All @@ -407,7 +407,7 @@ var _ = Describe("Service Broker API", func() {

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value("infoLocation")).To(Equal(""))
Expect(ctx.Value(middlewares.InfoLocationKey)).To(Equal(""))
})
})
})
Expand Down Expand Up @@ -492,7 +492,7 @@ var _ = Describe("Service Broker API", func() {
request.Header.Add("X-Broker-API-Request-Identity", requestIdentity)
ctx := context.Background()
if fail {
ctx = context.WithValue(ctx, "fails", true)
ctx = context.WithValue(ctx, fakes.FakeBrokerContextFailsKey, true)
}
request = request.WithContext(ctx)
brokerAPI.ServeHTTP(recorder, request)
Expand Down
29 changes: 18 additions & 11 deletions fakes/fake_service_broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,21 @@ type FakeAsyncOnlyServiceBroker struct {
FakeServiceBroker
}

type FakeBrokerContextKeyType string

const (
FakeBrokerContextDataKey FakeBrokerContextKeyType = "test_context"
FakeBrokerContextFailsKey FakeBrokerContextKeyType = "fails"
)

func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi.Service, error) {
fakeBroker.BrokerCalled = true

if val, ok := ctx.Value("test_context").(bool); ok {
if val, ok := ctx.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

if val, ok := ctx.Value("fails").(bool); ok && val {
if val, ok := ctx.Value(FakeBrokerContextFailsKey).(bool); ok && val {
return []brokerapi.Service{}, errors.New("something went wrong!")
}

Expand Down Expand Up @@ -164,7 +171,7 @@ func (fakeBroker *FakeServiceBroker) Services(ctx context.Context) ([]brokerapi.
func (fakeBroker *FakeServiceBroker) Provision(context context.Context, instanceID string, details brokerapi.ProvisionDetails, asyncAllowed bool) (brokerapi.ProvisionedServiceSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand Down Expand Up @@ -241,7 +248,7 @@ func (fakeBroker *FakeAsyncOnlyServiceBroker) Provision(context context.Context,
func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID string, details brokerapi.UpdateDetails, asyncAllowed bool) (brokerapi.UpdateServiceSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand All @@ -258,7 +265,7 @@ func (fakeBroker *FakeServiceBroker) Update(context context.Context, instanceID
func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instanceID string, details domain.FetchInstanceDetails) (brokerapi.GetInstanceDetailsSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand All @@ -277,7 +284,7 @@ func (fakeBroker *FakeServiceBroker) GetInstance(context context.Context, instan
func (fakeBroker *FakeServiceBroker) Deprovision(context context.Context, instanceID string, details brokerapi.DeprovisionDetails, asyncAllowed bool) (brokerapi.DeprovisionServiceSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand Down Expand Up @@ -335,7 +342,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Deprovision(context context.Context, i
func (fakeBroker *FakeServiceBroker) GetBinding(context context.Context, instanceID, bindingID string, details domain.FetchBindingDetails) (brokerapi.GetBindingSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand Down Expand Up @@ -375,7 +382,7 @@ func (fakeBroker *FakeAsyncServiceBroker) Bind(context context.Context, instance
func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, bindingID string, details brokerapi.BindDetails, asyncAllowed bool) (brokerapi.Binding, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand Down Expand Up @@ -415,7 +422,7 @@ func (fakeBroker *FakeServiceBroker) Bind(context context.Context, instanceID, b
func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID, bindingID string, details brokerapi.UnbindDetails, asyncAllowed bool) (brokerapi.UnbindSpec, error) {
fakeBroker.BrokerCalled = true

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand All @@ -437,7 +444,7 @@ func (fakeBroker *FakeServiceBroker) Unbind(context context.Context, instanceID,

func (fakeBroker *FakeServiceBroker) LastBindingOperation(context context.Context, instanceID, bindingID string, details brokerapi.PollDetails) (brokerapi.LastOperation, error) {

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand All @@ -452,7 +459,7 @@ func (fakeBroker *FakeServiceBroker) LastOperation(context context.Context, inst
fakeBroker.LastOperationInstanceID = instanceID
fakeBroker.LastOperationData = details.OperationData

if val, ok := context.Value("test_context").(bool); ok {
if val, ok := context.Value(FakeBrokerContextDataKey).(bool); ok {
fakeBroker.ReceivedContext = val
}

Expand Down
2 changes: 1 addition & 1 deletion handlers/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (h APIHandler) Bind(w http.ResponseWriter, req *http.Request) {
asyncAllowed = req.FormValue("accepts_incomplete") == "true"
}

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

var details domain.BindDetails
if err := json.NewDecoder(req.Body).Decode(&details); err != nil {
Expand Down
4 changes: 3 additions & 1 deletion handlers/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"fmt"
"net/http"

"github.com/pivotal-cf/brokerapi/v8/middlewares"

"github.com/pivotal-cf/brokerapi/v8/domain/apiresponses"
)

func (h *APIHandler) Catalog(w http.ResponseWriter, req *http.Request) {
requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

services, err := h.serviceBroker.Services(req.Context())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion handlers/deprovision.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (h APIHandler) Deprovision(w http.ResponseWriter, req *http.Request) {
Force: req.FormValue("force") == "true",
}

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

if details.ServiceID == "" {
h.respond(w, http.StatusBadRequest, requestId, apiresponses.ErrorResponse{
Expand Down
2 changes: 1 addition & 1 deletion handlers/get_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (h APIHandler) GetBinding(w http.ResponseWriter, req *http.Request) {
bindingIDLogKey: bindingID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

version := getAPIVersion(req)
if version.Minor < 14 {
Expand Down
2 changes: 1 addition & 1 deletion handlers/get_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (h APIHandler) GetInstance(w http.ResponseWriter, req *http.Request) {
instanceIDLogKey: instanceID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

version := getAPIVersion(req)
if version.Minor < 14 {
Expand Down
2 changes: 1 addition & 1 deletion handlers/last_binding_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (h APIHandler) LastBindingOperation(w http.ResponseWriter, req *http.Reques
instanceIDLogKey: instanceID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

version := getAPIVersion(req)
if version.Minor < 14 {
Expand Down
2 changes: 1 addition & 1 deletion handlers/last_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (h APIHandler) LastOperation(w http.ResponseWriter, req *http.Request) {

logger.Info("starting-check-for-operation")

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

lastOperation, err := h.serviceBroker.LastOperation(req.Context(), instanceID, pollDetails)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion handlers/provision.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (h *APIHandler) Provision(w http.ResponseWriter, req *http.Request) {
instanceIDLogKey: instanceID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

var details domain.ProvisionDetails
if err := json.NewDecoder(req.Body).Decode(&details); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion handlers/unbind.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (h APIHandler) Unbind(w http.ResponseWriter, req *http.Request) {
bindingIDLogKey: bindingID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

details := domain.UnbindDetails{
PlanID: req.FormValue("plan_id"),
Expand Down
2 changes: 1 addition & 1 deletion handlers/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (h APIHandler) Update(w http.ResponseWriter, req *http.Request) {
instanceIDLogKey: instanceID,
}, utils.DataForContext(req.Context(), middlewares.CorrelationIDKey))

requestId := fmt.Sprintf("%v", req.Context().Value("requestIdentity"))
requestId := fmt.Sprintf("%v", req.Context().Value(middlewares.RequestIdentityKey))

var details domain.UpdateDetails
if err := json.NewDecoder(req.Body).Decode(&details); err != nil {
Expand Down
10 changes: 10 additions & 0 deletions middlewares/context_keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package middlewares

type ContextKey string

const (
CorrelationIDKey ContextKey = "correlation-id"
InfoLocationKey ContextKey = "infoLocation"
OriginatingIdentityKey ContextKey = "originatingIdentity"
RequestIdentityKey ContextKey = "requestIdentity"
)
2 changes: 0 additions & 2 deletions middlewares/correlation_id_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"github.com/pborman/uuid"
)

const CorrelationIDKey = "correlation-id"

var correlationIDHeaders = []string{"X-Correlation-ID", "X-CorrelationID", "X-ForRequest-ID", "X-Request-ID", "X-Vcap-Request-Id"}

func AddCorrelationIDToContext(next http.Handler) http.Handler {
Expand Down
6 changes: 1 addition & 5 deletions middlewares/info_location_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ import (
"net/http"
)

const (
infoLocationKey = "infoLocation"
)

func AddInfoLocationToContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
infoLocation := req.Header.Get("X-Api-Info-Location")
newCtx := context.WithValue(req.Context(), infoLocationKey, infoLocation)
newCtx := context.WithValue(req.Context(), InfoLocationKey, infoLocation)
next.ServeHTTP(w, req.WithContext(newCtx))
})
}
6 changes: 1 addition & 5 deletions middlewares/originating_identity_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ import (
"net/http"
)

const (
originatingIdentityKey = "originatingIdentity"
)

func AddOriginatingIdentityToContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
originatingIdentity := req.Header.Get("X-Broker-API-Originating-Identity")
newCtx := context.WithValue(req.Context(), originatingIdentityKey, originatingIdentity)
newCtx := context.WithValue(req.Context(), OriginatingIdentityKey, originatingIdentity)
next.ServeHTTP(w, req.WithContext(newCtx))
})
}
2 changes: 0 additions & 2 deletions middlewares/request_identity_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"net/http"
)

const RequestIdentityKey = "requestIdentity"

func AddRequestIdentityToContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requestIdentity := req.Header.Get("X-Broker-API-Request-Identity")
Expand Down
2 changes: 1 addition & 1 deletion staticcheck.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# When adding staticcheck, we thought it was better to get it working with some checks disabled
# rather than fixing all the problems in one go. Some problems cannot be fixed without making
# breaking changes.
checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-SA1029", "-ST1020"]
checks = ["all", "-SA1019", "-ST1000", "-ST1003", "-ST1005", "-ST1012", "-ST1021", "-ST1020"]
5 changes: 3 additions & 2 deletions utils/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"code.cloudfoundry.org/lager"
"github.com/pivotal-cf/brokerapi/v8/domain"
"github.com/pivotal-cf/brokerapi/v8/middlewares"
)

type contextKey string
Expand Down Expand Up @@ -42,11 +43,11 @@ func RetrieveServicePlanFromContext(ctx context.Context) *domain.ServicePlan {
return nil
}

func DataForContext(context context.Context, dataKeys ...string) lager.Data {
func DataForContext(context context.Context, dataKeys ...middlewares.ContextKey) lager.Data {
data := lager.Data{}
for _, key := range dataKeys {
if value := context.Value(key); value != nil {
data[key] = value
data[string(key)] = value
}
}

Expand Down
Loading

0 comments on commit 8d23450

Please sign in to comment.