Skip to content

Commit

Permalink
Improve dataclass_json and _process_class type annotations (#475)
Browse files Browse the repository at this point in the history
Use typing.overload to differentiate the return type of dataclass_json based on whether _cls parameter is None or not
  • Loading branch information
ringohoffman authored Aug 25, 2023
1 parent 89578cb commit dd4c414
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions dataclasses_json/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import abc
import json
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar,
Union)
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload

from dataclasses_json.cfg import config, LetterCase # noqa: F401
from dataclasses_json.cfg import config, LetterCase
from dataclasses_json.core import (Json, _ExtendedEncoder, _asdict,
_decode_dataclass)
from dataclasses_json.mm import (JsonData, SchemaType, build_schema)
Expand All @@ -12,6 +11,7 @@
_undefined_parameter_action_safe)

A = TypeVar('A', bound="DataClassJsonMixin")
T = TypeVar('T')
Fields = List[Tuple[str, Any]]


Expand Down Expand Up @@ -102,8 +102,18 @@ def schema(cls: Type[A],
unknown=unknown)


def dataclass_json(_cls=None, *, letter_case=None,
undefined: Optional[Union[str, Undefined]] = None):
@overload
def dataclass_json(_cls: None = ..., *, letter_case: Optional[LetterCase] = ...,
undefined: Optional[Union[str, Undefined]] = ...) -> Callable[[Type[T]], Type[T]]: ...


@overload
def dataclass_json(_cls: Type[T], *, letter_case: Optional[LetterCase] = ...,
undefined: Optional[Union[str, Undefined]] = ...) -> Type[T]: ...


def dataclass_json(_cls: Optional[Type[T]] = None, *, letter_case: Optional[LetterCase] = None,
undefined: Optional[Union[str, Undefined]] = None) -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
"""
Based on the code in the `dataclasses` module to handle optional-parens
decorators. See example below:
Expand All @@ -114,29 +124,30 @@ class Example:
...
"""

def wrap(cls):
def wrap(cls: Type[T]) -> Type[T]:
return _process_class(cls, letter_case, undefined)

if _cls is None:
return wrap
return wrap(_cls)


def _process_class(cls, letter_case, undefined):
def _process_class(cls: Type[T], letter_case: Optional[LetterCase],
undefined: Optional[Union[str, Undefined]]) -> Type[T]:
if letter_case is not None or undefined is not None:
cls.dataclass_json_config = config(letter_case=letter_case,
undefined=undefined)[
'dataclasses_json']
cls.dataclass_json_config = config(letter_case=letter_case, # type: ignore[attr-defined]
undefined=undefined)['dataclasses_json']

cls.to_json = DataClassJsonMixin.to_json
cls.to_json = DataClassJsonMixin.to_json # type: ignore[attr-defined]
# unwrap and rewrap classmethod to tag it to cls rather than the literal
# DataClassJsonMixin ABC
cls.from_json = classmethod(DataClassJsonMixin.from_json.__func__) # type: ignore
cls.to_dict = DataClassJsonMixin.to_dict
cls.from_dict = classmethod(DataClassJsonMixin.from_dict.__func__) # type: ignore
cls.schema = classmethod(DataClassJsonMixin.schema.__func__) # type: ignore
cls.from_json = classmethod(DataClassJsonMixin.from_json.__func__) # type: ignore[attr-defined]
cls.to_dict = DataClassJsonMixin.to_dict # type: ignore[attr-defined]
cls.from_dict = classmethod(DataClassJsonMixin.from_dict.__func__) # type: ignore[attr-defined]
cls.schema = classmethod(DataClassJsonMixin.schema.__func__) # type: ignore[attr-defined]

cls.__init__ = _handle_undefined_parameters_safe(cls, kvs=(), usage="init")
cls.__init__ = _handle_undefined_parameters_safe(cls, kvs=(), # type: ignore[attr-defined,method-assign]
usage="init")
# register cls as a virtual subclass of DataClassJsonMixin
DataClassJsonMixin.register(cls)
return cls

0 comments on commit dd4c414

Please sign in to comment.