Skip to content

Commit e6810a8

Browse files
daiyippyglove authors
authored andcommitted
Improves handling on using methods as the values for callable symbolic attributes.
- Frozen fields will be absent from __init__ signature. - Frozen fields will not be included during `format`. PiperOrigin-RevId: 639901050
1 parent 7b56371 commit e6810a8

File tree

7 files changed

+106
-38
lines changed

7 files changed

+106
-38
lines changed

pyglove/core/symbolic/dict.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ def update(self,
796796

797797
def sym_jsonify(
798798
self,
799+
hide_frozen: bool = True,
799800
hide_default_values: bool = False,
800801
exclude_keys: Optional[Sequence[str]] = None,
801802
use_inferred: bool = False,
@@ -809,26 +810,30 @@ def sym_jsonify(
809810
# NOTE(daiyip): The key values of frozen field can safely be excluded
810811
# since they will be the same for a class.
811812
field = self._value_spec.schema[key_spec]
812-
if not field.frozen:
813-
for key in keys:
814-
if key not in exclude_keys:
815-
value = self.sym_getattr(key)
816-
if use_inferred and isinstance(value, base.Inferential):
817-
value = self.sym_inferred(key, default=value)
818-
if pg_typing.MISSING_VALUE == value:
819-
continue
820-
if hide_default_values and base.eq(value, field.default_value):
821-
continue
822-
json_repr[key] = base.to_json(
823-
value, hide_default_values=hide_default_values,
824-
use_inferred=use_inferred,
825-
**kwargs)
813+
if hide_frozen and field.frozen:
814+
continue
815+
for key in keys:
816+
if key not in exclude_keys:
817+
value = self.sym_getattr(key)
818+
if use_inferred and isinstance(value, base.Inferential):
819+
value = self.sym_inferred(key, default=value)
820+
if pg_typing.MISSING_VALUE == value:
821+
continue
822+
if hide_default_values and base.eq(value, field.default_value):
823+
continue
824+
json_repr[key] = base.to_json(
825+
value,
826+
hide_frozen=hide_frozen,
827+
hide_default_values=hide_default_values,
828+
use_inferred=use_inferred,
829+
**kwargs)
826830
return json_repr
827831
else:
828832
return {
829833
k: base.to_json(
830834
self.sym_inferred(k, default=v) if (
831835
use_inferred and isinstance(v, base.Inferential)) else v,
836+
hide_frozen=hide_frozen,
832837
hide_default_values=hide_default_values,
833838
use_inferred=use_inferred,
834839
**kwargs)
@@ -880,6 +885,7 @@ def format(
880885
*,
881886
python_format: bool = False,
882887
markdown: bool = False,
888+
hide_frozen: bool = True,
883889
hide_default_values: bool = False,
884890
hide_missing_values: bool = False,
885891
include_keys: Optional[Set[str]] = None,
@@ -910,6 +916,8 @@ def _should_include_key(key):
910916
for key in keys:
911917
if _should_include_key(key):
912918
field = self._value_spec.schema[key_spec]
919+
if hide_frozen and field.frozen:
920+
continue
913921
v = self.sym_getattr(key)
914922
if use_inferred and isinstance(v, base.Inferential):
915923
v = self.sym_inferred(key, default=v)
@@ -941,6 +949,7 @@ def _should_include_key(key):
941949
compact,
942950
verbose,
943951
root_indent + 1,
952+
hide_frozen=hide_frozen,
944953
hide_default_values=hide_default_values,
945954
hide_missing_values=hide_missing_values,
946955
python_format=python_format,
@@ -971,6 +980,7 @@ def _should_include_key(key):
971980
compact,
972981
verbose,
973982
root_indent + 1,
983+
hide_frozen=hide_frozen,
974984
hide_default_values=hide_default_values,
975985
hide_missing_values=hide_missing_values,
976986
python_format=python_format,

pyglove/core/symbolic/dict_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,36 @@ def __eq__(self, other):
19071907
self.assertEqual(sd.to_json_str(), '{"x": 1, "y": 2.0}')
19081908
self.assertEqual(base.from_json_str(sd.to_json_str(), value_spec=spec), sd)
19091909

1910+
def test_hide_frozen(self):
1911+
1912+
class A(pg_object.Object):
1913+
x: pg_typing.Int().freeze(1)
1914+
1915+
sd = Dict.partial(
1916+
a=A(),
1917+
value_spec=pg_typing.Dict([
1918+
('a', pg_typing.Object(A)),
1919+
('b', pg_typing.Bool(True).freeze()),
1920+
]))
1921+
self.assertEqual(
1922+
sd.to_json(),
1923+
{
1924+
'a': {
1925+
'_type': A.__type_name__
1926+
},
1927+
}
1928+
)
1929+
self.assertEqual(
1930+
sd.to_json(hide_frozen=False),
1931+
{
1932+
'a': {
1933+
'_type': A.__type_name__,
1934+
'x': 1,
1935+
},
1936+
'b': True
1937+
}
1938+
)
1939+
19101940
def test_hide_default_values(self):
19111941

19121942
class A(pg_object.Object):
@@ -2255,6 +2285,24 @@ def test_noncompact_with_inferred_value(self):
22552285
"""),
22562286
)
22572287

2288+
def test_hide_frozen(self):
2289+
d = Dict(x=1, value_spec=pg_typing.Dict([
2290+
('x', pg_typing.Int()),
2291+
('y', pg_typing.Bool(True).freeze()),
2292+
('z', pg_typing.Dict([
2293+
('v', pg_typing.Int(1)),
2294+
('w', pg_typing.Bool().freeze(True)),
2295+
]))
2296+
]))
2297+
self.assertEqual(
2298+
d.format(compact=True),
2299+
'{x=1, z={v=1}}'
2300+
)
2301+
self.assertEqual(
2302+
d.format(compact=False, hide_frozen=False),
2303+
'{\n x = 1,\n y = True,\n z = {\n v = 1,\n w = True\n }\n}'
2304+
)
2305+
22582306

22592307
def _on_change_callback(updates):
22602308
del updates

pyglove/core/symbolic/object.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,27 @@ def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]:
149149
fields = cls._end_annotation_inference(fields) # pytype: disable=attribute-error
150150
return fields
151151

152-
def _update_default_values_from_class_attributes(cls):
153-
"""Updates the symbolic attribute defaults from class attributes."""
154-
for field in cls.__schema__.fields.values():
152+
def _update_default_values_from_class_attributes(
153+
cls, schema: pg_typing.Schema):
154+
"""Freezes callable fields if their defaults are provided as methods."""
155+
for field in schema.fields.values():
155156
if isinstance(field.key, pg_typing.ConstStrKey):
156-
attr_name = field.key.text
157-
attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE)
158-
if (
159-
attr_value != pg_typing.MISSING_VALUE
160-
and not isinstance(attr_value, property)
161-
and (
162-
# This allows class methods to be used as callable
163-
# symbolic attributes.
164-
not inspect.isfunction(attr_value)
165-
or isinstance(field.value, pg_typing.Callable)
166-
)
167-
):
157+
attr_value = cls.__dict__.get(field.key.text, pg_typing.MISSING_VALUE)
158+
if (attr_value == pg_typing.MISSING_VALUE
159+
or isinstance(attr_value, property)):
160+
continue
161+
if inspect.isfunction(attr_value):
162+
# When users add a method that has the same name as as field, two
163+
# scenarios emerge. If the field is a callable type, the method will
164+
# serve as the default value for the field. As a result, we freeze the
165+
# field so it can't be provided from the constructor. If the field is
166+
# not a callable type, the symbolic field and the method will coexist,
167+
# meaning that the method has higher priority when being accessed,
168+
# while users still can use `sym_getattr` to access the value for the
169+
# symboic field.
170+
if isinstance(field.value, pg_typing.Callable):
171+
field.value.freeze(attr_value, apply_before_use=False)
172+
else:
168173
field.value.set_default(attr_value)
169174

170175

@@ -320,6 +325,9 @@ def __init_subclass__(cls):
320325
metadata={},
321326
)
322327
)
328+
# Freeze callable symbolic attributes if they are provided as methods.
329+
user_cls._update_default_values_from_class_attributes(cls_schema)
330+
323331
# NOTE(daiyip): When new fields are added through class attributes.
324332
# We invalidate `init_arg_list` so PyGlove could recompute it based
325333
# on its schema during `apply_schema`. Otherwise, we inherit the
@@ -333,10 +341,6 @@ def __init_subclass__(cls):
333341
@classmethod
334342
def _on_schema_update(cls):
335343
"""Customizable trait: handling schema change."""
336-
# Update the default value for each field after schema is updated. This is
337-
# because that users may change a field's default value via class attribute.
338-
cls._update_default_values_from_class_attributes() # pylint: disable=no-value-for-parameter
339-
340344
# Update all schema-based signatures.
341345
cls._update_signatures_based_on_schema()
342346

pyglove/core/symbolic/object_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_init_arg_list(self):
107107
self.assertEqual(
108108
self._A.init_arg_list, ['x', 'y', 'z', 'p'])
109109
self.assertEqual(
110-
self._B.init_arg_list, ['x', 'y', 'z', 'p', 'q'])
110+
self._B.init_arg_list, ['x', 'y', 'z', 'q'])
111111
self.assertEqual(
112112
self._C.init_arg_list, ['x', 'y', 'z', '*args'])
113113

@@ -370,6 +370,9 @@ def x(self, v):
370370
def y(self):
371371
return self.sym_init_args.y * 2
372372

373+
self.assertTrue(H.__schema__.fields['x'].frozen)
374+
self.assertFalse(H.__schema__.fields['y'].frozen)
375+
373376
h = H(y=1)
374377
self.assertEqual(h.x(1), 3)
375378
self.assertEqual(h.y(), 2)

pyglove/core/symbolic/schema_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def auto_init_arg_list(cls):
267267
for base_cls in cls.__bases__:
268268
schema = getattr(base_cls, '__schema__', None)
269269
if isinstance(schema, pg_typing.Schema):
270-
if list(schema.keys()) == list(cls.__schema__.keys()):
270+
if ([(k, f.frozen) for k, f in schema.fields.items()]
271+
== [(k, f.frozen) for k, f in cls.__schema__.fields.items()]):
271272
init_arg_list = base_cls.init_arg_list
272273
else:
273274
break
@@ -276,8 +277,8 @@ def auto_init_arg_list(cls):
276277
# declaration order from base classes to subclasses.
277278
init_arg_list = [
278279
str(key)
279-
for key in cls.__schema__.fields.keys()
280-
if isinstance(key, pg_typing.ConstStrKey)
280+
for key, field in cls.__schema__.fields.items()
281+
if isinstance(key, pg_typing.ConstStrKey) and not field.frozen
281282
]
282283
return init_arg_list
283284

pyglove/core/typing/callable_signature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def get_arg_spec(arg_name):
228228
kwonlyargs = []
229229
varkw = None
230230
for key, field in schema.fields.items():
231-
if key not in existing_names:
231+
if key not in existing_names and not field.frozen:
232232
if key.is_const:
233233
kwonlyargs.append(Argument(str(key), field.value))
234234
else:

pyglove/core/typing/callable_signature_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def _get_signature(self, init_arg_list, is_method: bool = True):
230230
s = class_schema.Schema([
231231
class_schema.Field('x', vs.Int(), 'x'),
232232
class_schema.Field('y', vs.Int(), 'y'),
233+
# Frozen fields will be ignored.
234+
class_schema.Field('v', vs.Bool().freeze(True), 'v'),
233235
class_schema.Field('z', vs.List(vs.Int()), 'z'),
234236
class_schema.Field(ks.StrKey(), vs.Str(), 'kwargs'),
235237
], metadata=dict(init_arg_list=init_arg_list), allow_nonconst_keys=True)

0 commit comments

Comments
 (0)