Skip to content

fix: fix DocList schema when using Pydantic V2 #1876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a864589
fix: try to fix doclist schema
Mar 8, 2024
951679c
chore: push tmp changes
Mar 8, 2024
febea8d
fix: make DocList properly a Generic
Mar 11, 2024
11eeb6a
fix: make DocList properly a Generic
Mar 11, 2024
3f4707f
Merge branch 'try-fix-doclist-schema' of https://github.com/docarray/…
Mar 11, 2024
a07919d
fix: undo some changes
Mar 11, 2024
5ec78bd
test: test fixes
Mar 12, 2024
50e191a
test: set test
Mar 13, 2024
949f185
fix: full test for fastapi
Mar 14, 2024
d626453
fix: try to make generic
Mar 14, 2024
f33c55f
Merge branch 'main' into try-fix-doclist-schema
Feb 25, 2025
4180a4e
test: fix some tests
Feb 25, 2025
d9782a4
Merge branch 'try-fix-doclist-schema' of https://github.com/docarray/…
Feb 25, 2025
fb13c65
fix: small tests
Feb 25, 2025
bfee17a
test: fix test
Feb 25, 2025
2f80c56
fix: small test fix
Feb 25, 2025
5ff2f68
test: change tests
Mar 12, 2025
00245aa
fix: try to fix all pydantic-v1 tests
Mar 13, 2025
40f9420
fix: fix small dynamic creation
Mar 13, 2025
cb9cc94
test: further fix tests
Mar 13, 2025
4f42c0b
test: new iteration
Mar 13, 2025
cfa7ea5
test: more tests
Mar 14, 2025
659b992
fix: tests
Mar 14, 2025
9c94394
test: fix schemas from new model
Mar 17, 2025
f422e88
fix: improve cleaning refs
Mar 18, 2025
c7e9bf6
fix: get from definitions
Mar 18, 2025
5d3e73f
fix: remove unneeded argument
Mar 18, 2025
e12c03a
fix: fix update
Mar 18, 2025
a6fef44
fix: handle ID optional
Mar 18, 2025
559c0ee
fix: remove problematic action
Mar 18, 2025
8265d95
fix: fix resp as json
Mar 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Pre-release (.devN)
run: |
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
pip install poetry
pip install poetry==1.7.1
./scripts/release.sh
env:
PYPI_USERNAME: ${{ secrets.TWINE_USERNAME }}
Expand All @@ -35,20 +35,16 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Get changed files
id: changed-files-specific
uses: tj-actions/changed-files@v41
with:
files: |
README.md
fetch-depth: 2

- name: Check if README is modified
id: step_output
if: steps.changed-files-specific.outputs.any_changed == 'true'
run: |
echo "readme_changed=true" >> $GITHUB_OUTPUT
if git diff --name-only HEAD^ HEAD | grep -q "README.md"; then
echo "readme_changed=true" >> $GITHUB_OUTPUT
else
echo "readme_changed=false" >> $GITHUB_OUTPUT
fi

publish-docarray-org:
needs: check-readme-modification
Expand Down
20 changes: 10 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Lint with ruff
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install

# stop the build if there are Python syntax errors or undefined names
Expand All @@ -44,7 +44,7 @@ jobs:
- name: check black
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --only dev
poetry run black --check .

Expand All @@ -62,7 +62,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --without dev
poetry run pip install tensorflow==2.12.0
poetry run pip install jax
Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
Expand Down Expand Up @@ -156,7 +156,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
Expand Down Expand Up @@ -204,7 +204,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -253,7 +253,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -302,7 +302,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip install protobuf==3.20.0
Expand Down Expand Up @@ -351,7 +351,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip uninstall -y torch
Expand Down Expand Up @@ -398,7 +398,7 @@ jobs:
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
poetry install --all-extras
poetry run pip uninstall -y torch
poetry run pip install torch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_only_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
run: |
npm i -g netlify-cli
python -m pip install --upgrade pip
python -m pip install poetry
python -m pip install poetry==1.7.1
python -m poetry config virtualenvs.create false && python -m poetry install --no-interaction --no-ansi --all-extras

cd docs
Expand Down
54 changes: 54 additions & 0 deletions docarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,60 @@
from docarray.array import DocList, DocVec
from docarray.base_doc.doc import BaseDoc
from docarray.utils._internal.misc import _get_path_from_docarray_root_level
from docarray.utils._internal.pydantic import is_pydantic_v2


def unpickle_doclist(doc_type, b):
return DocList[doc_type].from_bytes(b, protocol="protobuf")


def unpickle_docvec(doc_type, tensor_type, b):
return DocVec[doc_type].from_bytes(b, protocol="protobuf", tensor_type=tensor_type)


if is_pydantic_v2:
# Register the pickle functions
def register_serializers():
import copyreg
from functools import partial

unpickle_doc_fn = partial(BaseDoc.from_bytes, protocol="protobuf")

def pickle_doc(doc):
b = doc.to_bytes(protocol='protobuf')
return unpickle_doc_fn, (doc.__class__, b)

# Register BaseDoc serialization
copyreg.pickle(BaseDoc, pickle_doc)

# For DocList, we need to hook into __reduce__ since it's a generic

def pickle_doclist(doc_list):
b = doc_list.to_bytes(protocol='protobuf')
doc_type = doc_list.doc_type
return unpickle_doclist, (doc_type, b)

# Replace DocList.__reduce__ with a method that returns the correct format
def doclist_reduce(self):
return pickle_doclist(self)

DocList.__reduce__ = doclist_reduce

