Skip to content

Commit

Permalink
feat: WithAdditionalMiddleware() option
Browse files Browse the repository at this point in the history
Allows custom middleware to be added *after* the default middleware.

Resolves #273
  • Loading branch information
blgm committed Jan 26, 2024
1 parent 4b27d1c commit 6cd3677
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 29 deletions.
58 changes: 37 additions & 21 deletions api_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,20 @@ import (

type middlewareFunc func(http.Handler) http.Handler

type config struct {
router chi.Router
customRouter bool
logger lager.Logger
additionalMiddleware []middlewareFunc
}

func NewWithOptions(serviceBroker domain.ServiceBroker, logger lager.Logger, opts ...Option) http.Handler {
cfg := newDefaultConfig(logger)
WithOptions(append(opts, withDefaultMiddleware())...)(cfg)
cfg := config{
router: chi.NewRouter(),
logger: logger,
}

WithOptions(append(opts, withDefaultMiddleware())...)(&cfg)
attachRoutes(cfg.router, serviceBroker, logger)

return cfg.router
Expand All @@ -50,12 +61,25 @@ func WithBrokerCredentials(brokerCredentials BrokerCredentials) Option {
}
}

// WithCustomAuth adds the specified middleware *before* any other middleware.
// Despite the name, any middleware can be added whether nor not it has anything to do with authentication.
// But `WithAdditionalMiddleware()` may be a better choice if the middleware is not related to authentication.
// Can be called multiple times.
func WithCustomAuth(authMiddleware middlewareFunc) Option {
return func(c *config) {
c.router.Use(authMiddleware)
}
}

// WithAdditionalMiddleware adds the specified middleware *after* the default middleware.
// Can be called multiple times.
// This option is ignored if `WithRouter()` is used.
func WithAdditionalMiddleware(m middlewareFunc) Option {
return func(c *config) {
c.additionalMiddleware = append(c.additionalMiddleware, m)
}
}

// WithEncodedPath used to opt in to a gorilla/mux behaviour that would treat encoded
// slashes "/" as IDs. For example, it would change `PUT /v2/service_instances/foo%2Fbar`
// to treat `foo%2Fbar` as an instance ID, while the default behavior was to treat it
Expand All @@ -70,11 +94,17 @@ func WithEncodedPath() Option {
func withDefaultMiddleware() Option {
return func(c *config) {
if !c.customRouter {
c.router.Use(middlewares.APIVersionMiddleware{LoggerFactory: c.logger}.ValidateAPIVersionHdr)
c.router.Use(middlewares.AddCorrelationIDToContext)
c.router.Use(middlewares.AddOriginatingIdentityToContext)
c.router.Use(middlewares.AddInfoLocationToContext)
c.router.Use(middlewares.AddRequestIdentityToContext)
defaults := []middlewareFunc{
middlewares.APIVersionMiddleware{LoggerFactory: c.logger}.ValidateAPIVersionHdr,
middlewares.AddCorrelationIDToContext,
middlewares.AddOriginatingIdentityToContext,
middlewares.AddInfoLocationToContext,
middlewares.AddRequestIdentityToContext,
}

for _, m := range append(defaults, c.additionalMiddleware...) {
c.router.Use(m)
}
}
}
}
Expand All @@ -86,17 +116,3 @@ func WithOptions(opts ...Option) Option {
}
}
}

func newDefaultConfig(logger lager.Logger) *config {
return &config{
router: chi.NewRouter(),
customRouter: false,
logger: logger,
}
}

type config struct {
router chi.Router
customRouter bool
logger lager.Logger
}
39 changes: 31 additions & 8 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ var _ = Describe("Service Broker API", func() {

})
})

When("X-Api-Info-Location is not passed", func() {
It("Adds empty infoLocation to the context", func() {
_, err := http.DefaultClient.Do(req)
Expand Down Expand Up @@ -529,16 +530,17 @@ var _ = Describe("Service Broker API", func() {
testServer.Close()
})

DescribeTable("Adds correlation id to the context", func(tc testCase) {
req.Header.Add(tc.correlationIDHeaderName, correlationID)
DescribeTable("Adds correlation id to the context",
func(tc testCase) {
req.Header.Add(tc.correlationIDHeaderName, correlationID)

_, err := http.DefaultClient.Do(req)
Expect(err).NotTo(HaveOccurred())
_, err := http.DefaultClient.Do(req)
Expect(err).NotTo(HaveOccurred())

Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID))
},
Expect(fakeServiceBroker.ServicesCallCount()).To(Equal(1), "Services was not called")
ctx := fakeServiceBroker.ServicesArgsForCall(0)
Expect(ctx.Value(middlewares.CorrelationIDKey)).To(Equal(correlationID))
},
Entry("X-Correlation-ID", testCase{
correlationIDHeaderName: "X-Correlation-ID",
}),
Expand Down Expand Up @@ -2752,6 +2754,27 @@ var _ = Describe("Service Broker API", func() {
})
})

Describe("WithAdditionalMiddleware()", func() {
It("adds additional middleware", func() {
const (
customMiddlewareError = "fake custom middleware error"
customMiddlewareCode = http.StatusTeapot
)

By("adding some custom middleware that fails")
brokerAPI = brokerapi.NewWithOptions(fakeServiceBroker, brokerLogger, brokerapi.WithAdditionalMiddleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
http.Error(w, customMiddlewareError, customMiddlewareCode)
})
}))

By("checking for the specific failure from the custom middleware")
response := makeInstanceProvisioningRequest(uniqueInstanceID(), provisionDetails, "")
Expect(response.RawResponse).To(HaveHTTPStatus(customMiddlewareCode))
Expect(response.Body).To(Equal(customMiddlewareError + "\n"))
})
})

It("will accept URL-encoded paths", func() {
const encodedInstanceID = "foo%2Fbar"
brokerAPI = brokerapi.NewWithOptions(fakeServiceBroker, brokerLogger, brokerapi.WithBrokerCredentials(credentials))
Expand Down

0 comments on commit 6cd3677

Please sign in to comment.