Skip to content

Commit

Permalink
fix | decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-right committed Jul 19, 2023
1 parent af26601 commit e68a2a2
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/github-actions-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
fail-fast: false
matrix:
python-version: [ 3.10.6]
mongodb-version: [ 5.0, 6.0 ]
mongodb-version: [ 5.0 ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down
3 changes: 3 additions & 0 deletions beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Update,
)
from beanie.odm.bulk import BulkWriter
from beanie.odm.custom_types.decimal import DecimalAnnotation
from beanie.odm.fields import (
PydanticObjectId,
Indexed,
Expand Down Expand Up @@ -61,4 +62,6 @@
"BackLink",
"WriteRules",
"DeleteRules",
# Custom Types
"DecimalAnnotation",
]
Empty file.
40 changes: 40 additions & 0 deletions beanie/odm/custom_types/decimal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from decimal import Decimal as NativeDecimal
from typing import Annotated, Any, Callable

from bson import Decimal128
from pydantic import GetJsonSchemaHandler
from pydantic.fields import FieldInfo
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


class DecimalCustomAnnotation:

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
def validate(value, _: FieldInfo) -> NativeDecimal:
if isinstance(value, Decimal128):
return value.to_decimal()
return value

python_schema = core_schema.general_plain_validator_function(validate)

return core_schema.json_or_python_schema(
json_schema=core_schema.float_schema(),
python_schema=python_schema,
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.float_schema())


DecimalAnnotation = Annotated[
NativeDecimal, DecimalCustomAnnotation
]
4 changes: 3 additions & 1 deletion beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ async def insert(
for obj in value:
if isinstance(obj, Document):
await obj.save(link_rule=WriteRules.WRITE)

print("insert", get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
),)
result = await self.get_motor_collection().insert_one(
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
Expand Down
2 changes: 1 addition & 1 deletion beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,6 @@ async def _find_one(self):
projection_model=self.projection_model,
**self.pymongo_kwargs,
).first_or_none()
print(get_projection(self.projection_model))
return await self.document_model.get_motor_collection().find_one(
filter=self.get_filter_query(),
projection=get_projection(self.projection_model),
Expand Down Expand Up @@ -1009,6 +1008,7 @@ def __await__(
document = yield from self._find_one().__await__() # type: ignore
if document is None:
return None
print(document)
if type(document) == self.projection_model:
return cast(FindQueryResultType, document)
return cast(
Expand Down
4 changes: 3 additions & 1 deletion beanie/odm/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def merge_models(left: BaseModel, right: BaseModel) -> None:
):
left._previous_revision_id = right._previous_revision_id # type: ignore
for k, right_value in right.__iter__():
left_value = left.__getattribute__(k)
print("LEFT", left)
print("RIGHT", right)
left_value = getattr(left, k)
if isinstance(right_value, BaseModel) and isinstance(
left_value, BaseModel
):
Expand Down
18 changes: 7 additions & 11 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import decimal
from decimal import Decimal
from beanie import DecimalAnnotation
from ipaddress import (
IPv4Address,
IPv4Interface,
Expand All @@ -22,7 +22,7 @@
PrivateAttr,
SecretBytes,
SecretStr,
condecimal, ConfigDict,
ConfigDict,
)
from pydantic.color import Color
from pymongo import IndexModel
Expand Down Expand Up @@ -186,7 +186,7 @@ class DocumentWithCustomIdInt(Document):

class DocumentWithCustomFiledsTypes(Document):
color: Color
decimal: Decimal
decimal: DecimalAnnotation
secret_bytes: SecretBytes
secret_string: SecretStr
ipv4address: IPv4Address
Expand Down Expand Up @@ -643,10 +643,8 @@ class Settings:


class StateAndDecimalFieldModel(Document):
amt: decimal.Decimal
other_amt: condecimal(
decimal_places=1, multiple_of=decimal.Decimal("0.5")
) = 0
amt: DecimalAnnotation
other_amt: DecimalAnnotation = Field(decimal_places=1, multiple_of=0.5, default=0)

class Settings:
name = "amounts"
Expand Down Expand Up @@ -706,10 +704,8 @@ class Collection:


class DocumentWithDecimalField(Document):
amt: decimal.Decimal
other_amt: pydantic.condecimal(
decimal_places=1, multiple_of=decimal.Decimal("0.5")
) = 0
amt: DecimalAnnotation
other_amt: DecimalAnnotation = Field(decimal_places=1, multiple_of=0.5, default=0)

class Config:
validate_assignment = True
Expand Down
1 change: 1 addition & 0 deletions tests/odm/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from decimal import Decimal

from pathlib import Path
from typing import Mapping, AbstractSet
import pytest
Expand Down

0 comments on commit e68a2a2

Please sign in to comment.