diff --git a/api.go b/api.go index 705e4d2a..ab730675 100644 --- a/api.go +++ b/api.go @@ -41,6 +41,7 @@ func New(serviceBroker ServiceBroker, logger lager.Logger, brokerCredentials Bro router.Use(authMiddleware) router.Use(middlewares.AddOriginatingIdentityToContext) router.Use(apiVersionMiddleware.ValidateAPIVersionHdr) + router.Use(middlewares.AddInfoLocationToContext) return router } diff --git a/api_test.go b/api_test.go index 48dbc37f..225409aa 100644 --- a/api_test.go +++ b/api_test.go @@ -2206,4 +2206,54 @@ var _ = Describe("Service Broker API", func() { }) }) }) + + Describe("InfoLocationHeader", func() { + + var ( + fakeServiceBroker *fakes.AutoFakeServiceBroker + req *http.Request + testServer *httptest.Server + ) + + BeforeEach(func() { + fakeServiceBroker = new(fakes.AutoFakeServiceBroker) + brokerAPI = brokerapi.New(fakeServiceBroker, brokerLogger, credentials) + + testServer = httptest.NewServer(brokerAPI) + var err error + req, err = http.NewRequest("GET", testServer.URL+"/v2/catalog", nil) + Expect(err).NotTo(HaveOccurred()) + req.Header.Add("X-Broker-API-Version", "2.14") + req.SetBasicAuth(credentials.Username, credentials.Password) + }) + + AfterEach(func() { + testServer.Close() + }) + + When("X-Api-Info-Location is passed", func() { + It("Adds it to the context", func() { + infoLocation := "API Info Location Value" + req.Header.Add("X-Api-Info-Location", infoLocation) + + _, err := http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") + ctx := fakeServiceBroker.ServicesArgsForCall(0) + Expect(ctx.Value("infoLocation")).To(Equal(infoLocation)) + + }) + }) + When("X-Api-Info-Location is not passed", func() { + It("Adds empty infoLocation to the context", func() { + _, err := http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") + ctx := fakeServiceBroker.ServicesArgsForCall(0) + Expect(ctx.Value("infoLocation")).To(Equal("")) + }) + }) + }) }) diff --git a/middlewares/info_location_header.go b/middlewares/info_location_header.go new file mode 100644 index 00000000..f23acc90 --- /dev/null +++ b/middlewares/info_location_header.go @@ -0,0 +1,33 @@ +// Copyright (C) 2015-Present Pivotal Software, Inc. All rights reserved. + +// This program and the accompanying materials are made available under +// the terms of the under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package middlewares + +import ( + "context" + "net/http" +) + +const ( + infoLocationKey = "infoLocation" +) + +func AddInfoLocationToContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + infoLocation := req.Header.Get("X-Api-Info-Location") + newCtx := context.WithValue(req.Context(), infoLocationKey, infoLocation) + next.ServeHTTP(w, req.WithContext(newCtx)) + }) +}