Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix DocList schema when using Pydantic V2 #1876

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Pre-release (.devN)
run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
pip install poetry
pip install poetry==1.7.1
./scripts/release.sh
env:
PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
Expand Down
20 changes: 10 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Lint with ruff
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install

# stop the build if there are Python syntax errors or undefined names
Expand All @@ -44,7 +44,7 @@ jobs:
- name: check black
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --only dev
poetry run black --check .

Expand All @@ -62,7 +62,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --without dev
poetry run pip install tensorflow==2.12.0
poetry run pip install jax
Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
Expand Down Expand Up @@ -156,7 +156,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
Expand Down Expand Up @@ -204,7 +204,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -253,7 +253,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -302,7 +302,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -351,7 +351,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip uninstall -y torch
Expand Down Expand Up @@ -398,7 +398,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip uninstall -y torch
poetry run pip install torch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_only_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
run: |
npm i -g netlify-cli
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras

cd docs
Expand Down
26 changes: 17 additions & 9 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
from docarray.utils._internal.pydantic import is_pydantic_v2

if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
Expand Down Expand Up @@ -73,7 +74,7 @@
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped

class _DocArrayTyped(cls): # type: ignore
class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)

for field in _DocArrayTyped.doc_type._docarray_fields().keys():
Expand All @@ -99,14 +100,21 @@
setattr(_DocArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item

# The global scope and qualname need to refer to this class a unique name.
# Otherwise, creating another _DocArrayTyped will overwrite this one.
change_cls_name(
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)

cls.__typed_da__[cls][item] = _DocArrayTyped

# # The global scope and qualname need to refer to this class a unique name.
# # Otherwise, creating another _DocArrayTyped will overwrite this one.
if not is_pydantic_v2:
change_cls_name(_DocArrayTyped, f'{cls.__name__}[{item}]', globals())
cls.__typed_da__[cls][item] = _DocArrayTyped
else:
change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
if sys.version_info < (3, 12):
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
_DocArrayTyped, item
) # type: ignore
# this do nothing that checking that item is valid type var or str
# Keep the approach in #1147 to be compatible with lower versions of Python.
else:
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore

Check warning on line 117 in docarray/array/any_array.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/any_array.py#L117

Added line #L117 was not covered by tests
return cls.__typed_da__[cls][item]

@overload
Expand Down
25 changes: 17 additions & 8 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Union,
cast,
overload,
Callable,
)

from pydantic import parse_obj_as
Expand All @@ -28,7 +29,6 @@
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from docarray.utils._internal._typing import safe_issubclass
Expand All @@ -45,10 +45,7 @@


