Skip to content

Commit 5db4d7a

Browse files
cpgaffney1copybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 459553867
1 parent 47d27b5 commit 5db4d7a

15 files changed

+344
-333
lines changed

orbax/checkpoint/__init__.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
"""Defines exported symbols for the `orbax` package."""
1616

1717
from .abstract_checkpoint_manager import AbstractCheckpointManager
18+
from .checkpoint_handler import CheckpointHandler
1819
from .checkpoint_manager import CheckpointManager
1920
from .checkpoint_manager import CheckpointManagerOptions
20-
from .checkpointer import Checkpointer
21-
from .dataset_checkpointer import DatasetCheckpointer
22-
from .json_checkpointer import JsonCheckpointer
21+
from .dataset_checkpoint_handler import DatasetCheckpointHandler
22+
from .json_checkpoint_handler import JsonCheckpointHandler
2323
from orbax.checkpoint import lazy_array
24-
from .pytree_checkpointer import PyTreeCheckpointer
25-
from .pytree_checkpointer import RestoreArgs
26-
from .pytree_checkpointer import SaveArgs
24+
from .pytree_checkpoint_handler import PyTreeCheckpointHandler
25+
from .pytree_checkpoint_handler import RestoreArgs
26+
from .pytree_checkpoint_handler import SaveArgs
2727
from .transform_utils import apply_transformations
2828
from .transform_utils import Transform
2929
from .utils import checkpoints_iterator
30-
31-
# TODO(cpgaffney) Remove when handler classes are fully rolled out.
32-
CheckpointHandler = Checkpointer
33-
PyTreeCheckpointHandler = PyTreeCheckpointer

orbax/checkpoint/abstract_checkpoint_manager.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@
2222

2323

