From 6cd367775fbfd307546619345eb63ebe5f1ada19 Mon Sep 17 00:00:00 2001 From: George Blue Date: Fri, 26 Jan 2024 14:26:31 +0000 Subject: [PATCH] feat: WithAdditionalMiddleware() option Allows custom middleware to be added *after* the default middleware. Resolves https://github.com/pivotal-cf/brokerapi/issues/273 --- api_options.go | 58 ++++++++++++++++++++++++++++++++------------------ api_test.go | 39 ++++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 29 deletions(-) diff --git a/api_options.go b/api_options.go index 488b2417..69e51c21 100644 --- a/api_options.go +++ b/api_options.go @@ -27,9 +27,20 @@ import ( type middlewareFunc func(http.Handler) http.Handler +type config struct { + router chi.Router + customRouter bool + logger lager.Logger + additionalMiddleware []middlewareFunc +} + func NewWithOptions(serviceBroker domain.ServiceBroker, logger lager.Logger, opts ...Option) http.Handler { - cfg := newDefaultConfig(logger) - WithOptions(append(opts, withDefaultMiddleware())...)(cfg) + cfg := config{ + router: chi.NewRouter(), + logger: logger, + } + + WithOptions(append(opts, withDefaultMiddleware())...)(&cfg) attachRoutes(cfg.router, serviceBroker, logger) return cfg.router @@ -50,12 +61,25 @@ func WithBrokerCredentials(brokerCredentials BrokerCredentials) Option { } } +// WithCustomAuth adds the specified middleware *before* any other middleware. +// Despite the name, any middleware can be added whether nor not it has anything to do with authentication. +// But `WithAdditionalMiddleware()` may be a better choice if the middleware is not related to authentication. +// Can be called multiple times. func WithCustomAuth(authMiddleware middlewareFunc) Option { return func(c *config) { c.router.Use(authMiddleware) } } +// WithAdditionalMiddleware adds the specified middleware *after* the default middleware. +// Can be called multiple times. +// This option is ignored if `WithRouter()` is used. +func WithAdditionalMiddleware(m middlewareFunc) Option { + return func(c *config) { + c.additionalMiddleware = append(c.additionalMiddleware, m) + } +} + // WithEncodedPath used to opt in to a gorilla/mux behaviour that would treat encoded // slashes "/" as IDs. For example, it would change `PUT /v2/service_instances/foo%2Fbar` // to treat `foo%2Fbar` as an instance ID, while the default behavior was to treat it @@ -70,11 +94,17 @@ func WithEncodedPath() Option { func withDefaultMiddleware() Option { return func(c *config) { if !c.customRouter { - c.router.Use(middlewares.APIVersionMiddleware{LoggerFactory: c.logger}.ValidateAPIVersionHdr) - c.router.Use(middlewares.AddCorrelationIDToContext) - c.router.Use(middlewares.AddOriginatingIdentityToContext) - c.router.Use(middlewares.AddInfoLocationToContext) - c.router.Use(middlewares.AddRequestIdentityToContext) + defaults := []middlewareFunc{ + middlewares.APIVersionMiddleware{LoggerFactory: c.logger}.ValidateAPIVersionHdr, + middlewares.AddCorrelationIDToContext, + middlewares.AddOriginatingIdentityToContext, + middlewares.AddInfoLocationToContext, + middlewares.AddRequestIdentityToContext, + } + + for _, m := range append(defaults, c.additionalMiddleware...) { + c.router.Use(m) + } } } } @@ -86,17 +116,3 @@ func WithOptions(opts ...Option) Option { } } } - -func newDefaultConfig(logger lager.Logger) *config { - return &config{ - router: chi.NewRouter(), - customRouter: false, - logger: logger, - } -} - -type config struct { - router chi.Router - customRouter bool - logger lager.Logger -} diff --git a/api_test.go b/api_test.go index 6bf33f9d..d6bc04ef 100644 --- a/api_test.go +++ b/api_test.go @@ -488,6 +488,7 @@ var _ = Describe("Service Broker API", func() { }) }) + When("X-Api-Info-Location is not passed", func() { It("Adds empty infoLocation to the context", func() { _, err := http.DefaultClient.Do(req) @@ -529,16 +530,17 @@ var _ = Describe("Service Broker API", func() { testServer.Close() }) - DescribeTable("Adds correlation id to the context", func(tc testCase) { - req.Header.Add(tc.correlationIDHeaderName, correlationID) + DescribeTable("Adds correlation id to the context", + func(tc testCase) { + req.Header.Add(tc.correlationIDHeaderName, correlationID) - _, err := http.DefaultClient.Do(req) - Expect(err).NotTo(HaveOccurred()) + _, err := http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) - Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") - ctx := fakeServiceBroker.ServicesArgsForCall(0) - Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID)) - }, + Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called") + ctx := fakeServiceBroker.ServicesArgsForCall(0) + Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID)) + }, Entry("X-Correlation-ID", testCase{ correlationIDHeaderName: "X-Correlation-ID", }), @@ -2752,6 +2754,27 @@ var _ = Describe("Service Broker API", func() { }) }) + Describe("WithAdditionalMiddleware()", func() { + It("adds additional middleware", func() { + const ( + customMiddlewareError = "fake custom middleware error" + customMiddlewareCode = http.StatusTeapot + ) + + By("adding some custom middleware that fails") + brokerAPI = brokerapi.NewWithOptions(fakeServiceBroker, brokerLogger, brokerapi.WithAdditionalMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + http.Error(w, customMiddlewareError, customMiddlewareCode) + }) + })) + + By("checking for the specific failure from the custom middleware") + response := makeInstanceProvisioningRequest(uniqueInstanceID(), provisionDetails, "") + Expect(response.RawResponse).To(HaveHTTPStatus(customMiddlewareCode)) + Expect(response.Body).To(Equal(customMiddlewareError + "\n")) + }) + }) + It("will accept URL-encoded paths", func() { const encodedInstanceID = "foo%2Fbar" brokerAPI = brokerapi.NewWithOptions(fakeServiceBroker, brokerLogger, brokerapi.WithBrokerCredentials(credentials))