Skip to content

Commit

Permalink
Support for Counter objects
Browse files Browse the repository at this point in the history
  • Loading branch information
matt035343 committed Aug 12, 2023
1 parent 04ddea8 commit 12b7f4d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
6 changes: 4 additions & 2 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_is_collection, _is_mapping, _is_new_type,
_is_optional, _isinstance_safe,
_get_type_arg_param,
_get_type_args,
_get_type_args, _is_counter,
_NO_ARGS,
_issubclass_safe, _is_tuple)

Expand Down Expand Up @@ -271,7 +271,7 @@ def _decode_generic(type_, value, infer_missing):
res = type_(value)
# FIXME this is a hack to fix a deeper underlying issue. A refactor is due.
elif _is_collection(type_):
if _is_mapping(type_):
if _is_mapping(type_) and not _is_counter(type_):
k_type, v_type = _get_type_args(type_, (Any, Any))
# a mapping type has `.keys()` and `.values()`
# (see collections.abc)
Expand All @@ -284,6 +284,8 @@ def _decode_generic(type_, value, infer_missing):
xs = _decode_items(types[0], value, infer_missing)
else:
xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing)
elif _is_counter(type_):
xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values()))
else:
xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing)

Expand Down
6 changes: 5 additions & 1 deletion dataclasses_json/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from datetime import datetime, timezone
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
Union, cast)
Union, cast, Counter)


def _get_type_cons(type_):
Expand Down Expand Up @@ -142,6 +142,10 @@ def _is_optional(type_):
type_ is Any)


def _is_counter(type_):
return _issubclass_safe(_get_type_origin(type_), Counter)


def _is_mapping(type_):
return _issubclass_safe(_get_type_origin(type_), Mapping)

Expand Down
7 changes: 7 additions & 0 deletions tests/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from decimal import Decimal
from typing import (Collection,
Counter,
Deque,
Dict,
FrozenSet,
Expand Down Expand Up @@ -370,3 +371,9 @@ class DataClassWithNestedOptional:
@dataclass
class DataClassWithNestedDictWithTupleKeys:
a: Dict[Tuple[int], int]


@dataclass_json
@dataclass
class DataClassWithCounter:
c: Counter[str]
13 changes: 11 additions & 2 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import deque
from collections import Counter, deque

from tests.entities import (DataClassIntImmutableDefault,
DataClassMutableDefaultDict,
Expand All @@ -19,7 +19,8 @@
DataClassWithFrozenSetUnbound,
DataClassWithDequeCollections,
DataClassWithTuple, DataClassWithTupleUnbound,
DataClassWithUnionIntNone, MyCollection)
DataClassWithUnionIntNone, MyCollection,
DataClassWithCounter)


class TestEncoder:
Expand Down Expand Up @@ -112,6 +113,10 @@ def test_mutable_default_list(self):
def test_mutable_default_dict(self):
assert DataClassMutableDefaultDict().to_json() == '{"xs": {}}'

def test_counter(self):
assert DataClassWithCounter(
c=Counter('foo')).to_json() == '{"c": {"f": 1, "o": 2}}'


class TestDecoder:
def test_list(self):
Expand Down Expand Up @@ -235,3 +240,7 @@ def test_mutable_default_dict(self):
== DataClassMutableDefaultDict())
assert (DataClassMutableDefaultDict.from_json('{}', infer_missing=True)
== DataClassMutableDefaultDict())

def test_counter(self):
assert DataClassWithCounter.from_json('{"c": {"f": 1, "o": 2}}') == \
DataClassWithCounter(c=Counter('foo'))

0 comments on commit 12b7f4d

Please sign in to comment.