Skip to content

Commit 540f696

Browse files
committed
feat(dataclass): Support frozen=True in py_class
1 parent f9fc4ad commit 540f696

File tree

6 files changed

+62
-4
lines changed

6 files changed

+62
-4
lines changed

python/mlc/_cython/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def fget(this: typing.Any, _name: str = name) -> typing.Any:
377377
def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
378378
setter(this, value) # type: ignore[misc]
379379

380+
fget.__name__ = fset.__name__ = name
381+
fget.__module__ = fset.__module__ = cls.__module__
382+
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined]
383+
fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined]
380384
prop = property(
381385
fget=fget if getter else None,
382386
fset=fset if (not frozen) and setter else None,

python/mlc/core/object.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import typing
44

5-
from mlc._cython import PyAny, c_class_core
5+
from mlc._cython import PyAny, TypeInfo, c_class_core
66

77

88
@c_class_core("object.Object")
@@ -65,6 +65,25 @@ def __eq__(self, other: typing.Any) -> bool:
6565
def __ne__(self, other: typing.Any) -> bool:
6666
return not self == other
6767

68+
def _mlc_setattr(self, name: str, value: typing.Any) -> None:
69+
type_info: TypeInfo = type(self)._mlc_type_info
70+
for field in type_info.fields:
71+
if field.name == name:
72+
if field.setter is None:
73+
raise AttributeError(f"Attribute `{name}` missing setter")
74+
field.setter(self, value)
75+
return
76+
raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`")
77+
78+
def _mlc_getattr(self, name: str) -> typing.Any:
79+
type_info: TypeInfo = type(self)._mlc_type_info
80+
for field in type_info.fields:
81+
if field.name == name:
82+
if field.getter is None:
83+
raise AttributeError(f"Attribute `{name}` missing getter")
84+
return field.getter(self)
85+
raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`")
86+
6887
def swap(self, other: typing.Any) -> None:
6988
if type(self) == type(other):
7089
self._mlc_swap(other)

python/mlc/dataclasses/c_class.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class type_cls(super_type_cls): # type: ignore[valid-type,misc]
3939

4040
if type_info.type_cls is not None:
4141
raise ValueError(f"Type is already registered: {type_key}")
42-
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info)
42+
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False)
4343
type_info.type_cls = type_cls
4444
type_info.d_fields = tuple(d_fields)
4545

python/mlc/dataclasses/py_class.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def py_class(
4949
*,
5050
init: bool = True,
5151
repr: bool = True,
52+
frozen: bool = False,
5253
structure: typing.Literal["bind", "nobind", "var"] | None = None,
5354
) -> Callable[[type[ClsType]], type[ClsType]]:
5455
if isinstance(type_key, type):
@@ -86,6 +87,7 @@ def decorator(super_type_cls: type[ClsType]) -> type[ClsType]:
8687
type_key,
8788
super_type_cls,
8889
parent_type_info,
90+
frozen=frozen,
8991
)
9092
num_bytes = _add_field_properties(fields)
9193
type_info.fields = tuple(fields)

python/mlc/dataclasses/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def inspect_dataclass_fields( # noqa: PLR0912
124124
type_key: str,
125125
type_cls: type,
126126
parent_type_info: TypeInfo,
127+
frozen: bool,
127128
) -> tuple[list[TypeField], list[Field]]:
128129
def _get_num_bytes(field_ty: Any) -> int:
129130
if hasattr(field_ty, "_ctype"):
@@ -136,6 +137,7 @@ def _get_num_bytes(field_ty: Any) -> int:
136137
for type_field in parent_type_info.fields:
137138
field_name = type_field.name
138139
field_ty = type_field.ty
140+
field_frozen = type_field.frozen
139141
if type_hints.pop(field_name, None) is None:
140142
raise ValueError(
141143
f"Missing field `{type_key}::{field_name}`, "
@@ -146,7 +148,7 @@ def _get_num_bytes(field_ty: Any) -> int:
146148
name=field_name,
147149
offset=-1,
148150
num_bytes=_get_num_bytes(field_ty),
149-
frozen=False,
151+
frozen=field_frozen,
150152
ty=field_ty,
151153
)
152154
)
@@ -159,7 +161,7 @@ def _get_num_bytes(field_ty: Any) -> int:
159161
name=field_name,
160162
offset=-1,
161163
num_bytes=_get_num_bytes(field_ty),
162-
frozen=False,
164+
frozen=frozen,
163165
ty=field_ty,
164166
)
165167
)

tests/python/test_dataclasses_py_class.py

+31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import mlc
44
import mlc.dataclasses as mlcd
5+
import pytest
56

67

78
@mlcd.py_class("mlc.testing.py_class_base")
@@ -57,6 +58,12 @@ def __post_init__(self) -> None:
5758
self.b = self.b.upper()
5859

5960

61+
@mlcd.py_class("mlc.testing.py_class_frozen", frozen=True)
62+
class Frozen(mlcd.PyClass):
63+
a: int
64+
b: str
65+
66+
6067
def test_base() -> None:
6168
base = Base(1, "a")
6269
base_str = "mlc.testing.py_class_base(base_a=1, base_b='a')"
@@ -120,6 +127,30 @@ def test_post_init() -> None:
120127
assert repr(post_init) == "mlc.testing.py_class_post_init(a=1, b='A')"
121128

122129

130+
def test_frozen_set_fail() -> None:
131+
frozen = Frozen(1, "a")
132+
with pytest.raises(AttributeError) as e:
133+
frozen.a = 2
134+
# depends on Python version, there are a few possible error messages
135+
assert str(e.value) in [
136+
"property 'a' of 'Frozen' object has no setter",
137+
"can't set attribute",
138+
]
139+
assert frozen.a == 1
140+
assert frozen.b == "a"
141+
142+
143+
def test_frozen_force_set() -> None:
144+
frozen = Frozen(1, "a")
145+
frozen._mlc_setattr("a", 2)
146+
assert frozen.a == 2
147+
assert frozen.b == "a"
148+
149+
frozen._mlc_setattr("b", "b")
150+
assert frozen.a == 2
151+
assert frozen.b == "b"
152+
153+
123154
def test_derived_derived() -> None:
124155
# __init__(base_a, derived_derived_a, base_b, derived_a, derived_b)
125156
obj = DerivedDerived(1, "a", [1, 2], 2, "b")

0 commit comments

Comments
 (0)