Skip to content

Commit 5147104

Browse files
angel-coreOrbax Authors
authored andcommitted
Add DeletionOptions to Orbax v1 Context.
PiperOrigin-RevId: 886843908
1 parent bfb08c9 commit 5147104

File tree

5 files changed

+63
-2
lines changed

5 files changed

+63
-2
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- #v1 Add `use_load_and_broadcast` option.
1313
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.
14+
- #v1 Add `DeletionOptions` to configure V1 Checkpointer's checkpoint deletion
15+
behavior.
1416

1517
### Removed
1618

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
checkpointables_options: options_lib.CheckpointablesOptions | None = None,
116116
pathways_options: options_lib.PathwaysOptions | None = None,
117117
checkpoint_layout: options_lib.CheckpointLayout | None = None,
118+
deletion_options: options_lib.DeletionOptions | None = None,
118119
):
119120
self._pytree_options = pytree_options or (
120121
context.pytree_options if context else options_lib.PyTreeOptions()
@@ -146,6 +147,9 @@ def __init__(
146147
if context
147148
else options_lib.CheckpointLayout.ORBAX
148149
)
150+
self._deletion_options = deletion_options or (
151+
context.deletion_options if context else options_lib.DeletionOptions()
152+
)
149153

150154
@property
151155
def pytree_options(self) -> options_lib.PyTreeOptions:
@@ -179,6 +183,10 @@ def pathways_options(self) -> options_lib.PathwaysOptions:
179183
def checkpoint_layout(self) -> options_lib.CheckpointLayout:
180184
return self._checkpoint_layout
181185

186+
@property
187+
def deletion_options(self) -> options_lib.DeletionOptions:
188+
return self._deletion_options
189+
182190
def operation_id(self) -> str:
183191
return synchronization.OperationIdGenerator.get_current_operation_id()
184192

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,45 @@ class PathwaysOptions:
494494
checkpointing_impl: pathways_types.CheckpointingImpl | None = None
495495

496496

497+
@dataclasses.dataclass(frozen=True, kw_only=True)
498+
class DeletionOptions:
499+
"""Options used to configure checkpoint deletion behavior.
500+
501+
Attributes:
502+
gcs_deletion_options: Deletion options specific to GCS.
503+
"""
504+
505+
@dataclasses.dataclass(frozen=True, kw_only=True)
506+
class GcsDeletionOptions:
507+
"""Deletion options specific to GCS.
508+
509+
Attributes:
510+
todelete_full_path: Specifies a path relative to the bucket root for
511+
"soft-deleting" checkpoints on Google Cloud Storage (GCS). Instead of
512+
being permanently removed, checkpoints are moved to this new location
513+
within the same bucket. This is useful if direct deletion on GCS is
514+
time-consuming, as it allows an external component to
515+
manage the actual removal.
516+
517+
This option gathers all "deleted" items in a centralized path at the
518+
bucket level for future cleanup.
519+
520+
For instance, if a checkpoint is in
521+
gs://my-bucket/experiments/run1/, providing the value 'trash' will move
522+
a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct
523+
deletion is time consuming. It gathers all deleted items in a
524+
centralized path for future cleanup.
525+
"""
526+
527+
todelete_full_path: str | None = None
528+
529+
530+
gcs_deletion_options: GcsDeletionOptions = dataclasses.field(
531+
default_factory=GcsDeletionOptions
532+
)
533+
534+
535+
497536
class CheckpointLayout(enum.Enum):
498537
"""The layout of the checkpoint.
499538

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def __init__(
170170
preservation_policy=preservation_policy,
171171
step_name_format=step_name_format,
172172
max_to_keep=None, # Unlimited.
173-
# TODO(b/401541834) Configure todelete_subdir.
174-
# TODO(b/401541834) Enable background deletion.
173+
todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path,
175174
async_options=context.async_options.v0(),
176175
file_options=context.file_options.v0(),
177176
multiprocessing_options=context.multiprocessing_options.v0(),

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,16 @@ def test_preservation_metrics(self, policy, expected_steps):
669669
[all_metrics[step] for step in expected_steps],
670670
)
671671
checkpointer.close()
672+
673+
def test_gcs_deletion_options(self):
674+
deletion_options = ocp.options.DeletionOptions(
675+
gcs_deletion_options=ocp.options.DeletionOptions.GcsDeletionOptions(
676+
todelete_full_path='gs://bucket/trash'
677+
)
678+
)
679+
self.enter_context(ocp.Context(deletion_options=deletion_options))
680+
checkpointer = Checkpointer(self.directory)
681+
self.assertEqual(
682+
checkpointer._manager._options.todelete_full_path, 'gs://bucket/trash'
683+
)
684+

0 commit comments

Comments
 (0)