Skip to content

Commit

Permalink
fix: DTO factory narrowed with a generic alias.
Browse files Browse the repository at this point in the history
This PR is an attempt at handling DTOs that are narrowed with a `_GenericAlias` of a type supported by the DTO factory type.

Closes #2500
  • Loading branch information
peterschutt committed Nov 28, 2023
1 parent 84710a1 commit b69ae5f
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 14 deletions.
3 changes: 2 additions & 1 deletion litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from msgspec import UNSET, Struct, UnsetType, convert, defstruct, field
from typing_extensions import get_origin

from litestar.dto._types import (
CollectionType,
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
rename_fields=self.dto_factory.config.rename_fields,
)
self.transfer_model_type = self.create_transfer_model_type(
model_name=model_type.__name__, field_definitions=self.parsed_field_definitions
model_name=(get_origin(model_type) or model_type).__name__, field_definitions=self.parsed_field_definitions
)
self.dto_data_type: type[DTOData] | None = None

Expand Down
7 changes: 5 additions & 2 deletions litestar/dto/base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import getmodule
from typing import TYPE_CHECKING, Collection, Generic, TypeVar

from typing_extensions import NotRequired, TypedDict, get_type_hints
from typing_extensions import NotRequired, TypedDict

from litestar.dto._backend import DTOBackend
from litestar.dto._codegen_backend import DTOCodegenBackend
Expand All @@ -17,6 +17,7 @@
from litestar.types.builtin_types import NoneType
from litestar.types.composite_types import TypeEncodersMap
from litestar.typing import FieldDefinition
from litestar.utils.typing import get_type_hints_with_generics_resolved

if TYPE_CHECKING:
from typing import Any, ClassVar, Generator
Expand Down Expand Up @@ -267,7 +268,9 @@ def get_model_type_hints(

return {
k: FieldDefinition.from_kwarg(annotation=v, name=k)
for k, v in get_type_hints(model_type, localns=namespace, include_extras=True).items()
for k, v in get_type_hints_with_generics_resolved(
model_type, localns=namespace, include_extras=True
).items()
}

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions litestar/dto/dataclass_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import MISSING, fields, replace
from typing import TYPE_CHECKING, Generic, TypeVar

from typing_extensions import get_origin

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
Expand All @@ -29,7 +31,8 @@ class DataclassDTO(AbstractDTO[T], Generic[T]):
def generate_field_definitions(
cls, model_type: type[DataclassProtocol]
) -> Generator[DTOFieldDefinition, None, None]:
dc_fields = {f.name: f for f in fields(model_type)}
model_origin = get_origin(model_type) or model_type
dc_fields = {f.name: f for f in fields(model_origin)}
for key, field_definition in cls.get_model_type_hints(model_type).items():
if not (dc_field := dc_fields.get(key)):
continue
Expand All @@ -41,7 +44,7 @@ def generate_field_definitions(
field_definition=field_definition,
default_factory=default_factory,
dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()),
model_name=model_type.__name__,
model_name=model_origin.__name__,
),
name=key,
default=default,
Expand Down
25 changes: 24 additions & 1 deletion litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from inspect import Parameter, Signature
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast
from typing import ( # type: ignore[attr-defined]
Any,
AnyStr,
Callable,
ClassVar,
Collection,
ForwardRef,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
_GenericAlias, # pyright: ignore
cast,
)

from msgspec import UnsetType
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
Expand Down Expand Up @@ -442,6 +456,15 @@ def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool:
if self.origin in UnionTypes:
return all(t.is_subclass_of(cl) for t in self.inner_types)

if isinstance(self.annotation, _GenericAlias) and self.origin not in (ClassVar, Literal):
cl_args = get_args(cl)
cl_origin = get_origin(cl) or cl
return (
(len(cl_args) == len(self.args) if cl_args else True)
and issubclass(self.origin, cl_origin)
and all(t.is_subclass_of(cl_arg) for t, cl_arg in zip(self.inner_types, cl_args))
)

return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl)