2424
class AbstractCheckpointManager(abc.ABC):
25-
"""An interface that manages multiple Checkpointer classes.
25+
"""An interface that manages multiple CheckpointHandler classes.
2626
2727
CheckpointManager coordinates save/restore operations across multiple
28-
Checkpointer classes, and also provides useful methods describing checkpoint
28+
CheckpointHandler classes, and also provides useful methods describing
29+
checkpoint
2930
states.
3031
3132
For example, CheckpointManager may be responsible for managing a parameter
3233
state in the form of a PyTree and a dataset iterator state in the form of
3334
tf.data.Iterator.
3435
35-
Each item should be handled by a separate Checkpointer.
36+
Each item should be handled by a separate CheckpointHandler.
3637
37-
For instance, item "a" is handled by Checkpointer A, while item "b" is handled
38-
by Checkpointer B.
38+
For instance, item "a" is handled by CheckpointHandler A, while item "b" is
39+
handled
40+
by CheckpointHandler B.
3941
"""
4042

4143
@abc.abstractmethod
@@ -54,27 +56,28 @@ def save(
5456
...
5557
}
5658
Each of these values is a saveable item that should be written with a
57-
specific Checkpointer.
59+
specific CheckpointHandler.
5860
5961
Similarly, save_kwargs takes the form:
6062
{
6163
'params': {
62-
<kwargs for PyTreeCheckpointer.save>
64+
<kwargs for PyTreeCheckpointHandler.save>
6365
},
6466
'dataset': {
65-
<kwargs for DatasetCheckpointer.save>
67+
<kwargs for DatasetCheckpointHandler.save>
6668
}
6769
...
6870
}
6971
The dict of kwargs for each key in save_kwargs is provided as extra
70-
arguments to the save method of the corresponding Checkpointer.
72+
arguments to the save method of the corresponding CheckpointHandler.
7173
7274
Args:
7375
step: current step, int
7476
items: a savable object, or a dictionary of object name to savable object.
75-
save_kwargs: save kwargs for a single Checkpointer, or a dictionary of
76-
object name to kwargs needed by the Checkpointer implementation to save
77-
the object.
77+
save_kwargs: save kwargs for a single CheckpointHandler, or a dictionary
78+
of object name to kwargs needed by the CheckpointHandler implementation
79+
to save the object.
80+
7881
Returns:
7982
bool indicating whether save was performed or not.
8083
"""
@@ -97,29 +100,30 @@ def restore(
97100
...
98101
}
99102
Each of these values is a restoreable item that should be read with a
100-
specific Checkpointer. Implementations should support items=None, and the
103+
specific CheckpointHandler. Implementations should support items=None, and
104+
the
101105
ability to restore an item which is not provided in this dict.
102106
103107
Similarly, restore_kwargs takes the form:
104108
{
105109
'params': {
106-
<kwargs for PyTreeCheckpointer.restore>
110+
<kwargs for PyTreeCheckpointHandler.restore>
107111
},
108112
'dataset': {
109-
<kwargs for DatasetCheckpointer.restore>
113+
<kwargs for DatasetCheckpointHandler.restore>
110114
}
111115
...
112116
}
113117
The dict of kwargs for each key in restore_kwargs is provided as extra
114-
arguments to the restore method of the corresponding Checkpointer.
118+
arguments to the restore method of the corresponding CheckpointHandler.
115119
116120
Args:
117121
step: current step, int
118122
items: a restoreable object, or a dictionary of object name to restoreable
119123
object.
120-
restore_kwargs: restore kwargs for a single Checkpointer, or a dictionary
121-
of object name to kwargs needed by the Checkpointer implementation to
122-
restore the object.
124+
restore_kwargs: restore kwargs for a single CheckpointHandler, or a
125+
dictionary of object name to kwargs needed by the CheckpointHandler
126+
implementation to restore the object.
123127
124128
Returns:
125129
A dictionary mapping name to restored object, or a single restored object.
@@ -128,13 +132,15 @@ def restore(
128132

129133
@abc.abstractmethod
130134
def structure(self) -> Union[Any, Mapping[str, Any]]:
131-
"""For all Checkpointers, returns the saved structure.
135+
"""For all CheckpointHandlers, returns the saved structure.
132136
133-
Calls the `structure` method for each Checkpointer and returns a mapping of
137+
Calls the `structure` method for each CheckpointHandler and returns a
138+
mapping of
134139
each item name to the restored structure. If the manager only manages a
135140
single item, a single structure will be returned instead.
136141
137-
Note that any items for which the corresponding Checkpointer does not have
142+
Note that any items for which the corresponding CheckpointHandler does not
143+
have
138144
an implemented `structure` method, these items will simply not be contained
139145
in the result.
140146
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2022 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""AsyncCheckpointHandler interface."""
16+
17+
import abc
18+
import asyncio
19+
from typing import Any, Optional, Sequence
20+
21+
22+
class AsyncCheckpointHandler(abc.ABC):
23+
"""An interface providing async methods that can be used with CheckpointHandler."""
24+
25+
@abc.abstractmethod
26+
async def async_save(self, directory: str, item: Any, *args,
27+
**kwargs) -> Optional[Sequence[asyncio.Future]]:
28+
"""Constructs a save operation.
29+
30+
Synchronously awaits a copy of the item, before returning commit futures
31+
necessary to save the item.
32+
33+
Args:
34+
directory: the directory to save to.
35+
item: the item to be saved.
36+
*args: additional arguments for save.
37+
**kwargs: additional arguments for save.
38+
"""
39+
pass
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2022 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""CheckpointHandler interface."""
16+
17+
import abc
18+
from typing import Any, Optional
19+
20+
21+
class CheckpointHandler(abc.ABC):
22+
"""An interface providing save/restore methods used on a savable item.
23+
24+
Item may be a PyTree, Dataset, or any other supported object.
25+
"""
26+
27+
@abc.abstractmethod
28+
def save(self, directory: str, item: Any, *args, **kwargs):
29+
"""Saves the provided item synchronously.
30+
31+
Args:
32+
directory: the directory to save to.
33+
item: the item to be saved.
34+
*args: additional arguments for save.
35+
**kwargs: additional arguments for save.
36+
"""
37+
pass
38+
39+
@abc.abstractmethod
40+
def restore(self,
41+
directory: str,
42+
item: Optional[Any] = None,
43+
**kwargs) -> Any:
44+
"""Restores the provided item synchronously.
45+
46+
Args:
47+
directory: the directory to restore from.
48+
item: an item with the same structure as that to be restored.
49+
**kwargs: additional arguments for restore.
50+
51+
Returns:
52+
The restored item.
53+
"""
54+
pass
55+
56+
@abc.abstractmethod
57+
def structure(self, directory: str) -> Any:
58+
"""Returns the structure of the item.
59+
60+
This should represent the checkpointed item format without needing to fully
61+
restore the entire item and all its values.
62+
63+
Args:
64+
directory: the directory where the checkpoint is located.
65+
66+
Returns:
67+
item structure
68+
"""
69+
pass

0 commit comments

Comments
 (0)