Skip to content

Commit 1a7b2e9

Browse files
angel-coreOrbax Authors
authored andcommitted
Expose step from orbax.checkpoint.experimental.v1.path.
PiperOrigin-RevId: 889348008
1 parent bfb08c9 commit 1a7b2e9

File tree

7 files changed

+115
-7
lines changed

7 files changed

+115
-7
lines changed

checkpoint/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- #v1 Add `use_load_and_broadcast` option.
1313
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.
14+
- #v1 Add `DeletionOptions` to configure V1 Checkpointer's checkpoint deletion
15+
behavior.
16+
- #v1 Add `cleanup_tmp_directories` setting to V1 checkpointer's `FileOptions`
17+
to manage temporary directory cleanup behavior.
18+
- #v1 Add `lightweight_initialize` which allows users to specify whether
19+
temporary directories should be cleaned up upon Checkpointer initialization
1420

1521
### Removed
1622

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,11 @@ class CheckpointManagerOptions:
364364
supposed to be created per process. This is used to support async
365365
directory creation. If True, `multiprocessing_options.primary_host` must be
366366
None.
367-
lightweight_initialize: If True, checkpoint step metadata is not
368-
read on
367+
lightweight_initialize: If True, checkpoint step metadata is not read on
369368
CheckpointManager initialization during checkpoint info loading. This is
370-
useful to improve init performance
371-
when there are O(1k) or more existing checkpoint step present and checkpoint
372-
info properties like `time` and `metrics` are not needed.
369+
useful to improve init performance when there are O(1k) or more existing
370+
checkpoint step present and checkpoint info properties like `time` and
371+
`metrics` are not needed.
373372
"""
374373

375374
save_interval_steps: int = 1

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
checkpointables_options: options_lib.CheckpointablesOptions | None = None,
116116
pathways_options: options_lib.PathwaysOptions | None = None,
117117
checkpoint_layout: options_lib.CheckpointLayout | None = None,
118+
deletion_options: options_lib.DeletionOptions | None = None,
118119
):
119120
self._pytree_options = pytree_options or (
120121
context.pytree_options if context else options_lib.PyTreeOptions()
@@ -146,6 +147,9 @@ def __init__(
146147
if context
147148
else options_lib.CheckpointLayout.ORBAX
148149
)
150+
self._deletion_options = deletion_options or (
151+
context.deletion_options if context else options_lib.DeletionOptions()
152+
)
149153

150154
@property
151155
def pytree_options(self) -> options_lib.PyTreeOptions:
@@ -179,6 +183,10 @@ def pathways_options(self) -> options_lib.PathwaysOptions:
179183
def checkpoint_layout(self) -> options_lib.CheckpointLayout:
180184
return self._checkpoint_layout
181185

186+
@property
187+
def deletion_options(self) -> options_lib.DeletionOptions:
188+
return self._deletion_options
189+
182190
def operation_id(self) -> str:
183191
return synchronization.OperationIdGenerator.get_current_operation_id()
184192

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,45 @@ class PathwaysOptions:
494494
checkpointing_impl: pathways_types.CheckpointingImpl | None = None
495495

496496

497+
@dataclasses.dataclass(frozen=True, kw_only=True)
498+
class DeletionOptions:
499+
"""Options used to configure checkpoint deletion behavior.
500+
501+
Attributes:
502+
gcs_deletion_options: Deletion options specific to GCS.
503+
"""
504+
505+
@dataclasses.dataclass(frozen=True, kw_only=True)
506+
class GcsDeletionOptions:
507+
"""Deletion options specific to GCS.
508+
509+
Attributes:
510+
todelete_full_path: Specifies a path relative to the bucket root for
511+
"soft-deleting" checkpoints on Google Cloud Storage (GCS). Instead of
512+
being permanently removed, checkpoints are moved to this new location
513+
within the same bucket. This is useful if direct deletion on GCS is
514+
time-consuming, as it allows an external component to
515+
manage the actual removal.
516+
517+
This option gathers all "deleted" items in a centralized path at the
518+
bucket level for future cleanup.
519+
520+
For instance, if a checkpoint is in
521+
gs://my-bucket/experiments/run1/, providing the value 'trash' will move
522+
a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct
523+
deletion is time consuming. It gathers all deleted items in a
524+
centralized path for future cleanup.
525+
"""
526+
527+
todelete_full_path: str | None = None
528+
529+
530+
gcs_deletion_options: GcsDeletionOptions = dataclasses.field(
531+
default_factory=GcsDeletionOptions
532+
)
533+
534+
535+
497536
class CheckpointLayout(enum.Enum):
498537
"""The layout of the checkpoint.
499538

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
path_step_lib.NameFormat[path_step_lib.Metadata] | None
8888
) = None,
8989
custom_metadata: tree_types.JsonType | None = None,
90+
cleanup_tmp_directories: bool = False,
91+
lightweight_initialize: bool = False,
9092
):
9193
"""Initializes a Checkpointer.
9294
@@ -150,6 +152,13 @@ def __init__(
150152
custom_metadata: A JSON dictionary representing user-specified custom
151153
metadata. This should be information that is relevant to the entire
152154
sequence of checkpoints, rather than to any single checkpoint.
155+
cleanup_tmp_directories: If True, cleans up any existing temporary
156+
directories on Checkpointer creation.
157+
lightweight_initialize: If True, checkpoint step metadata is not read on
158+
Checkpointer initialization during checkpoint info loading. This is
159+
useful to improve init performance when there are O(1k) or more existing
160+
checkpoint step present and checkpoint info properties like `time` and
161+
`metrics` are not needed.
153162
"""
154163
context = context_lib.get_context()
155164

