Skip to content

Commit 32f2c03

Browse files
niketkumarOrbax Authors
authored andcommitted
Resolve default value of Handler.typestr if method missing.
PiperOrigin-RevId: 710747415
1 parent c227788 commit 32f2c03

File tree

5 files changed

+79
-14
lines changed

5 files changed

+79
-14
lines changed

checkpoint/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.0] - 2024-12-30
11+
12+
### Fixed
13+
- Resolve default value of Handler.typestr if method missing.
14+
### Added
15+
- Add announcement for grain version compatibility. See
16+
https://github.com/google/orbax/issues/1456.
17+
18+
1019
## [0.10.3] - 2024-12-19
1120

1221
### Added

checkpoint/orbax/checkpoint/_src/handlers/handler_type_registry.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,33 @@ def get(
6868

6969

7070
def register_handler_type(handler_cls):
71-
_GLOBAL_HANDLER_TYPE_REGISTRY.add(handler_cls.typestr(), handler_cls)
71+
"""Registers a checkpoint handler type in the global registry.
72+
73+
The registry is keyed by the handler's typestr. If the handler does not
74+
provide a typestr, the default typestr is resolved from the handler's
75+
module and class name.
76+
77+
Args:
78+
handler_cls: The checkpoint handler class to register.
79+
80+
Returns:
81+
The registered checkpoint handler class.
82+
"""
83+
# TODO(adamcogdell): Change HandlerTypeRegistry.add(typestr, type) to
84+
# HandlerTypeRegistry.add(handler_type) and move following logic into
85+
# HandlerTypeRegistry.add(). It will help to drop unit tests based on the
86+
# singleton HandlerTypeRegistry, which can be flaky.
87+
try:
88+
typestr = handler_cls.typestr()
89+
except AttributeError:
90+
typestr = f'{handler_cls.__module__}.{handler_cls.__qualname__}'
91+
logging.warning(
92+
'Handler class %s does not have a typestr method. '
93+
'Using the default typestr value "%s" instead.',
94+
handler_cls,
95+
typestr,
96+
)
97+
_GLOBAL_HANDLER_TYPE_REGISTRY.add(typestr, handler_cls)
7298
return handler_cls
7399

74100

checkpoint/orbax/checkpoint/_src/handlers/handler_type_registry_test.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Tests for CheckpointerHandler type registry."""
1616

17+
import copy
1718
from absl.testing import absltest
1819
from absl.testing import parameterized
1920
from etils import epath
@@ -26,6 +27,7 @@
2627

2728

2829
class TestHandler(checkpoint_handler.CheckpointHandler):
30+
2931
def save(self, directory: epath.Path, *args, **kwargs):
3032
pass
3133

@@ -34,7 +36,9 @@ def restore(self, directory: epath.Path, *args, **kwargs):
3436

3537

3638
class ParentHandler(checkpoint_handler.CheckpointHandler):
39+
3740
class TestHandler(checkpoint_handler.CheckpointHandler):
41+
3842
def save(self, directory: epath.Path, *args, **kwargs):
3943
pass
4044

@@ -43,6 +47,7 @@ def restore(self, directory: epath.Path, *args, **kwargs):
4347

4448

4549
class StandardCheckpointHandler(checkpoint_handler.CheckpointHandler):
50+
4651
def save(self, directory: epath.Path, *args, **kwargs):
4752
pass
4853

@@ -57,11 +62,16 @@ class ChildStandardCheckpointHandler(
5762

5863

5964
class TypestrOverrideHandler(checkpoint_handler.CheckpointHandler):
65+
6066
@classmethod
6167
def typestr(cls) -> str:
6268
return 'typestr_override'
6369

6470

71+
class NoTypestrHandler:
72+
pass
73+
74+
6575
class HandlerTypeRegistryTest(parameterized.TestCase):
6676

6777
def test_register_and_get(self):
@@ -78,13 +88,11 @@ def test_register_and_get(self):
7888
)
7989
self.assertTrue(
8090
'__main__.TestHandler' in registry._registry
81-
or
82-
'handler_type_registry_test.TestHandler' in registry._registry
91+
or 'handler_type_registry_test.TestHandler' in registry._registry
8392
)
8493
self.assertTrue(
8594
'__main__.ParentHandler.TestHandler' in registry._registry
86-
or
87-
'handler_type_registry_test.ParentHandler.TestHandler'
95+
or 'handler_type_registry_test.ParentHandler.TestHandler'
8896
in registry._registry
8997
)
9098

@@ -97,7 +105,7 @@ def test_register_different_modules(self):
97105
)
98106
registry.add(
99107
standard_checkpoint_handler.StandardCheckpointHandler.typestr(),
100-
standard_checkpoint_handler.StandardCheckpointHandler
108+
standard_checkpoint_handler.StandardCheckpointHandler,
101109
)
102110
self.assertEqual(
103111
registry.get(
@@ -107,14 +115,13 @@ def test_register_different_modules(self):
107115
)
108116
self.assertTrue(
109117
'__main__.StandardCheckpointHandler' in registry._registry
110-
or
111-
'handler_type_registry_test.StandardCheckpointHandler'
118+
or 'handler_type_registry_test.StandardCheckpointHandler'
112119
in registry._registry
113120
)
114121
self.assertIn(
115122
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.'
116123
'StandardCheckpointHandler',
117-
registry._registry
124+
registry._registry,
118125
)
119126

120127
def test_register_duplicate_handler_type(self):
@@ -133,7 +140,7 @@ def test_register_duplicate_handler_type(self):
133140
r'<class \'(?:__main__|handler_type_registry_test)\.TestHandler\'>. '
134141
'Cannot add type '
135142
r'<class \'(?:__main__|handler_type_registry_test)\.'
136-
'ParentHandler.TestHandler\'>.',
143+
"ParentHandler.TestHandler'>.",
137144
):
138145
registry.add(TestHandler.typestr(), ParentHandler.TestHandler)
139146

@@ -151,11 +158,10 @@ def test_register_subclass_handler_type(self):
151158
registry = HandlerTypeRegistry()
152159
registry.add(
153160
standard_checkpoint_handler.StandardCheckpointHandler.typestr(),
154-
standard_checkpoint_handler.StandardCheckpointHandler
161+
standard_checkpoint_handler.StandardCheckpointHandler,
155162
)
156163
registry.add(
157-
ChildStandardCheckpointHandler.typestr(),
158-
ChildStandardCheckpointHandler
164+
ChildStandardCheckpointHandler.typestr(), ChildStandardCheckpointHandler
159165
)
160166
self.assertEqual(
161167
registry.get(
@@ -176,5 +182,23 @@ def test_typestr_override(self):
176182
TypestrOverrideHandler,
177183
)
178184

185+
def test_no_typestr(self):
186+
backup = copy.deepcopy(handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY)
187+
try:
188+
# Clear the global registry to avoid side effects from other tests.
189+
handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY._registry.clear()
190+
191+
handler_type_registry.register_handler_type(NoTypestrHandler)
192+
registry = handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY._registry
193+
194+
expected_registry0 = {
195+
'handler_type_registry_test.NoTypestrHandler': NoTypestrHandler
196+
}
197+
expected_registry1 = {'__main__.NoTypestrHandler': NoTypestrHandler}
198+
self.assertIn(registry, [expected_registry0, expected_registry1])
199+
finally:
200+
handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY = backup
201+
202+
179203
if __name__ == '__main__':
180204
absltest.main()

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
19-
__version__ = '0.10.3'
19+
__version__ = '0.11.0'
2020

2121

2222
# TODO: b/362813406 - Add latest change timestamp and commit number.

docs/guides/checkpoint/orbax_checkpoint_announcements.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Announcements
22

3+
## 2024-12-30
4+
orbax-checkpoint version `0.10.3` and
5+
[grain](https://pypi.org/project/grain/) version `0.2.2` are not compatible.
6+
Either upgrade `grain>=0.2.3` or `orbax-checkpoint>=0.11.0`. Please see
7+
https://github.com/google/orbax/issues/1456 for error details.
8+
39
## 2024-10-25
410
A new option, `strict` has been added to `ArrayRestoreArgs` (and will be
511
present in the next version release). The option defaults to True. This

0 commit comments

Comments
 (0)