if self.annotation is AnyStr:
Expand Down
4 changes: 3 additions & 1 deletion litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def get_type_hints_with_generics_resolved(
if origin is None:
# Implies the generic types have not been specified in the annotation
type_hints = get_type_hints(annotation, globalns=globalns, localns=localns, include_extras=include_extras)
typevar_map = {p: p for p in annotation.__parameters__}
if not (parameters := getattr(annotation, "__parameters__", None)):
return type_hints
typevar_map = {p: p for p in parameters}
else:
type_hints = get_type_hints(origin, globalns=globalns, localns=localns, include_extras=include_extras)
# the __parameters__ is only available on the origin itself and not the annotation
Expand Down
37 changes: 32 additions & 5 deletions tests/unit/test_dto/test_factory/test_base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union

import pytest
from typing_extensions import Annotated

from litestar import Request
from litestar.dto import DataclassDTO, DTOConfig
from litestar.exceptions.dto_exceptions import InvalidAnnotationException
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition

from . import Model
Expand All @@ -19,7 +20,8 @@

from litestar.dto._backend import DTOBackend

T = TypeVar("T", bound=Model)
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound=Model)


def get_backend(dto_type: type[DataclassDTO[Any]]) -> DTOBackend:
Expand Down Expand Up @@ -77,7 +79,7 @@ def test_extra_annotated_metadata_ignored() -> None:

def test_overwrite_config() -> None:
first = DTOConfig(exclude={"a"})
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore
second = DTOConfig(exclude={"b"})
dto = generic_dto[Annotated[Model, second]] # pyright: ignore
assert dto.config is second
Expand All @@ -86,13 +88,13 @@ def test_overwrite_config() -> None:
def test_existing_config_not_overwritten() -> None:
assert getattr(DataclassDTO, "_config", None) is None
first = DTOConfig(exclude={"a"})
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore
dto = generic_dto[Model] # pyright: ignore
assert dto.config is first


def test_config_assigned_via_subclassing() -> None:
class CustomGenericDTO(DataclassDTO[T]):
class CustomGenericDTO(DataclassDTO[ModelT]):
config = DTOConfig(exclude={"a"})

concrete_dto = CustomGenericDTO[Model]
Expand Down Expand Up @@ -161,3 +163,28 @@ class SubType(Model):
assert (
dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore
)


def test_type_narrowing_with_generic_type() -> None:
@dataclass
class Foo(Generic[T]):
foo: T

hints = DataclassDTO.get_model_type_hints(Foo[int])
assert hints == {
"foo": FieldDefinition(
raw=int,
annotation=int,
type_wrappers=(),
origin=None,
args=(),
metadata=(),
instantiable_origin=None,
safe_generic_origin=None,
inner_types=(),
default=Empty,
extra={},
kwarg_definition=None,
name="foo",
)
}
27 changes: 25 additions & 2 deletions tests/unit/test_dto/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

from typing import Dict
from dataclasses import dataclass
from typing import Dict, Generic, TypeVar
from unittest.mock import MagicMock

import pytest

from litestar import Controller, Litestar, Router, post
from litestar.config.app import ExperimentalFeatures
from litestar.dto import AbstractDTO, DTOConfig
from litestar.dto import AbstractDTO, DataclassDTO, DTOConfig, DTOData
from litestar.dto._backend import DTOBackend
from litestar.dto._codegen_backend import DTOCodegenBackend
from litestar.testing import create_test_client

from . import Model

T = TypeVar("T")


@pytest.fixture()
def experimental_features(use_experimental_dto_backend: bool) -> list[ExperimentalFeatures] | None:
Expand Down Expand Up @@ -153,3 +156,23 @@ def handler(data: Model) -> Model:

backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr]
assert isinstance(backend, DTOBackend)


def test_dto_for_generic_model() -> None:
@dataclass
class Foo(Generic[T]):
foo: T

FooDTO = DataclassDTO[Foo[int]]

@post("/foo", dto=FooDTO, signature_types=[Foo])
async def foo_handler(data: DTOData[Foo[int]]) -> Foo[int]:
return data.create_instance()

with create_test_client(route_handlers=foo_handler) as client:
response = client.post("/foo", json={"foo": 1})
assert response.status_code == 201
assert response.json() == {"foo": 1}
response = client.post("/foo", json={"foo": "1"})
assert response.status_code == 400
assert response.json() == {"status_code": 400, "detail": "Expected `int`, got `str` - at `$.foo`"}

0 comments on commit b69ae5f

Please sign in to comment.