# For DocVec, we need to hook into __reduce__ since it's a generic

def pickle_docvec(doc_vec):
b = doc_vec.to_bytes(protocol='protobuf')
doc_type = doc_vec.doc_type
tensor_type = doc_vec.tensor_type
return unpickle_docvec, (doc_type, tensor_type, b)

# Replace DocList.__reduce__ with a method that returns the correct format
def docvec_reduce(self):
return pickle_docvec(self)

DocVec.__reduce__ = docvec_reduce

register_serializers()

__all__ = ['BaseDoc', 'DocList', 'DocVec']

Expand Down
40 changes: 31 additions & 9 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
from docarray.utils._internal.pydantic import is_pydantic_v2

if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
Expand Down Expand Up @@ -73,8 +74,19 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped

class _DocArrayTyped(cls): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
if not is_pydantic_v2:

class _DocArrayTyped(cls): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(
Type[BaseDocWithoutId], item
)

else:

class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(
Type[BaseDocWithoutId], item
)

for field in _DocArrayTyped.doc_type._docarray_fields().keys():

Expand All @@ -99,14 +111,24 @@ def _setter(self, value):
setattr(_DocArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item

# The global scope and qualname need to refer to this class a unique name.
# Otherwise, creating another _DocArrayTyped will overwrite this one.
change_cls_name(
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)

cls.__typed_da__[cls][item] = _DocArrayTyped
# # The global scope and qualname need to refer to this class a unique name.
# # Otherwise, creating another _DocArrayTyped will overwrite this one.
if not is_pydantic_v2:
change_cls_name(
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)

cls.__typed_da__[cls][item] = _DocArrayTyped
else:
change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
if sys.version_info < (3, 12):
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
_DocArrayTyped, item
) # type: ignore
# this do nothing that checking that item is valid type var or str
# Keep the approach in #1147 to be compatible with lower versions of Python.
else:
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore
return cls.__typed_da__[cls][item]

@overload
Expand Down
25 changes: 17 additions & 8 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Union,
cast,
overload,
Callable,
)

from pydantic import parse_obj_as
Expand All @@ -28,7 +29,6 @@
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from docarray.utils._internal._typing import safe_issubclass
Expand All @@ -45,10 +45,7 @@


class DocList(
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinDocList,
AnyDocArray[T_doc],
ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -357,8 +354,20 @@ def __repr__(self):

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.general_plain_validator_function(
cls.validate,
instance_schema = core_schema.is_instance_schema(cls)
args = getattr(source, '__args__', None)
if args:
sequence_t_schema = handler(Sequence[args[0]])
else:
sequence_t_schema = handler(Sequence)

def validate_fn(v, info):
# input has already been validated
return cls(v, validate_input_docs=False)

non_instance_schema = core_schema.with_info_after_validator_function(
validate_fn, sequence_t_schema
)
return core_schema.union_schema([instance_schema, non_instance_schema])
1 change: 0 additions & 1 deletion docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def to_bytes(
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""

with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
Expand Down
6 changes: 4 additions & 2 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _check_doc_field_not_none(field_name, doc):
if safe_issubclass(tensor.__class__, tensor_type):
field_type = tensor_type

if isinstance(field_type, type):
if isinstance(field_type, type) or safe_issubclass(field_type, AnyDocArray):
if tf_available and safe_issubclass(field_type, TensorFlowTensor):
# tf.Tensor does not allow item assignment, therefore the
# optimized way
Expand Down Expand Up @@ -335,7 +335,9 @@ def _docarray_validate(
return cast(T, value.to_doc_vec())
else:
raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}')
elif isinstance(value, DocList.__class_getitem__(cls.doc_type)):
elif not is_pydantic_v2 and isinstance(
value, DocList.__class_getitem__(cls.doc_type)
):
return cast(T, value.to_doc_vec())
elif isinstance(value, Sequence):
return cls(value)
Expand Down
10 changes: 7 additions & 3 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,13 @@ def _exclude_doclist(
from docarray.array.any_array import AnyDocArray

type_ = self._get_field_annotation(field)
if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
doclist_exclude_fields.append(field)
if is_pydantic_v2:
# Conservative when touching pydantic v1 logic
if safe_issubclass(type_, AnyDocArray):
doclist_exclude_fields.append(field)
else:
if isinstance(type_, type) and safe_issubclass(type_, AnyDocArray):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
Expand Down Expand Up @@ -480,7 +485,6 @@ def model_dump( # type: ignore
warnings: bool = True,
) -> Dict[str, Any]:
def _model_dump(doc):

(
exclude_,
original_exclude,
Expand Down
4 changes: 1 addition & 3 deletions docarray/base_doc/mixins/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
field_type = doc._get_field_annotation(field_name)

if isinstance(field_type, type) and safe_issubclass(
field_type, DocList
):
if safe_issubclass(field_type, DocList):
nested_docarray_fields.append(field_name)
else:
origin = get_origin(field_type)
Expand Down
8 changes: 4 additions & 4 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,12 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
dict: 'object',
}

for type in elastic_py_types.keys():
if safe_issubclass(python_type, type):
for t in elastic_py_types.keys():
if safe_issubclass(python_type, t):
self._logger.info(
f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"'
f'Mapped Python type {python_type} to database type "{elastic_py_types[t]}"'
)
return elastic_py_types[type]
return elastic_py_types[t]

err_msg = f'Unsupported column type for {type(self)}: {python_type}'
self._logger.error(err_msg)
Expand Down
Loading
Loading