Skip to content

Commit

Permalink
fix: recoginze optional link field's type annotation in the form of u…
Browse files Browse the repository at this point in the history
…nion type expression with bitwise or operator (#952)

Co-authored-by: IterableTrucks <IterableTrucks@localhost>
  • Loading branch information
IterableTrucks and IterableTrucks committed Jun 18, 2024
1 parent dcd73a7 commit a37021f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
23 changes: 18 additions & 5 deletions beanie/odm/utils/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
else:
from typing_extensions import get_args, get_origin

if sys.version_info >= (3, 10):
from types import UnionType as TypesUnionType
else:
TypesUnionType = ()

import importlib
import inspect
from typing import ( # type: ignore
Expand Down Expand Up @@ -262,13 +267,21 @@ def detect_link(
)

# Check if annotation is Optional[custom class] or Optional[List[custom class]]
elif origin is Union and len(args) == 2 and args[1] is type(None):
optional_origin = get_origin(args[0])
optional_args = get_args(args[0])
elif (
(origin is Union or origin is TypesUnionType)
and len(args) == 2
and type(None) in args
):
if args[1] is type(None):
optional = args[0]
else:
optional = args[1]
optional_origin = get_origin(optional)
optional_args = get_args(optional)

if (
isinstance(args[0], _GenericAlias)
and args[0].__origin__ is cls
isinstance(optional, _GenericAlias)
and optional.__origin__ is cls
):
if cls is Link:
return LinkInfo(
Expand Down
19 changes: 19 additions & 0 deletions tests/odm/documents/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
DocumentTestModelWithSimpleIndex,
DocumentWithCustomInit,
DocumentWithIndexMerging2,
DocumentWithLink,
DocumentWithListLink,
DocumentWithUnionTypeExpressionOptionalBackLink,
)


Expand Down Expand Up @@ -335,3 +338,19 @@ class Settings:
)

await db.drop_collection("sample")


async def test_init_document_with_union_type_expression_optional_back_link(db):
await init_beanie(
database=db,
document_models=[
DocumentWithUnionTypeExpressionOptionalBackLink,
DocumentWithListLink,
DocumentWithLink,
],
)

assert DocumentWithUnionTypeExpressionOptionalBackLink.get_link_fields().keys() == {
"back_link_list",
"back_link",
}
29 changes: 29 additions & 0 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import sys
from enum import Enum
from ipaddress import (
IPv4Address,
Expand Down Expand Up @@ -60,6 +61,16 @@
if IS_PYDANTIC_V2:
from pydantic import RootModel, validate_call

if sys.version_info >= (3, 10):

def type_union(A, B):
return A | B

else:

def type_union(A, B):
return Union[A, B]


class Color:
def __init__(self, value):
Expand Down Expand Up @@ -952,6 +963,24 @@ class DocumentWithOptionalListBackLink(Document):
i: int = 1


class DocumentWithUnionTypeExpressionOptionalBackLink(Document):
if IS_PYDANTIC_V2:
back_link_list: type_union(
List[BackLink[DocumentWithListLink]], None
) = Field(json_schema_extra={"original_field": "link"})
back_link: type_union(BackLink[DocumentWithLink], None) = Field(
json_schema_extra={"original_field": "link"}
)
else:
back_link_list: type_union(
List[BackLink[DocumentWithListLink]], None
) = Field(original_field="link")
back_link: type_union(BackLink[DocumentWithLink], None) = Field(
original_field="link"
)
i: int = 1


class DocumentToBeLinked(Document):
s: str = "TEST"

Expand Down

0 comments on commit a37021f

Please sign in to comment.