class DocList(
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinDocList,
AnyDocArray[T_doc],
ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -357,8 +354,20 @@

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.general_plain_validator_function(
cls.validate,
instance_schema = core_schema.is_instance_schema(cls)
args = getattr(source, '__args__', None)
if args:
sequence_t_schema = handler(Sequence[args[0]])

Check warning on line 362 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L359-L362

Added lines #L359 - L362 were not covered by tests
else:
sequence_t_schema = handler(Sequence)

Check warning on line 364 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L364

Added line #L364 was not covered by tests

def validate_fn(v, info):

Check warning on line 366 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L366

Added line #L366 was not covered by tests
# input has already been validated
return cls(v, validate_input_docs=False)

Check warning on line 368 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L368

Added line #L368 was not covered by tests

non_instance_schema = core_schema.with_info_after_validator_function(

Check warning on line 370 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L370

Added line #L370 was not covered by tests
validate_fn, sequence_t_schema
)
return core_schema.union_schema([instance_schema, non_instance_schema])

Check warning on line 373 in docarray/array/doc_list/doc_list.py

View check run for this annotation

Codecov / codecov/patch

docarray/array/doc_list/doc_list.py#L373

Added line #L373 was not covered by tests
2 changes: 1 addition & 1 deletion docarray/typing/bytes/base_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: 'GetCoreSchemaHandler'
) -> 'core_schema.CoreSchema':
return core_schema.general_after_validator_function(
return core_schema.with_info_after_validator_function(

Check warning on line 65 in docarray/typing/bytes/base_bytes.py

View check run for this annotation

Codecov / codecov/patch

docarray/typing/bytes/base_bytes.py#L65

Added line #L65 was not covered by tests
cls.validate,
core_schema.bytes_schema(),
)
2 changes: 1 addition & 1 deletion docarray/typing/id.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: 'GetCoreSchemaHandler'
) -> core_schema.CoreSchema:
return core_schema.general_plain_validator_function(
return core_schema.with_info_plain_validator_function(
cls.validate,
)

Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _docarray_to_ndarray(self) -> np.ndarray:
def __get_pydantic_core_schema__(
cls, _source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.general_plain_validator_function(
return core_schema.with_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(
function=orjson_dumps,
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: Optional['GetCoreSchemaHandler'] = None
) -> core_schema.CoreSchema:
return core_schema.general_after_validator_function(
return core_schema.with_info_after_validator_function(

Check warning on line 59 in docarray/typing/url/any_url.py

View check run for this annotation

Codecov / codecov/patch

docarray/typing/url/any_url.py#L59

Added line #L59 was not covered by tests
cls._docarray_validate,
core_schema.str_schema(),
)
Expand Down
63 changes: 63 additions & 0 deletions tests/integrations/externals/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from docarray.base_doc import DocArrayResponse
from docarray.documents import ImageDoc, TextDoc
from docarray.typing import NdArray
from docarray.utils._internal.pydantic import is_pydantic_v2


@pytest.mark.asyncio
Expand Down Expand Up @@ -135,3 +136,65 @@ async def func(fastapi_docs: List[ImageDoc]) -> List[ImageDoc]:
docs = DocList[ImageDoc].from_json(response.content.decode())
assert len(docs) == 2
assert docs[0].tensor.shape == (3, 224, 224)


@pytest.mark.asyncio
@pytest.mark.skipif(
not is_pydantic_v2, reason='Behavior is only available for Pydantic V2'
)
async def test_doclist_directly():
from fastapi import Body

doc = ImageDoc(tensor=np.zeros((3, 224, 224)))
docs = DocList[ImageDoc]([doc, doc])

app = FastAPI()

@app.post("/doc/", response_class=DocArrayResponse)
async def func_embed_false(
fastapi_docs: DocList[ImageDoc] = Body(embed=False),
) -> DocList[ImageDoc]:
return fastapi_docs

@app.post("/doc_default/", response_class=DocArrayResponse)
async def func_default(fastapi_docs: DocList[ImageDoc]) -> DocList[ImageDoc]:
return fastapi_docs

@app.post("/doc_embed/", response_class=DocArrayResponse)
async def func_embed_true(
fastapi_docs: DocList[ImageDoc] = Body(embed=True),
) -> DocList[ImageDoc]:
return fastapi_docs

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/doc/", data=docs.to_json())
response_default = await ac.post("/doc_default/", data=docs.to_json())
response_embed = await ac.post(
"/doc_embed/",
json={
'fastapi_docs': [
{'tensor': doc.tensor.tolist()},
{'tensor': doc.tensor.tolist()},
]
},
)
resp_doc = await ac.get("/docs")
resp_redoc = await ac.get("/redoc")

assert response.status_code == 200
assert response_default.status_code == 200
assert response_embed.status_code == 200
assert resp_doc.status_code == 200
assert resp_redoc.status_code == 200

docs = DocList[ImageDoc].from_json(response.content.decode())
assert len(docs) == 2
assert docs[0].tensor.shape == (3, 224, 224)

docs_default = DocList[ImageDoc].from_json(response_default.content.decode())
assert len(docs_default) == 2
assert docs_default[0].tensor.shape == (3, 224, 224)

docs_embed = DocList[ImageDoc].from_json(response_embed.content.decode())
assert len(docs_embed) == 2
assert docs_embed[0].tensor.shape == (3, 224, 224)
2 changes: 2 additions & 0 deletions tests/units/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ def test_validate_list_dict():
dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]

# docs = DocList[Image]([Image(url=image['url'], tensor=image['tensor']) for image in images])

docs = parse_obj_as(DocList[Image], images)

assert docs.url == [
Expand Down
20 changes: 20 additions & 0 deletions tests/units/array/test_doclist_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from docarray import BaseDoc, DocList
from docarray.utils._internal.pydantic import is_pydantic_v2


@pytest.mark.skipif(not is_pydantic_v2, reason='Feature only available for Pydantic V2')
def test_schema_nested():
# check issue https://github.com/docarray/docarray/issues/1521

class Doc1Test(BaseDoc):
aux: str

class DocDocTest(BaseDoc):
docs: DocList[Doc1Test]

assert 'Doc1Test' in DocDocTest.schema()['$defs']
d = DocDocTest(docs=DocList[Doc1Test]([Doc1Test(aux='aux')]))

assert type(d.docs) == DocList[Doc1Test]
assert d.docs.aux == ['aux']
Loading