Skip to content

Commit

Permalink
fix: support generic dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
PJCampi committed Apr 18, 2024
1 parent 8512afc commit 240a956
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
9 changes: 7 additions & 2 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
_get_type_arg_param,
_get_type_args, _is_counter,
_NO_ARGS,
_issubclass_safe, _is_tuple)
_issubclass_safe, _is_tuple,
_is_generic_dataclass)

Json = Union[dict, list, str, int, float, bool, None]

Expand Down Expand Up @@ -259,8 +260,9 @@ def _is_supported_generic(type_):
return False
not_str = not _issubclass_safe(type_, str)
is_enum = _issubclass_safe(type_, Enum)
is_generic_dataclass = _is_generic_dataclass(type_)
return (not_str and _is_collection(type_)) or _is_optional(
type_) or is_union_type(type_) or is_enum
type_) or is_union_type(type_) or is_enum or is_generic_dataclass


def _decode_generic(type_, value, infer_missing):
Expand Down Expand Up @@ -298,6 +300,9 @@ def _decode_generic(type_, value, infer_missing):
except (TypeError, AttributeError):
pass
res = materialize_type(xs)
elif _is_generic_dataclass(type_):
origin = _get_type_origin(type_)
res = _decode_dataclass(origin, value, infer_missing)
else: # Optional or Union
_args = _get_type_args(type_)
if _args is _NO_ARGS:
Expand Down
5 changes: 5 additions & 0 deletions dataclasses_json/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from datetime import datetime, timezone
from collections import Counter
from dataclasses import is_dataclass # type: ignore
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
Union, cast)

Expand Down Expand Up @@ -164,6 +165,10 @@ def _is_nonstr_collection(type_):
and not _issubclass_safe(type_, str))


def _is_generic_dataclass(type_):
return is_dataclass(_get_type_origin(type_))


def _timestamp_to_dt_aware(timestamp: float):
tz = datetime.now(timezone.utc).astimezone().tzinfo
dt = datetime.fromtimestamp(timestamp, tz=tz)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_generic_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Generic, TypeVar

from dataclasses_json import dataclass_json

T = TypeVar("T")


@dataclass_json
@dataclass
class NestedClass(Generic[T]):
value: T


@dataclass_json
@dataclass
class MyClass(Generic[T]):
nested: NestedClass[T]


def test_dataclass_with_generic_dataclass_field():
a = MyClass(nested=NestedClass(value="value"))
assert MyClass.from_json(a.to_json()) == a

0 comments on commit 240a956

Please sign in to comment.