Skip to content

Commit 7a3addb

Browse files
cpgaffney1copybara-github
authored andcommitted
Fix structure caching for enable_flax=False case by storing relative file names in structure rather than absolute paths.
PiperOrigin-RevId: 475845023
1 parent d6486f4 commit 7a3addb

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

orbax/checkpoint/pytree_checkpoint_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _get_param_info(leaf):
207207
tspec = None
208208
elif isinstance(leaf, dict):
209209
tspec = None
210-
elif isinstance(leaf, epath.Path):
210+
elif isinstance(leaf, utils.Leaf):
211211
# Leaf is a param name.
212212
path = os.fspath(directory / leaf)
213213
tspec = serialization.get_tensorstore_spec(path)

orbax/checkpoint/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,34 +82,34 @@ def rmtree(path: epath.Path):
8282
path.rmdir()
8383

8484

85-
Leaf = epath.Path
85+
Leaf = str
8686

8787

8888
def pytree_structure(directory: Path) -> PyTree:
8989
"""Reconstruct state dict from saved model format in `directory`."""
9090
directory = epath.Path(directory)
9191

92-
def add_nested_key(subtree, nested_key, full_name):
92+
def add_nested_key(subtree, nested_key, key_name):
9393
if not nested_key:
9494
return subtree
9595

9696
current = nested_key[0]
9797

9898
if len(nested_key) == 1:
9999
assert current not in subtree
100-
subtree[current] = full_name
100+
subtree[current] = key_name
101101
return subtree
102102

103103
subkeys = nested_key[1:]
104104
if current not in subtree:
105105
subtree[current] = {}
106-
subtree[current] = add_nested_key(subtree[current], subkeys, full_name)
106+
subtree[current] = add_nested_key(subtree[current], subkeys, key_name)
107107
return subtree
108108

109109
keys = directory.iterdir()
110110
tree = {}
111111
for k in keys:
112-
tree = add_nested_key(tree, k.name.split('.'), k)
112+
tree = add_nested_key(tree, k.name.split('.'), k.name)
113113
return tree
114114

115115

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
setuptools.setup(
2424
name='orbax',
25-
version='0.0.10',
25+
version='0.0.11',
2626
description='Orbax',
2727
long_description=_LONG_DESCRIPTION,
2828
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)