Skip to content

Commit

Permalink
Annotate decorators that wrap Document methods (#679) (#886)
Browse files Browse the repository at this point in the history
* Annotate decorators that wrap `Document` methods (#679)

Type checkers were complaining about missing `self`
argument in decorated `Document` methods. This was
caused by incomplete annotations of used decorators.

* fixup! Annotate decorators that wrap `Document` methods (#679)

Removed sync/async overload in favour of ignoring errors in wrappers
because mypy confused them and always expected async function.

---------

Co-authored-by: Maxim Borisov <[email protected]>
  • Loading branch information
bedlamzd and Maxim Borisov authored May 1, 2024
1 parent fc79936 commit 7de0303
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 37 deletions.
55 changes: 40 additions & 15 deletions beanie/odm/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from beanie.odm.documents import Document
from beanie.odm.documents import AsyncDocMethod, DocType, Document

P = ParamSpec("P")
R = TypeVar("R")


class EventTypes(str, Enum):
Expand Down Expand Up @@ -136,10 +142,14 @@ async def run_actions(
await asyncio.gather(*coros)


# `Any` because there is arbitrary attribute assignment on this type
F = TypeVar("F", bound=Any)


def register_action(
event_types: Tuple[Union[List[EventTypes], EventTypes]],
event_types: Tuple[Union[List[EventTypes], EventTypes], ...],
action_direction: ActionDirections,
):
) -> Callable[[F], F]:
"""
Decorator. Base registration method.
Used inside `before_event` and `after_event`
Expand All @@ -154,7 +164,7 @@ def register_action(
else:
final_event_types.append(event_type)

def decorator(f):
def decorator(f: F) -> F:
f.has_action = True
f.event_types = final_event_types
f.action_direction = action_direction
Expand All @@ -163,7 +173,9 @@ def decorator(f):
return decorator


def before_event(*args: Union[List[EventTypes], EventTypes]):
def before_event(
*args: Union[List[EventTypes], EventTypes]
) -> Callable[[F], F]:
"""
Decorator. It adds action, which should run before mentioned one
or many events happen
Expand All @@ -172,11 +184,13 @@ def before_event(*args: Union[List[EventTypes], EventTypes]):
:return: None
"""
return register_action(
action_direction=ActionDirections.BEFORE, event_types=args # type: ignore
action_direction=ActionDirections.BEFORE, event_types=args
)


def after_event(*args: Union[List[EventTypes], EventTypes]):
def after_event(
*args: Union[List[EventTypes], EventTypes]
) -> Callable[[F], F]:
"""
Decorator. It adds action, which should run after mentioned one
or many events happen
Expand All @@ -186,26 +200,32 @@ def after_event(*args: Union[List[EventTypes], EventTypes]):
"""

return register_action(
action_direction=ActionDirections.AFTER, event_types=args # type: ignore
action_direction=ActionDirections.AFTER, event_types=args
)


def wrap_with_actions(event_type: EventTypes):
def wrap_with_actions(
event_type: EventTypes,
) -> Callable[
["AsyncDocMethod[DocType, P, R]"], "AsyncDocMethod[DocType, P, R]"
]:
"""
Helper function to wrap Document methods with
before and after event listeners
:param event_type: EventTypes - event types
:return: None
"""

def decorator(f: Callable):
def decorator(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(
self,
*args,
self: "DocType",
*args: P.args,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**kwargs,
):
**kwargs: P.kwargs,
) -> R:
if skip_actions is None:
skip_actions = []

Expand All @@ -216,7 +236,12 @@ async def wrapper(
exclude=skip_actions,
)

result = await f(self, *args, skip_actions=skip_actions, **kwargs)
result = await f(
self,
*args,
skip_actions=skip_actions, # type: ignore[arg-type]
**kwargs,
)

await ActionRegistry.run_actions(
self,
Expand Down
28 changes: 20 additions & 8 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from enum import Enum
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Coroutine,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -32,6 +35,7 @@
DeleteResult,
InsertManyResult,
)
from typing_extensions import Concatenate, ParamSpec, TypeAlias

from beanie.exceptions import (
CollectionWasNotInitialized,
Expand Down Expand Up @@ -104,6 +108,14 @@
from pydantic import model_validator

DocType = TypeVar("DocType", bound="Document")
P = ParamSpec("P")
R = TypeVar("R")
# can describe both sync and async, where R itself is a coroutine
AnyDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
# describes only async
AsyncDocMethod: TypeAlias = Callable[
Concatenate[DocType, P], Coroutine[Any, Any, R]
]
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


Expand Down Expand Up @@ -529,7 +541,7 @@ async def save(
link_rule: WriteRules = WriteRules.DO_NOTHING,
ignore_revision: bool = False,
**kwargs,
) -> None:
) -> DocType:
"""
Update an existing model in the database or
insert it if it does not yet exist.
Expand Down Expand Up @@ -605,12 +617,12 @@ async def save(
@wrap_with_actions(EventTypes.SAVE_CHANGES)
@validate_self_before
async def save_changes(
self,
self: DocType,
ignore_revision: bool = False,
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
) -> None:
) -> Optional[DocType]:
"""
Save changes.
State management usage must be turned on
Expand All @@ -632,7 +644,7 @@ async def save_changes(
)
else:
return await self.set(
changes, # type: ignore #TODO fix typing
changes,
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
Expand Down Expand Up @@ -741,13 +753,13 @@ def update_all(
)

def set(
self,
self: DocType,
expression: Dict[Union[ExpressionField, str], Any],
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
**kwargs,
):
) -> Awaitable[DocType]:
"""
Set values
Expand Down Expand Up @@ -976,7 +988,7 @@ def get_previous_saved_state(self) -> Optional[Dict[str, Any]]:
"""
return self._previous_saved_state

@property # type: ignore
@property
@saved_state_needed
def is_changed(self) -> bool:
if self._saved_state == get_dict(
Expand All @@ -988,7 +1000,7 @@ def is_changed(self) -> bool:
return False
return True

@property # type: ignore
@property
@saved_state_needed
@previous_saved_state_needed
def has_changed(self) -> bool:
Expand Down
15 changes: 11 additions & 4 deletions beanie/odm/utils/self_validation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from functools import wraps
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from beanie.odm.documents import DocType
from beanie.odm.documents import AsyncDocMethod, DocType

P = ParamSpec("P")
R = TypeVar("R")


def validate_self_before(f: Callable):
def validate_self_before(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
await self.validate_self(*args, **kwargs)
return await f(self, *args, **kwargs)

Expand Down
39 changes: 29 additions & 10 deletions beanie/odm/utils/state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import ParamSpec

from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved

if TYPE_CHECKING:
from beanie.odm.documents import DocType
from beanie.odm.documents import AnyDocMethod, AsyncDocMethod, DocType

P = ParamSpec("P")
R = TypeVar("R")


def check_if_state_saved(self: "DocType"):
Expand All @@ -17,7 +22,9 @@ def check_if_state_saved(self: "DocType"):
raise StateNotSaved("No state was saved")


def saved_state_needed(f: Callable):
def saved_state_needed(
f: "AnyDocMethod[DocType, P, R]",
) -> "AnyDocMethod[DocType, P, R]":
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
Expand All @@ -26,10 +33,14 @@ def sync_wrapper(self: "DocType", *args, **kwargs):
@wraps(f)
async def async_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
return await f(self, *args, **kwargs)
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return await f(self, *args, **kwargs) # type: ignore[misc]

if inspect.iscoroutinefunction(f):
return async_wrapper
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return async_wrapper # type: ignore[return-value]
return sync_wrapper


Expand All @@ -44,7 +55,9 @@ def check_if_previous_state_saved(self: "DocType"):
)


def previous_saved_state_needed(f: Callable):
def previous_saved_state_needed(
f: "AnyDocMethod[DocType, P, R]",
) -> "AnyDocMethod[DocType, P, R]":
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_previous_state_saved(self)
Expand All @@ -53,16 +66,22 @@ def sync_wrapper(self: "DocType", *args, **kwargs):
@wraps(f)
async def async_wrapper(self: "DocType", *args, **kwargs):
check_if_previous_state_saved(self)
return await f(self, *args, **kwargs)
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return await f(self, *args, **kwargs) # type: ignore[misc]

if inspect.iscoroutinefunction(f):
return async_wrapper
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return async_wrapper # type: ignore[return-value]
return sync_wrapper


def save_state_after(f: Callable):
def save_state_after(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
result = await f(self, *args, **kwargs)
self._save_state()
return result
Expand Down
Loading

0 comments on commit 7de0303

Please sign in to comment.