Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): add support for middleware #643

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 107 additions & 0 deletions databases/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Iterator, cast
from prisma.models import User

import pytest

from prisma import Prisma
from prisma.middleware import (
MiddlewareParams,
NextMiddleware,
MiddlewareResult,
)

# TODO: more tests
# TODO: test every action


@pytest.fixture(autouse=True)
def cleanup_middlewares(client: Prisma) -> Iterator[None]:
middlewares = client._middlewares.copy()

client._middlewares.clear()

try:
yield
finally:
client._middlewares = middlewares


@pytest.mark.asyncio
async def test_basic(client: Prisma) -> None:
__ran__ = False

async def middleware(
params: MiddlewareParams, get_result: NextMiddleware
) -> MiddlewareResult:
nonlocal __ran__

__ran__ = True
return await get_result(params)

client.use(middleware)
await client.user.create({'name': 'Robert'})
assert __ran__ is True


@pytest.mark.asyncio
async def test_modified_return_field(client: Prisma) -> None:
async def middleware(
params: MiddlewareParams, get_result: NextMiddleware
) -> MiddlewareResult:
assert params.model is not None
assert params.model.__name__ == 'User'
assert params.method == 'create'

result = await get_result(params)

assert isinstance(result, User)
result.name = 'Tegan'
return result

client.use(middleware)
user = await client.user.create({'name': 'Robert'})
assert user.name == 'Tegan'


@pytest.mark.asyncio
async def test_modified_return_type(client: Prisma) -> None:
# TODO: note about alternatives
class MyCustomUser(User):
@property
def full_name(self) -> str:
return self.name + ' Smith'

async def middleware(
params: MiddlewareParams, get_result: NextMiddleware
) -> MiddlewareResult:
result = await get_result(params)

return MiddlewareResult(MyCustomUser.parse_obj(result))

client.use(middleware)
user = await client.user.create({'name': 'Robert'})
assert user.name == 'Robert'
assert isinstance(user, MyCustomUser)
assert user.full_name == 'Robert Smith'


@pytest.mark.asyncio
async def test_modified_arguments(client: Prisma) -> None:
async def middleware(
params: MiddlewareParams, get_result: NextMiddleware
) -> MiddlewareResult:
data = cast('dict[str, object] | None', params.arguments.get('data'))
if data is not None: # pragma: no branch
name = data.get('name')
if name == 'Robert': # pragma: no branch
data['name'] = 'Tegan'

return await get_result(params)

client.use(middleware)

user = await client.user.create({'name': 'Robert'})
assert user.name == 'Tegan'

user = await client.user.create({'name': 'Alfie'})
assert user.name == 'Alfie'
8 changes: 8 additions & 0 deletions src/prisma/_raw_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def deserialize_raw_results(
...


@overload
def deserialize_raw_results(
raw_list: list[dict[str, object]],
model: type[BaseModelT] | None,
) -> list[BaseModelT] | list[dict[str, Any]]:
...


def deserialize_raw_results(
raw_list: list[dict[str, Any]],
model: type[BaseModelT] | None = None,
Expand Down
20 changes: 8 additions & 12 deletions src/prisma/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@
from enum import Enum
from pathlib import Path
from typing import (
Optional,
List,
Union,
NoReturn,
Mapping,
Any,
Type,
overload,
cast,
)
Expand All @@ -32,8 +28,8 @@ class PrismaCLI(click.MultiCommand):
base_package: str = 'prisma.cli.commands'
folder: Path = Path(__file__).parent / 'commands'

def list_commands(self, ctx: click.Context) -> List[str]:
commands: List[str] = []
def list_commands(self, ctx: click.Context) -> list[str]:
commands: list[str] = []

for path in self.folder.iterdir():
name = path.name
Expand All @@ -50,7 +46,7 @@ def list_commands(self, ctx: click.Context) -> List[str]:

def get_command(
self, ctx: click.Context, cmd_name: str
) -> Optional[click.Command]:
) -> click.Command | None:
name = f'{self.base_package}.{cmd_name}'
if not module_exists(name):
# command not found
Expand Down Expand Up @@ -92,7 +88,7 @@ class MyEnum(str, Enum):
results in click.Choice(['bar'])
"""

def __init__(self, enum: Type[Enum]) -> None:
def __init__(self, enum: type[Enum]) -> None:
if str not in enum.__mro__:
raise TypeError('Enum does not subclass `str`')

Expand All @@ -102,8 +98,8 @@ def __init__(self, enum: Type[Enum]) -> None:
def convert(
self,
value: str,
param: Optional[click.Parameter],
ctx: Optional[click.Context],
param: click.Parameter | None,
ctx: click.Context | None,
) -> str:
return str(
cast(Any, self.__enum(super().convert(value, param, ctx)).value)
Expand All @@ -121,7 +117,7 @@ def maybe_exit(retcode: int) -> None:


def generate_client(
schema: Optional[str] = None, *, reload: bool = False
schema: str | None = None, *, reload: bool = False
) -> None:
"""Run `prisma generate` and update sys.modules"""
args = ['generate']
Expand Down Expand Up @@ -158,7 +154,7 @@ def error(message: str, exit_: Literal[False]) -> None:
...


def error(message: str, exit_: bool = True) -> Union[None, NoReturn]:
def error(message: str, exit_: bool = True) -> None | NoReturn:
click.echo(click.style(message, fg='bright_red', bold=True), err=True)
if exit_:
sys.exit(1)
Expand Down