Skip to content

Commit

Permalink
feat: Add support for Literal types.
Browse files Browse the repository at this point in the history
This commit extends support of dataclasses_json to dataclasses with
fields annotated with Literal types. Literal types allow users to
specify a list of valid values, e.g.,

```python
@DataClass
class DataClassWithLiteral(DataClassJsonMixin):
   languages: Literal["C", "C++", "Java"]
```

When de-serializing data, this commit now validates that the JSON's
values are one of those specified in the Literal type.

Change in behavior:
Using literal types would previously give users the following warning:
```
dataclasses_json/mm.py:357: UserWarning: Unknown type C at Foo.langs: typing.Literal['C', 'C++', 'Java']. It's advised to pass the correct marshmallow type to `mm_field`.
```
  • Loading branch information
arunchaganty committed Jun 9, 2024
1 parent 538ff15 commit b128a02
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 7 deletions.
5 changes: 3 additions & 2 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Tuple, TypeVar, Type)
from uuid import UUID

from typing_inspect import is_union_type # type: ignore
from typing_inspect import is_union_type, is_literal_type # type: ignore

from dataclasses_json import cfg
from dataclasses_json.utils import (_get_type_cons, _get_type_origin,
Expand Down Expand Up @@ -358,7 +358,8 @@ def _decode_dict_keys(key_type, xs, infer_missing):
# This is a special case for Python 3.7 and Python 3.8.
# By some reason, "unbound" dicts are counted
# as having key type parameter to be TypeVar('KT')
if key_type is None or key_type == Any or isinstance(key_type, TypeVar):
# Literal types are also passed through without any decoding.
if key_type is None or key_type == Any or isinstance(key_type, TypeVar) or is_literal_type(key_type):
decode_function = key_type = (lambda x: x)
# handle a nested python dict that has tuples for keys. E.g. for
# Dict[Tuple[int], int], key_type will be typing.Tuple[int], but
Expand Down
59 changes: 54 additions & 5 deletions dataclasses_json/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from uuid import UUID
from enum import Enum

from typing_inspect import is_union_type # type: ignore
from typing_inspect import is_union_type, is_literal_type # type: ignore

from marshmallow import fields, Schema, post_load # type: ignore
from marshmallow.exceptions import ValidationError # type: ignore

from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
_ExtendedEncoder, _user_overrides_or_exts)
from dataclasses_json.utils import (_is_collection, _is_optional,
from dataclasses_json.utils import (_get_type_args, _is_collection, _is_optional,
_issubclass_safe, _timestamp_to_dt_aware,
_is_new_type, _get_type_origin,
_handle_undefined_parameters_safe,
Expand Down Expand Up @@ -130,6 +130,46 @@ def _deserialize(self, value, attr, data, **kwargs):
return None if optional_list is None else tuple(optional_list)


class _LiteralField(fields.Field):
def __init__(self, literal_values, cls, field, *args, **kwargs):
"""Create a new Literal field.
Literals allow you to specify the set of valid _values_ for a field. The field
implementation validates against these values on deserialization.
Example:
>>> @dataclass
... class DataClassWithLiteral(DataClassJsonMixin):
... read_mode: Literal["r", "w", "a"]
Args:
literal_values: A sequence of possible values for the field.
cls: The dataclass that the field belongs to.
field: The field that the schema describes.
"""
self.literal_values = literal_values
self.cls = cls
self.field = field
super().__init__(*args, **kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if self.allow_none and value is None:
return None
if value not in self.literal_values:
warnings.warn(
f'The value "{value}" is not one of the values of typing.Literal '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value will not be de-serialized properly.')
return super()._serialize(value, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
if value not in self.literal_values:
raise ValidationError(
f'Value "{value}" is not one in typing.Literal{self.literal_values} '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).')
return super()._deserialize(value, attr, data, **kwargs)


TYPES = {
typing.Mapping: fields.Mapping,
typing.MutableMapping: fields.Mapping,
Expand Down Expand Up @@ -259,9 +299,14 @@ def inner(type_, options):
f"`dataclass_json` decorator or mixin.")
return fields.Field(**options)

origin = getattr(type_, '__origin__', type_)
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
a is not type(None)]
origin = _get_type_origin(type_)

# Type arguments are typically types (e.g. int in list[int]) except for Literal
# types, where they are values.
if is_literal_type(type_):
args = []
else:
args = [inner(a, {}) for a in _get_type_args(type_) if a is not type(None)]

if type_ == Ellipsis:
return type_
Expand All @@ -279,6 +324,10 @@ def inner(type_, options):
if _issubclass_safe(origin, Enum):
return fields.Enum(enum=origin, by_value=True, *args, **options)

if is_literal_type(type_):
literal_values = _get_type_args(type_)
return _LiteralField(literal_values, cls, field, **options)

if is_union_type(type_):
union_types = [a for a in getattr(type_, '__args__', []) if
a is not type(None)]
Expand Down
103 changes: 103 additions & 0 deletions tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Test dataclasses_json handling of Literal types."""
import sys
import pytest

if sys.version_info < (3, 8):
pytest.skip("Literal types are only supported in Python 3.8+", allow_module_level=True)

import json
from typing import Literal, Optional

from dataclasses import dataclass

from dataclasses_json import dataclass_json, DataClassJsonMixin
from marshmallow.exceptions import ValidationError # type: ignore


@dataclass_json
@dataclass
class DataClassWithLiteral(DataClassJsonMixin):
numeric_literals: Literal[0, 1]
string_literals: Literal["one", "two", "three"]
mixed_literals: Literal[0, "one", 2]


with_valid_literal_json = '{"numeric_literals": 0, "string_literals": "one", "mixed_literals": 2}'
with_valid_literal_data = DataClassWithLiteral(numeric_literals=0, string_literals="one", mixed_literals=2)
with_invalid_literal_json = '{"numeric_literals": 9, "string_literals": "four", "mixed_literals": []}'
with_invalid_literal_data = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore

@dataclass_json
@dataclass
class DataClassWithNestedLiteral(DataClassJsonMixin):
list_of_literals: list[Literal[0, 1]]
dict_of_literals: dict[Literal["one", "two", "three"], Literal[0, 1]]
optional_literal: Optional[Literal[0, 1]]

with_valid_nested_literal_json = '{"list_of_literals": [0, 1], "dict_of_literals": {"one": 0, "two": 1}, "optional_literal": 1}'
with_valid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 1], dict_of_literals={"one": 0, "two": 1}, optional_literal=1)
with_invalid_nested_literal_json = '{"list_of_literals": [0, 2], "dict_of_literals": {"one": 0, "four": 2}, "optional_literal": 2}'
with_invalid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 2], dict_of_literals={"one": 0, "four": 2}, optional_literal=2) # type: ignore

class TestEncoder:
def test_valid_literal(self):
assert with_valid_literal_data.to_json() == with_valid_literal_json
assert with_valid_literal_data.to_dict(encode_json=True) == json.loads(with_valid_literal_json)

def test_invalid_literal(self):
assert with_invalid_literal_data.to_json() == with_invalid_literal_json
assert with_invalid_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_literal_json)

def test_valid_nested_literal(self):
assert with_valid_nested_literal_data.to_json() == with_valid_nested_literal_json
assert with_valid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_valid_nested_literal_json)

def test_invalid_nested_literal(self):
assert with_invalid_nested_literal_data.to_json() == with_invalid_nested_literal_json
assert with_invalid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_nested_literal_json)


class TestSchemaEncoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.schema().dumps(with_valid_literal_data)
assert actual == with_valid_literal_json

def test_invalid_literal(self):
actual = DataClassWithLiteral.schema().dumps(with_invalid_literal_data)
assert actual == with_invalid_literal_json

def test_valid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().dumps(with_valid_nested_literal_data)
assert actual == with_valid_nested_literal_json

def test_invalid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().dumps(with_invalid_nested_literal_data)
assert actual == with_invalid_nested_literal_json

class TestDecoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.from_json(with_valid_literal_json)
assert actual == with_valid_literal_data

def test_invalid_literal(self):
expected = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore
actual = DataClassWithLiteral.from_json(with_invalid_literal_json)
assert actual == expected


class TestSchemaDecoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.schema().loads(with_valid_literal_json)
assert actual == with_valid_literal_data

def test_invalid_literal(self):
with pytest.raises(ValidationError):
DataClassWithLiteral.schema().loads(with_invalid_literal_json)

def test_valid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().loads(with_valid_nested_literal_json)
assert actual == with_valid_nested_literal_data

def test_invalid_nested_literal(self):
with pytest.raises(ValidationError):
DataClassWithNestedLiteral.schema().loads(with_invalid_nested_literal_json)

0 comments on commit b128a02

Please sign in to comment.