Skip to content

Commit

Permalink
fix set
Browse files Browse the repository at this point in the history
  • Loading branch information
CAPITAINMARVEL committed Sep 18, 2024
1 parent 18a6962 commit fe351b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 50 deletions.
66 changes: 17 additions & 49 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@
# can describe both sync and async, where R itself is a coroutine
AnyDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
# describes only async
AsyncDocMethod: TypeAlias = Callable[
Concatenate[DocType, P], Coroutine[Any, Any, R]
]
AsyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], Coroutine[Any, Any, R]]
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


Expand Down Expand Up @@ -227,9 +225,7 @@ def _fill_back_refs(cls, values):
and field_name not in values
):
values[field_name] = [
BackLink[link_info.document_class](
link_info.document_class
)
BackLink[link_info.document_class](link_info.document_class)
]
return values

Expand Down Expand Up @@ -310,9 +306,7 @@ async def sync(self, merge_strategy: MergeStrategy = MergeStrategy.remote):
new_state = document.get_saved_state()
if new_state is None:
raise DocumentWasNotSaved
changes_to_apply = self._collect_updates(
new_state, original_changes
)
changes_to_apply = self._collect_updates(new_state, original_changes)
merge_models(self, document)
apply_changes(changes_to_apply, self)
elif merge_strategy == MergeStrategy.remote:
Expand Down Expand Up @@ -360,9 +354,7 @@ async def insert(
]
)
result = await self.get_motor_collection().insert_one(
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
),
get_dict(self, to_db=True, keep_nulls=self.get_settings().keep_nulls),
session=session,
)
new_id = result.inserted_id
Expand Down Expand Up @@ -403,16 +395,12 @@ async def insert_one(
:return: DocType
"""
if not isinstance(document, cls):
raise TypeError(
"Inserting document must be of the original document class"
)
raise TypeError("Inserting document must be of the original document class")
if bulk_writer is None:
return await document.insert(link_rule=link_rule, session=session)
else:
if link_rule == WriteRules.WRITE:
raise NotSupported(
"Cascade insert with bulk writing not supported"
)
raise NotSupported("Cascade insert with bulk writing not supported")
bulk_writer.add_operation(
Operation(
operation=InsertOne,
Expand Down Expand Up @@ -443,9 +431,7 @@ async def insert_many(
:return: InsertManyResult
"""
if link_rule == WriteRules.WRITE:
raise NotSupported(
"Cascade insert not supported for insert many method"
)
raise NotSupported("Cascade insert not supported for insert many method")
documents_list = [
get_dict(
document,
Expand Down Expand Up @@ -572,9 +558,7 @@ async def save(
LinkTypes.OPTIONAL_BACK_DIRECT,
]:
if isinstance(value, Document):
await value.save(
link_rule=link_rule, session=session
)
await value.save(link_rule=link_rule, session=session)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST,
Expand All @@ -584,9 +568,7 @@ async def save(
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=link_rule, session=session
)
obj.save(link_rule=link_rule, session=session)
for obj in value
if isinstance(obj, Document)
]
Expand Down Expand Up @@ -674,14 +656,10 @@ async def replace_many(
"""
ids_list = [document.id for document in documents]
if await cls.find(In(cls.id, ids_list)).count() != len(ids_list):
raise ReplaceError(
"Some of the documents are not exist in the collection"
)
raise ReplaceError("Some of the documents are not exist in the collection")
async with BulkWriter(session=session) as bulk_writer:
for document in documents:
await document.replace(
bulk_writer=bulk_writer, session=session
)
await document.replace(bulk_writer=bulk_writer, session=session)

@wrap_with_actions(EventTypes.UPDATE)
@save_state_after
Expand Down Expand Up @@ -763,7 +741,7 @@ def update_all(

def set(
self: DocType,
expression: Dict[Any, Any],
expression: Dict[Union[ExpressionField, str, Any], Any],
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
Expand Down Expand Up @@ -1116,18 +1094,14 @@ async def inspect_collection(
:return: InspectionResult
"""
inspection_result = InspectionResult()
async for json_document in cls.get_motor_collection().find(
{}, session=session
):
async for json_document in cls.get_motor_collection().find({}, session=session):
try:
parse_model(cls, json_document)
except ValidationError as e:
if inspection_result.status == InspectionStatuses.OK:
inspection_result.status = InspectionStatuses.FAIL
inspection_result.errors.append(
InspectionError(
document_id=json_document["_id"], error=str(e)
)
InspectionError(document_id=json_document["_id"], error=str(e))
)
return inspection_result

Expand Down Expand Up @@ -1200,9 +1174,7 @@ async def distinct(
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> list:
return await cls.get_motor_collection().distinct(
key, filter, session, **kwargs
)
return await cls.get_motor_collection().distinct(key, filter, session, **kwargs)

@classmethod
def link_from_id(cls, id: Any):
Expand Down Expand Up @@ -1293,9 +1265,7 @@ def find_many( # type: ignore
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs,
) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]:
args = cls._add_class_id_filter(args, with_children) + (
{"deleted_at": None},
)
args = cls._add_class_id_filter(args, with_children) + ({"deleted_at": None},)
return cls._find_many_query_class(document_model=cls).find_many(
*args,
sort=sort,
Expand Down Expand Up @@ -1324,9 +1294,7 @@ def find_one( # type: ignore
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs,
) -> Union[FindOne[FindType], FindOne["DocumentProjectionType"]]:
args = cls._add_class_id_filter(args, with_children) + (
{"deleted_at": None},
)
args = cls._add_class_id_filter(args, with_children) + ({"deleted_at": None},)
return cls._find_one_query_class(document_model=cls).find_one(
*args,
projection_model=projection_model,
Expand Down
2 changes: 1 addition & 1 deletion beanie/odm/interfaces/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def update(

def set(
self,
expression: Dict[Any, Any],
expression: Dict[Union[ExpressionField, str, Any], Any],
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
**kwargs,
Expand Down

0 comments on commit fe351b7

Please sign in to comment.