Skip to content

Commit 1276ef2

Browse files
ds-hwangwangkuiyi
authored andcommitted
Improve Required type definition for stricter type checking
GitOrigin-RevId: 6e75838d04b9b825c5473aad68fe30898f3a05de
1 parent 46c95f9 commit 1276ef2

1 file changed

Lines changed: 104 additions & 3 deletions

File tree

axlearn/common/config.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ class Config(ConfigBase):
6868
import re
6969
import types
7070
from collections import defaultdict
71-
from collections.abc import Collection, Iterable
71+
from collections.abc import Collection, Iterable, Iterator
7272
from functools import cache, lru_cache
73-
from typing import Any, Callable, Generic, Optional, Protocol, Sequence, TypeVar, Union
73+
from typing import Any, Callable, Generic, NoReturn, Optional, Protocol, Sequence, TypeVar, Union
7474

7575
# attr provides similar features as Python dataclass. Unlike
7676
# dataclass, however, it provides a richer set of features to regulate
@@ -143,6 +143,13 @@ def overlaps(name: str, key: str) -> float:
143143

144144

145145
class RequiredFieldValue:
146+
"""Sentinel for required config fields that have not been set."""
147+
148+
def _raise(self, op: str) -> NoReturn:
149+
raise TypeError(
150+
f"Cannot use a required config field before setting its value (attempted: {op})."
151+
)
152+
146153
def __deepcopy__(self, memo):
147154
return self
148155

@@ -152,10 +159,104 @@ def __bool__(self):
152159
def __repr__(self):
153160
return "REQUIRED"
154161

162+
# Attribute access — covers .items(), .keys(), .values(), .set(), etc.
163+
def __getattr__(self, name: str) -> Any:
164+
raise AttributeError(
165+
f"Cannot access attribute '{name}' on a required config field "
166+
"before setting its value."
167+
)
168+
169+
# Subscript — covers [key].
170+
def __getitem__(self, key: Any) -> Any:
171+
self._raise("subscript")
172+
173+
def __setitem__(self, key: Any, value: Any) -> None:
174+
self._raise("subscript")
175+
176+
# Iteration and containment.
177+
def __iter__(self) -> Iterator: # pylint: disable=non-iterator-returned
178+
self._raise("iterate over")
179+
180+
def __contains__(self, item: Any) -> bool:
181+
self._raise("check membership of")
182+
183+
def __len__(self) -> int: # pylint: disable=invalid-length-returned
184+
self._raise("get length of")
185+
186+
# Callable.
187+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
188+
self._raise("call")
189+
190+
# Arithmetic operators.
191+
def __add__(self, other: Any) -> Any:
192+
self._raise("add")
193+
194+
def __radd__(self, other: Any) -> Any:
195+
self._raise("add")
196+
197+
def __sub__(self, other: Any) -> Any:
198+
self._raise("subtract")
199+
200+
def __rsub__(self, other: Any) -> Any:
201+
self._raise("subtract")
202+
203+
def __mul__(self, other: Any) -> Any:
204+
self._raise("multiply")
205+
206+
def __rmul__(self, other: Any) -> Any:
207+
self._raise("multiply")
208+
209+
def __truediv__(self, other: Any) -> Any:
210+
self._raise("divide")
211+
212+
def __rtruediv__(self, other: Any) -> Any:
213+
self._raise("divide")
214+
215+
def __floordiv__(self, other: Any) -> Any:
216+
self._raise("floor-divide")
217+
218+
def __mod__(self, other: Any) -> Any:
219+
self._raise("modulo")
220+
221+
def __pow__(self, other: Any) -> Any:
222+
self._raise("exponentiate")
223+
224+
def __neg__(self) -> Any:
225+
self._raise("negate")
226+
227+
def __pos__(self) -> Any:
228+
self._raise("apply unary +")
229+
230+
def __abs__(self) -> Any:
231+
self._raise("apply abs() to")
232+
233+
# Comparison operators.
234+
def __lt__(self, other: Any) -> Any:
235+
self._raise("compare")
236+
237+
def __le__(self, other: Any) -> Any:
238+
self._raise("compare")
239+
240+
def __gt__(self, other: Any) -> Any:
241+
self._raise("compare")
242+
243+
def __ge__(self, other: Any) -> Any:
244+
self._raise("compare")
245+
246+
# Type conversions.
247+
def __int__(self) -> int:
248+
self._raise("convert to int")
249+
250+
def __float__(self) -> float:
251+
self._raise("convert to float")
252+
253+
def __str__(self) -> str:
254+
return "REQUIRED"
255+
155256

156257
# TODO(markblee): Raise if trying to set attributes on REQUIRED.
157258
REQUIRED = RequiredFieldValue()
158-
Required = Union[T, RequiredFieldValue, Any]
259+
Required = Union[T, RequiredFieldValue]
159260

160261

161262
class MissingConfigClassDecoratorError(TypeError):

0 commit comments

Comments
 (0)