Skip to content

Commit

Permalink
feat(client): initial support for middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jan 8, 2023
1 parent 47c9844 commit e87e2b4
Show file tree
Hide file tree
Showing 12 changed files with 980 additions and 339 deletions.
42 changes: 42 additions & 0 deletions databases/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Iterator

import pytest

from prisma import Prisma
from prisma.middleware import (
MiddlewareParams,
NextMiddleware,
MiddlewareResult,
)

# TODO: more tests
# TODO: test every action


@pytest.fixture(autouse=True)
def cleanup_middlewares(client: Prisma) -> Iterator[None]:
middlewares = client._middlewares.copy()

client._middlewares.clear()

try:
yield
finally:
client._middlewares = middlewares


@pytest.mark.asyncio
async def test_basic(client: Prisma) -> None:
__ran__ = False

async def middleware(
params: MiddlewareParams, get_result: NextMiddleware
) -> MiddlewareResult:
nonlocal __ran__

__ran__ = True
return await get_result(params)

client.use(middleware)
await client.user.create({'name': 'Robert'})
assert __ran__ is True
36 changes: 24 additions & 12 deletions src/prisma/generator/templates/actions.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.create }}',
method='{{ methods.create }}',
method='create',
prisma_method='{{ methods.create }}',
model='{{ model.name }}',
arguments={
'data': data,
Expand Down Expand Up @@ -239,7 +240,8 @@ class {{ model.name }}Actions:

resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.create_many }}',
method='{{ methods.create_many }}',
method='create_many',
prisma_method='{{ methods.create_many }}',
model='{{ model.name }}',
arguments={
'data': data,
Expand Down Expand Up @@ -288,7 +290,8 @@ class {{ model.name }}Actions:
try:
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.delete }}',
method='{{ methods.delete }}',
method='delete',
prisma_method='{{ methods.delete }}',
model='{{ model.name }}',
arguments={
'where': where,
Expand Down Expand Up @@ -338,7 +341,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.find_unique }}',
method='{{ methods.find_unique }}',
method='find_unique',
prisma_method='{{ methods.find_unique }}',
model='{{ model.name }}',
arguments={
'where': where,
Expand Down Expand Up @@ -408,7 +412,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.find_many }}',
method='{{ methods.find_many }}',
method='find_many',
prisma_method='{{ methods.find_many }}',
model='{{ model.name }}',
arguments={
'take': take,
Expand Down Expand Up @@ -474,7 +479,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.find_first }}',
method='{{ methods.find_first }}',
method='find_first',
prisma_method='{{ methods.find_first }}',
model='{{ model.name }}',
arguments={
'skip': skip,
Expand Down Expand Up @@ -534,7 +540,8 @@ class {{ model.name }}Actions:
try:
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations["update"] }}',
method='{{ methods["update"] }}',
method='update',
prisma_method='{{ methods["update"] }}',
model='{{ model.name }}',
arguments={
'data': data,
Expand Down Expand Up @@ -620,7 +627,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.upsert }}',
method='{{ methods.upsert }}',
method='upsert',
prisma_method='{{ methods.upsert }}',
model='{{ model.name }}',
arguments={
'where': where,
Expand Down Expand Up @@ -669,7 +677,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.update_many }}',
method='{{ methods.update_many }}',
method='update_many',
prisma_method='{{ methods.update_many }}',
model='{{ model.name }}',
arguments={'data': data, 'where': where,},
root_selection=['count'],
Expand Down Expand Up @@ -776,7 +785,8 @@ class {{ model.name }}Actions:

resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.count }}',
method='{{ methods.count }}',
method='count',
prisma_method='{{ methods.count }}',
model='{{ model.name }}',
arguments={
'take': take,
Expand Down Expand Up @@ -821,7 +831,8 @@ class {{ model.name }}Actions:
"""
resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.delete_many }}',
method='{{ methods.delete_many }}',
method='delete_many',
prisma_method='{{ methods.delete_many }}',
model='{{ model.name }}',
arguments={'where': where},
root_selection=['count'],
Expand Down Expand Up @@ -940,7 +951,8 @@ class {{ model.name }}Actions:

resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.group_by }}',
method='{{ methods.group_by }}',
method='group_by',
prisma_method='{{ methods.group_by }}',
model='{{ model.name }}',
arguments={
'by': by,
Expand Down
6 changes: 4 additions & 2 deletions src/prisma/generator/templates/builder.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ RELATIONAL_FIELD_MAPPINGS: Dict[str, Dict[str, str]] = {
{% endfor %}
}

Operation = Literal['query', 'mutation']


class QueryBuilder:
# prisma method
method: str

# GraphQL operation
operation: str
operation: Operation

# prisma model
model: Optional[str]
Expand All @@ -85,7 +87,7 @@ class QueryBuilder:
self,
*,
method: str,
operation: str,
operation: Operation,
arguments: Dict[str, Any],
model: Optional[str] = None,
root_selection: Optional[List[str]] = None
Expand Down
104 changes: 93 additions & 11 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ from . import types, models, errors, actions
from .types import DatasourceOverride, HttpConfig
from ._types import BaseModelT
from .engine import AbstractEngine, QueryEngine
from .builder import QueryBuilder
from .builder import QueryBuilder, Operation
from .middleware import MiddlewareFunc, MiddlewareParams, MiddlewareResult
from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths
from ._compat import removeprefix
from ._raw_query import deserialize_raw_results
Expand Down Expand Up @@ -131,6 +132,7 @@ class Prisma:
'_datasource',
'_connect_timeout',
'_http_config',
'_middlewares',
)

def __init__(
Expand All @@ -156,6 +158,7 @@ class Prisma:
self._datasource = datasource
self._connect_timeout = connect_timeout
self._http_config: HttpConfig = http or {}
self._middlewares: List[MiddlewareFunc] = []

if use_dotenv:
load_env()
Expand Down Expand Up @@ -243,11 +246,21 @@ class Prisma:
{% endif %}
self.__engine = None

# TODO: add support for clearing middlewares

def use(self, *middlewares: MiddlewareFunc) -> None:
"""Setup a middleware function to be called before each query is executed.

TODO: link to docs
"""
self._middlewares.extend(middlewares)

{% if active_provider != 'mongodb' %}
{{ maybe_async_def }}execute_raw(self, query: LiteralString, *args: Any) -> int:
resp = {{ maybe_await }}self._execute(
operation='{{ operations.execute_raw }}',
method='{{ methods.execute_raw }}',
method='execute_raw',
prisma_method='{{ methods.execute_raw }}',
arguments={
'query': query,
'parameters': args,
Expand Down Expand Up @@ -285,9 +298,18 @@ class Prisma:
"""
results: Sequence[Union[BaseModelT, dict[str, Any]]]
if model is not None:
results = {{ maybe_await }}self.query_raw(query, *args, model=model)
results = {{ maybe_await }}self._execute_raw_query(
query,
*args,
model=model,
_from_method='query_first',
)
else:
results = {{ maybe_await }}self.query_raw(query, *args)
results = {{ maybe_await }}self._execute_raw_query(
query,
*args,
_from_method='query_first',
)

if not results:
return None
Expand Down Expand Up @@ -322,9 +344,41 @@ class Prisma:
If model is given, each returned record is converted to the pydantic model first,
otherwise results will be raw dictionaries.
"""
return {{ maybe_await }}self._execute_raw_query(
query, *args, _from_method='query_raw'
)


@overload
{{ maybe_async_def }}_execute_raw_query(
self,
query: LiteralString,
*args: Any,
_from_method: str,
) -> list[dict[str, Any]]:
...

@overload
{{ maybe_async_def }}_execute_raw_query(
self,
query: LiteralString,
*args: Any,
_from_method: str,
model: Type[BaseModelT],
) -> list[BaseModelT]:
...

{{ maybe_async_def }}_execute_raw_query(
self,
query: LiteralString,
*args: Any,
_from_method: str,
model: Optional[Type[BaseModelT]] = None,
) -> list[BaseModelT] | list[dict[str, Any]]:
resp = {{ maybe_await }}self._execute(
operation='{{ operations.query_raw }}',
method='{{ methods.query_raw }}',
method=_from_method,
prisma_method='{{ methods.query_raw }}',
arguments={
'query': query,
'parameters': args,
Expand All @@ -335,6 +389,7 @@ class Prisma:
return deserialize_raw_results(result, model=model)

return deserialize_raw_results(result)

{% endif %}

def batch_(self) -> 'Batch':
Expand All @@ -344,20 +399,47 @@ class Prisma:
# TODO: don't return Any
{{ maybe_async_def }}_execute(
self,
*,
method: str,
operation: str,
prisma_method: str,
operation: Operation,
arguments: Dict[str, Any],
model: Optional[str] = None,
root_selection: Optional[List[str]] = None
) -> Any:
builder = QueryBuilder(
operation=operation,
{{ maybe_async_def }}executor(params: MiddlewareParams) -> MiddlewareResult:
builder = QueryBuilder(
operation=params.operation,
method=params.prisma_method,
model=params.model_name,
arguments=params.arguments,
# TODO: move this to middleware params too
root_selection=root_selection,
)
return {{ maybe_await }}self._engine.query(builder.build())

# TODO: should this also operate on the parsed return type?
{{ maybe_async_def }}get_result(params: MiddlewareParams) -> MiddlewareResult:
try:
middleware = next(iterator)
except StopIteration:
return {{ maybe_await }}executor(params)

return {{ maybe_await }}middleware(params, get_result)

params = MiddlewareParams(
method=method,
model=model,
operation=operation,
model_name=model,
arguments=arguments,
root_selection=root_selection,
prisma_method=prisma_method,
)
return {{ maybe_await }}self._engine.query(builder.build())
if not self._middlewares:
return {{ maybe_await }}executor(params)

first, *rest = self._middlewares
iterator = iter(rest)
return {{ maybe_await }}first(params, get_result)

def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine:
if ENGINE_TYPE == EngineType.binary:
Expand Down
6 changes: 6 additions & 0 deletions src/prisma/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._middleware import (
MiddlewareParams as MiddlewareParams,
MiddlewareResult as MiddlewareResult,
NextMiddleware as NextMiddleware,
MiddlewareFunc as MiddlewareFunc,
)

0 comments on commit e87e2b4

Please sign in to comment.