Skip to content

feat(dataclass): Support frozen=True in py_class #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/mlc/_cython/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ def fget(this: typing.Any, _name: str = name) -> typing.Any:
def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
setter(this, value) # type: ignore[misc]

fget.__name__ = fset.__name__ = name
fget.__module__ = fset.__module__ = cls.__module__
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined]
fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined]
prop = property(
fget=fget if getter else None,
fset=fset if (not frozen) and setter else None,
Expand Down
21 changes: 20 additions & 1 deletion python/mlc/core/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing

from mlc._cython import PyAny, c_class_core
from mlc._cython import PyAny, TypeInfo, c_class_core


@c_class_core("object.Object")
Expand Down Expand Up @@ -65,6 +65,25 @@ def __eq__(self, other: typing.Any) -> bool:
def __ne__(self, other: typing.Any) -> bool:
return not self == other

def _mlc_setattr(self, name: str, value: typing.Any) -> None:
type_info: TypeInfo = type(self)._mlc_type_info
for field in type_info.fields:
if field.name == name:
if field.setter is None:
raise AttributeError(f"Attribute `{name}` missing setter")
field.setter(self, value)
return
raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`")

def _mlc_getattr(self, name: str) -> typing.Any:
type_info: TypeInfo = type(self)._mlc_type_info
for field in type_info.fields:
if field.name == name:
if field.getter is None:
raise AttributeError(f"Attribute `{name}` missing getter")
return field.getter(self)
raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`")

def swap(self, other: typing.Any) -> None:
if type(self) == type(other):
self._mlc_swap(other)
Expand Down
2 changes: 1 addition & 1 deletion python/mlc/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class type_cls(super_type_cls): # type: ignore[valid-type,misc]

if type_info.type_cls is not None:
raise ValueError(f"Type is already registered: {type_key}")
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info)
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False)
type_info.type_cls = type_cls
type_info.d_fields = tuple(d_fields)

Expand Down
2 changes: 2 additions & 0 deletions python/mlc/dataclasses/py_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def py_class(
*,
init: bool = True,
repr: bool = True,
frozen: bool = False,
structure: typing.Literal["bind", "nobind", "var"] | None = None,
) -> Callable[[type[ClsType]], type[ClsType]]:
if isinstance(type_key, type):
Expand Down Expand Up @@ -86,6 +87,7 @@ def decorator(super_type_cls: type[ClsType]) -> type[ClsType]:
type_key,
super_type_cls,
parent_type_info,
frozen=frozen,
)
num_bytes = _add_field_properties(fields)
type_info.fields = tuple(fields)
Expand Down
6 changes: 4 additions & 2 deletions python/mlc/dataclasses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def inspect_dataclass_fields( # noqa: PLR0912
type_key: str,
type_cls: type,
parent_type_info: TypeInfo,
frozen: bool,
) -> tuple[list[TypeField], list[Field]]:
def _get_num_bytes(field_ty: Any) -> int:
if hasattr(field_ty, "_ctype"):
Expand All @@ -136,6 +137,7 @@ def _get_num_bytes(field_ty: Any) -> int:
for type_field in parent_type_info.fields:
field_name = type_field.name
field_ty = type_field.ty
field_frozen = type_field.frozen
if type_hints.pop(field_name, None) is None:
raise ValueError(
f"Missing field `{type_key}::{field_name}`, "
Expand All @@ -146,7 +148,7 @@ def _get_num_bytes(field_ty: Any) -> int:
name=field_name,
offset=-1,
num_bytes=_get_num_bytes(field_ty),
frozen=False,
frozen=field_frozen,
ty=field_ty,
)
)
Expand All @@ -159,7 +161,7 @@ def _get_num_bytes(field_ty: Any) -> int:
name=field_name,
offset=-1,
num_bytes=_get_num_bytes(field_ty),
frozen=False,
frozen=frozen,
ty=field_ty,
)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/python/test_dataclasses_py_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import mlc
import mlc.dataclasses as mlcd
import pytest


@mlcd.py_class("mlc.testing.py_class_base")
Expand Down Expand Up @@ -57,6 +58,12 @@ def __post_init__(self) -> None:
self.b = self.b.upper()


@mlcd.py_class("mlc.testing.py_class_frozen", frozen=True)
class Frozen(mlcd.PyClass):
a: int
b: str


def test_base() -> None:
base = Base(1, "a")
base_str = "mlc.testing.py_class_base(base_a=1, base_b='a')"
Expand Down Expand Up @@ -120,6 +127,30 @@ def test_post_init() -> None:
assert repr(post_init) == "mlc.testing.py_class_post_init(a=1, b='A')"


def test_frozen_set_fail() -> None:
frozen = Frozen(1, "a")
with pytest.raises(AttributeError) as e:
frozen.a = 2
# depends on Python version, there are a few possible error messages
assert str(e.value) in [
"property 'a' of 'Frozen' object has no setter",
"can't set attribute",
]
assert frozen.a == 1
assert frozen.b == "a"


def test_frozen_force_set() -> None:
frozen = Frozen(1, "a")
frozen._mlc_setattr("a", 2)
assert frozen.a == 2
assert frozen.b == "a"

frozen._mlc_setattr("b", "b")
assert frozen.a == 2
assert frozen.b == "b"


def test_derived_derived() -> None:
# __init__(base_a, derived_derived_a, base_b, derived_a, derived_b)
obj = DerivedDerived(1, "a", [1, 2], 2, "b")
Expand Down
Loading