diff --git a/beanie/odm/fields.py b/beanie/odm/fields.py index c93cc892..839ea4c2 100644 --- a/beanie/odm/fields.py +++ b/beanie/odm/fields.py @@ -25,7 +25,9 @@ from bson import DBRef, ObjectId from bson.errors import InvalidId -from pydantic import BaseModel +from pydantic import ( + BaseModel, +) from pymongo import ASCENDING, IndexModel from beanie.odm.enums import SortDirection @@ -58,7 +60,6 @@ from pydantic_core.core_schema import ( ValidationInfo, simple_ser_schema, - str_schema, ) else: from pydantic.fields import ModelField # type: ignore @@ -159,7 +160,16 @@ def __get_pydantic_core_schema__( ) -> CoreSchema: # type: ignore return core_schema.json_or_python_schema( python_schema=plain_validator(cls.validate), - json_schema=str_schema(), + json_schema=plain_validator( + cls.validate, + metadata={ + "pydantic_js_input_core_schema": core_schema.str_schema( + pattern="^[0-9a-f]{24}$", + min_length=24, + max_length=24, + ) + }, + ), serialization=core_schema.plain_serializer_function_ser_schema( lambda instance: str(instance), when_used="json" ), diff --git a/tests/odm/test_id.py b/tests/odm/test_id.py index b98d5450..62550f60 100644 --- a/tests/odm/test_id.py +++ b/tests/odm/test_id.py @@ -1,5 +1,9 @@ from uuid import UUID +import pytest +from pydantic import BaseModel + +from beanie import PydanticObjectId from tests.odm.models import DocumentWithCustomIdInt, DocumentWithCustomIdUUID @@ -15,3 +19,25 @@ async def test_integer_id(): await doc.insert() new_doc = await DocumentWithCustomIdInt.get(doc.id) assert isinstance(new_doc.id, int) + + +class A(BaseModel): + id: PydanticObjectId + + +async def test_pydantic_object_id_validation_json(): + deserialized = A.model_validate_json('{"id": "5eb7cf5a86d9755df3a6c593"}') + assert isinstance(deserialized.id, PydanticObjectId) + assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593" + assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593") + + +@pytest.mark.parametrize( + "data", + ["5eb7cf5a86d9755df3a6c593", PydanticObjectId("5eb7cf5a86d9755df3a6c593")], +) +async def test_pydantic_object_id_serialization(data): + deserialized = A(**{"id": data}) + assert isinstance(deserialized.id, PydanticObjectId) + assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593" + assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")