diff --git a/internal/xmaps/set.go b/internal/xmaps/set.go new file mode 100644 index 000000000..0b8bacc1b --- /dev/null +++ b/internal/xmaps/set.go @@ -0,0 +1,10 @@ +package xmaps + +// BuildSet builds a set from the given values. +func BuildSet[K comparable](s ...K) map[K]struct{} { + r := make(map[K]struct{}, len(s)) + for _, v := range s { + r[v] = struct{}{} + } + return r +} diff --git a/middleware/match.go b/middleware/match.go new file mode 100644 index 000000000..c12960ebd --- /dev/null +++ b/middleware/match.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "regexp" + + "github.com/ogen-go/ogen/internal/xmaps" +) + +// OperationID calls the next middleware if request operation ID matches the given operationID. +func OperationID(m Middleware, operationID ...string) Middleware { + switch len(operationID) { + case 0: + return justCallNext + case 1: + val := operationID[0] + return func(req Request, next Next) (Response, error) { + if req.OperationID == val { + return m(req, next) + } + return next(req) + } + default: + set := xmaps.BuildSet(operationID...) + return func(req Request, next Next) (Response, error) { + if _, ok := set[req.OperationID]; ok { + return m(req, next) + } + return next(req) + } + } +} + +// OperationName calls the next middleware if request operation name matches the given operationName. +func OperationName(m Middleware, operationName ...string) Middleware { + switch len(operationName) { + case 0: + return justCallNext + case 1: + val := operationName[0] + return func(req Request, next Next) (Response, error) { + if req.OperationName == val { + return m(req, next) + } + return next(req) + } + default: + set := xmaps.BuildSet(operationName...) + return func(req Request, next Next) (Response, error) { + if _, ok := set[req.OperationName]; ok { + return m(req, next) + } + return next(req) + } + } +} + +// PathRegex calls the next middleware if request path matches the given regex. +func PathRegex(re *regexp.Regexp, m Middleware) Middleware { + if re == nil { + return justCallNext + } + + return func(req Request, next Next) (Response, error) { + if re.MatchString(req.Raw.URL.Path) { + return m(req, next) + } + return next(req) + } +} + +// BodyType calls the next middleware if request body type matches the given type. +func BodyType[T any](m Middleware) Middleware { + return func(req Request, next Next) (Response, error) { + if _, ok := req.Body.(T); ok { + return m(req, next) + } + return next(req) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index ed2aa80f0..78eb4250c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -73,12 +73,14 @@ type ( Middleware func(req Request, next Next) (Response, error) ) +func justCallNext(req Request, next Next) (Response, error) { + return next(req) +} + // ChainMiddlewares chains middlewares into a single middleware, which will be executed in the order they are passed. func ChainMiddlewares(m ...Middleware) Middleware { if len(m) == 0 { - return func(req Request, next Next) (Response, error) { - return next(req) - } + return justCallNext } tail := ChainMiddlewares(m[1:]...) return func(req Request, next Next) (Response, error) { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 9f53f52bf..d8ee8eae1 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -79,14 +79,11 @@ func TestChainMiddlewares(t *testing.T) { func BenchmarkChainMiddlewares(b *testing.B) { const N = 20 - noop := func(req Request, next Next) (Response, error) { - return next(req) - } var ( chain = ChainMiddlewares(func() (r []Middleware) { for i := 0; i < N; i++ { - r = append(r, noop) + r = append(r, justCallNext) } return r }()...)