diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index c7b696e7..7901e80d 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -21,7 +21,7 @@ _is_collection, _is_mapping, _is_new_type, _is_optional, _isinstance_safe, _get_type_arg_param, - _get_type_args, + _get_type_args, _is_counter, _NO_ARGS, _issubclass_safe, _is_tuple) @@ -271,7 +271,7 @@ def _decode_generic(type_, value, infer_missing): res = type_(value) # FIXME this is a hack to fix a deeper underlying issue. A refactor is due. elif _is_collection(type_): - if _is_mapping(type_): + if _is_mapping(type_) and not _is_counter(type_): k_type, v_type = _get_type_args(type_, (Any, Any)) # a mapping type has `.keys()` and `.values()` # (see collections.abc) @@ -284,6 +284,8 @@ def _decode_generic(type_, value, infer_missing): xs = _decode_items(types[0], value, infer_missing) else: xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing) + elif _is_counter(type_): + xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values())) else: xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing) diff --git a/dataclasses_json/utils.py b/dataclasses_json/utils.py index 0927cd01..f63c142a 100644 --- a/dataclasses_json/utils.py +++ b/dataclasses_json/utils.py @@ -2,7 +2,7 @@ import sys from datetime import datetime, timezone from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple, - Union, cast) + Union, cast, Counter) def _get_type_cons(type_): @@ -142,6 +142,10 @@ def _is_optional(type_): type_ is Any) +def _is_counter(type_): + return _issubclass_safe(_get_type_origin(type_), Counter) + + def _is_mapping(type_): return _issubclass_safe(_get_type_origin(type_), Mapping) diff --git a/tests/entities.py b/tests/entities.py index 480c8a07..dd90b8fd 100644 --- a/tests/entities.py +++ b/tests/entities.py @@ -3,6 +3,7 @@ from datetime import datetime from decimal import Decimal from typing import (Collection, + Counter, Deque, Dict, FrozenSet, @@ -370,3 +371,9 @@ class DataClassWithNestedOptional: @dataclass class DataClassWithNestedDictWithTupleKeys: a: Dict[Tuple[int], int] + + +@dataclass_json +@dataclass +class DataClassWithCounter: + c: Counter[str] diff --git a/tests/test_collections.py b/tests/test_collections.py index e8c54333..9c872dda 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,4 +1,4 @@ -from collections import deque +from collections import Counter, deque from tests.entities import (DataClassIntImmutableDefault, DataClassMutableDefaultDict, @@ -19,7 +19,8 @@ DataClassWithFrozenSetUnbound, DataClassWithDequeCollections, DataClassWithTuple, DataClassWithTupleUnbound, - DataClassWithUnionIntNone, MyCollection) + DataClassWithUnionIntNone, MyCollection, + DataClassWithCounter) class TestEncoder: @@ -112,6 +113,10 @@ def test_mutable_default_list(self): def test_mutable_default_dict(self): assert DataClassMutableDefaultDict().to_json() == '{"xs": {}}' + def test_counter(self): + assert DataClassWithCounter( + c=Counter('foo')).to_json() == '{"c": {"f": 1, "o": 2}}' + class TestDecoder: def test_list(self): @@ -235,3 +240,7 @@ def test_mutable_default_dict(self): == DataClassMutableDefaultDict()) assert (DataClassMutableDefaultDict.from_json('{}', infer_missing=True) == DataClassMutableDefaultDict()) + + def test_counter(self): + assert DataClassWithCounter.from_json('{"c": {"f": 1, "o": 2}}') == \ + DataClassWithCounter(c=Counter('foo'))