Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support generic dataclass #525

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
_get_type_arg_param,
_get_type_args, _is_counter,
_NO_ARGS,
_issubclass_safe, _is_tuple)
_issubclass_safe, _is_tuple,
_is_generic_dataclass)

Json = Union[dict, list, str, int, float, bool, None]

Expand Down Expand Up @@ -269,8 +270,9 @@ def _is_supported_generic(type_):
return False
not_str = not _issubclass_safe(type_, str)
is_enum = _issubclass_safe(type_, Enum)
is_generic_dataclass = _is_generic_dataclass(type_)
return (not_str and _is_collection(type_)) or _is_optional(
type_) or is_union_type(type_) or is_enum
type_) or is_union_type(type_) or is_enum or is_generic_dataclass


def _decode_generic(type_, value, infer_missing):
Expand Down Expand Up @@ -308,6 +310,9 @@ def _decode_generic(type_, value, infer_missing):
except (TypeError, AttributeError):
pass
res = materialize_type(xs)
elif _is_generic_dataclass(type_):
origin = _get_type_origin(type_)
res = _decode_dataclass(origin, value, infer_missing)
else: # Optional or Union
_args = _get_type_args(type_)
if _args is _NO_ARGS:
Expand Down
5 changes: 5 additions & 0 deletions dataclasses_json/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from datetime import datetime, timezone
from collections import Counter
from dataclasses import is_dataclass # type: ignore
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
Union, cast)

Expand Down Expand Up @@ -164,6 +165,10 @@ def _is_nonstr_collection(type_):
and not _issubclass_safe(type_, str))


def _is_generic_dataclass(type_):
return is_dataclass(_get_type_origin(type_))


def _timestamp_to_dt_aware(timestamp: float):
tz = datetime.now(timezone.utc).astimezone().tzinfo
dt = datetime.fromtimestamp(timestamp, tz=tz)
Expand Down
83 changes: 83 additions & 0 deletions tests/test_generic_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Generic, TypeVar

import pytest

from dataclasses_json import dataclass_json

S = TypeVar("S")
T = TypeVar("T")


@dataclass_json
@dataclass
class Bar:
value: int


@dataclass_json
@dataclass
class Foo(Generic[T]):
bar: T


@dataclass_json
@dataclass
class Baz(Generic[T]):
foo: Foo[T]


@pytest.mark.parametrize(
"instance_of_t, decodes_successfully",
[
pytest.param(1, True, id="literal"),
pytest.param([1], True, id="literal_list"),
pytest.param({"a": 1}, True, id="map_of_literal"),
pytest.param(datetime(2021, 1, 1), False, id="extended_type"),
pytest.param(Bar(1), False, id="object"),
]
)
def test_dataclass_with_generic_dataclass_field(instance_of_t, decodes_successfully):
foo = Foo(bar=instance_of_t)
baz = Baz(foo=foo)
decoded = Baz[type(instance_of_t)].from_json(baz.to_json())
assert decoded.foo == Foo.from_json(foo.to_json())
if decodes_successfully:
assert decoded == baz
else:
assert decoded != baz


@dataclass_json
@dataclass
class Foo2(Generic[T, S]):
bar1: T
bar2: S


@dataclass_json
@dataclass
class Baz2(Generic[T, S]):
foo2: Foo2[T, S]


@pytest.mark.parametrize(
"instance_of_t, decodes_successfully",
[
pytest.param(1, True, id="literal"),
pytest.param([1], True, id="literal_list"),
pytest.param({"a": 1}, True, id="map_of_literal"),
pytest.param(datetime(2021, 1, 1), False, id="extended_type"),
pytest.param(Bar(1), False, id="object"),
]
)
def test_dataclass_with_multiple_generic_dataclass_fields(instance_of_t, decodes_successfully):
foo2 = Foo2(bar1=instance_of_t, bar2=instance_of_t)
baz = Baz2(foo2=foo2)
decoded = Baz2[type(instance_of_t), type(instance_of_t)].from_json(baz.to_json())
assert decoded.foo2 == Foo2.from_json(foo2.to_json())
if decodes_successfully:
assert decoded == baz
else:
assert decoded != baz
Loading