1414
1515"""Tests for CheckpointerHandler type registry."""
1616
17+ import copy
1718from absl .testing import absltest
1819from absl .testing import parameterized
1920from etils import epath
2627
2728
2829class TestHandler (checkpoint_handler .CheckpointHandler ):
30+
2931 def save (self , directory : epath .Path , * args , ** kwargs ):
3032 pass
3133
@@ -34,7 +36,9 @@ def restore(self, directory: epath.Path, *args, **kwargs):
3436
3537
3638class ParentHandler (checkpoint_handler .CheckpointHandler ):
39+
3740 class TestHandler (checkpoint_handler .CheckpointHandler ):
41+
3842 def save (self , directory : epath .Path , * args , ** kwargs ):
3943 pass
4044
@@ -43,6 +47,7 @@ def restore(self, directory: epath.Path, *args, **kwargs):
4347
4448
4549class StandardCheckpointHandler (checkpoint_handler .CheckpointHandler ):
50+
4651 def save (self , directory : epath .Path , * args , ** kwargs ):
4752 pass
4853
@@ -57,11 +62,16 @@ class ChildStandardCheckpointHandler(
5762
5863
5964class TypestrOverrideHandler (checkpoint_handler .CheckpointHandler ):
65+
6066 @classmethod
6167 def typestr (cls ) -> str :
6268 return 'typestr_override'
6369
6470
71+ class NoTypestrHandler :
72+ pass
73+
74+
6575class HandlerTypeRegistryTest (parameterized .TestCase ):
6676
6777 def test_register_and_get (self ):
@@ -78,13 +88,11 @@ def test_register_and_get(self):
7888 )
7989 self .assertTrue (
8090 '__main__.TestHandler' in registry ._registry
81- or
82- 'handler_type_registry_test.TestHandler' in registry ._registry
91+ or 'handler_type_registry_test.TestHandler' in registry ._registry
8392 )
8493 self .assertTrue (
8594 '__main__.ParentHandler.TestHandler' in registry ._registry
86- or
87- 'handler_type_registry_test.ParentHandler.TestHandler'
95+ or 'handler_type_registry_test.ParentHandler.TestHandler'
8896 in registry ._registry
8997 )
9098
@@ -97,7 +105,7 @@ def test_register_different_modules(self):
97105 )
98106 registry .add (
99107 standard_checkpoint_handler .StandardCheckpointHandler .typestr (),
100- standard_checkpoint_handler .StandardCheckpointHandler
108+ standard_checkpoint_handler .StandardCheckpointHandler ,
101109 )
102110 self .assertEqual (
103111 registry .get (
@@ -107,14 +115,13 @@ def test_register_different_modules(self):
107115 )
108116 self .assertTrue (
109117 '__main__.StandardCheckpointHandler' in registry ._registry
110- or
111- 'handler_type_registry_test.StandardCheckpointHandler'
118+ or 'handler_type_registry_test.StandardCheckpointHandler'
112119 in registry ._registry
113120 )
114121 self .assertIn (
115122 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.'
116123 'StandardCheckpointHandler' ,
117- registry ._registry
124+ registry ._registry ,
118125 )
119126
120127 def test_register_duplicate_handler_type (self ):
@@ -133,7 +140,7 @@ def test_register_duplicate_handler_type(self):
133140 r'<class \'(?:__main__|handler_type_registry_test)\.TestHandler\'>. '
134141 'Cannot add type '
135142 r'<class \'(?:__main__|handler_type_registry_test)\.'
136- ' ParentHandler.TestHandler\ ' >.' ,
143+ " ParentHandler.TestHandler'>." ,
137144 ):
138145 registry .add (TestHandler .typestr (), ParentHandler .TestHandler )
139146
@@ -151,11 +158,10 @@ def test_register_subclass_handler_type(self):
151158 registry = HandlerTypeRegistry ()
152159 registry .add (
153160 standard_checkpoint_handler .StandardCheckpointHandler .typestr (),
154- standard_checkpoint_handler .StandardCheckpointHandler
161+ standard_checkpoint_handler .StandardCheckpointHandler ,
155162 )
156163 registry .add (
157- ChildStandardCheckpointHandler .typestr (),
158- ChildStandardCheckpointHandler
164+ ChildStandardCheckpointHandler .typestr (), ChildStandardCheckpointHandler
159165 )
160166 self .assertEqual (
161167 registry .get (
@@ -176,5 +182,23 @@ def test_typestr_override(self):
176182 TypestrOverrideHandler ,
177183 )
178184
185+ def test_no_typestr (self ):
186+ backup = copy .deepcopy (handler_type_registry ._GLOBAL_HANDLER_TYPE_REGISTRY )
187+ try :
188+ # Clear the global registry to avoid side effects from other tests.
189+ handler_type_registry ._GLOBAL_HANDLER_TYPE_REGISTRY ._registry .clear ()
190+
191+ handler_type_registry .register_handler_type (NoTypestrHandler )
192+ registry = handler_type_registry ._GLOBAL_HANDLER_TYPE_REGISTRY ._registry
193+
194+ expected_registry0 = {
195+ 'handler_type_registry_test.NoTypestrHandler' : NoTypestrHandler
196+ }
197+ expected_registry1 = {'__main__.NoTypestrHandler' : NoTypestrHandler }
198+ self .assertIn (registry , [expected_registry0 , expected_registry1 ])
199+ finally :
200+ handler_type_registry ._GLOBAL_HANDLER_TYPE_REGISTRY = backup
201+
202+
179203if __name__ == '__main__' :
180204 absltest .main ()
0 commit comments