@@ -169,9 +178,10 @@ def __init__(
169178
save_decision_policy=save_decision_policy,
170179
preservation_policy=preservation_policy,
171180
step_name_format=step_name_format,
181+
cleanup_tmp_directories=cleanup_tmp_directories,
182+
lightweight_initialize=lightweight_initialize,
172183
max_to_keep=None, # Unlimited.
173-
# TODO(b/401541834) Configure todelete_subdir.
174-
# TODO(b/401541834) Enable background deletion.
184+
todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path,
175185
async_options=context.async_options.v0(),
176186
file_options=context.file_options.v0(),
177187
multiprocessing_options=context.multiprocessing_options.v0(),

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,48 @@ def test_preservation_metrics(self, policy, expected_steps):
669669
[all_metrics[step] for step in expected_steps],
670670
)
671671
checkpointer.close()
672+
673+
def test_gcs_deletion_options(self):
674+
deletion_options = ocp.options.DeletionOptions(
675+
gcs_deletion_options=ocp.options.DeletionOptions.GcsDeletionOptions(
676+
todelete_full_path='gs://bucket/trash'
677+
)
678+
)
679+
self.enter_context(ocp.Context(deletion_options=deletion_options))
680+
checkpointer = Checkpointer(self.directory)
681+
self.assertEqual(
682+
checkpointer._manager._options.todelete_full_path, 'gs://bucket/trash'
683+
)
684+
685+
686+
@parameterized.parameters(
687+
(True, True),
688+
(False, False),
689+
)
690+
def test_cleanup_tmp_directories(
691+
self, cleanup_tmp_directories, expected_cleanup_tmp_directories
692+
):
693+
checkpointer = Checkpointer(
694+
self.directory, cleanup_tmp_directories=cleanup_tmp_directories
695+
)
696+
self.assertEqual(
697+
checkpointer._manager._options.cleanup_tmp_directories,
698+
expected_cleanup_tmp_directories,
699+
)
700+
checkpointer.close()
701+
702+
@parameterized.parameters(
703+
(True, True),
704+
(False, False),
705+
)
706+
def test_lightweight_initialize(
707+
self, lightweight_initialize, expected_lightweight_initialize
708+
):
709+
checkpointer = Checkpointer(
710+
self.directory, lightweight_initialize=lightweight_initialize
711+
)
712+
self.assertEqual(
713+
checkpointer._manager._options.lightweight_initialize,
714+
expected_lightweight_initialize,
715+
)
716+
checkpointer.close()

checkpoint/orbax/checkpoint/experimental/v1/path.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
PathLike,
2222
PathAwaitingCreation,
2323
)
24+
from orbax.checkpoint.experimental.v1._src.path import step

0 commit comments

Comments
 (0)