Skip to content

Commit 1f0d7b0

Browse files
cpgaffney1copybara-github
authored andcommitted
Add support for tracking checkpoint metrics with Orbax in T5X.
PiperOrigin-RevId: 492213818
1 parent 42a9dcc commit 1f0d7b0

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

orbax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
"""Orbax API."""
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
18-
__version__ = '0.0.18'
18+
__version__ = '0.0.19'

orbax/checkpoint/checkpoint_manager.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class CheckpointManagerOptions:
7878
function.
7979
best_mode: one of ['max', 'min']. The best metric is determine on the basis of
8080
this value.
81+
keep_checkpoints_without_metrics: If False, checkpoints with metrics present
82+
are eligible for cleanup. Otherwise, they will never be deleted.
8183
step_prefix: if provided, step directories will take the form
8284
f'{step_prefix}_<step>'. Otherwise, they will simply be an integer <step>.
8385
@@ -88,8 +90,15 @@ class CheckpointManagerOptions:
8890
keep_period: Optional[int] = None
8991
best_fn: Optional[Callable[[PyTree], float]] = None
9092
best_mode: str = 'max'
93+
keep_checkpoints_without_metrics: bool = True
9194
step_prefix: Optional[str] = None
9295

96+
def __post_init__(self):
97+
if self.best_mode not in ('min', 'max'):
98+
msg = ("`CheckpointManagerOptions.best_mode` must be one of None, 'min' "
99+
"or 'max'. Got {self.dtype}.")
100+
raise ValueError(msg)
101+
93102

94103
@dataclasses.dataclass
95104
class CheckpointInfo:
@@ -213,6 +222,8 @@ def best_step(self) -> Optional[int]:
213222
if not self._checkpoints:
214223
return None
215224
_, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints)
225+
if not sorted_checkpoints:
226+
return None
216227
return sorted_checkpoints[-1].step
217228

218229
def should_save(self, step: int) -> bool:
@@ -584,7 +595,7 @@ def get_metrics(step):
584595
for s, t, m in zip(steps, times, metrics)
585596
]
586597

587-
def _add_checkpoint_info(self, step, metrics):
598+
def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
588599
self._checkpoints.append(
589600
CheckpointInfo(step, datetime.datetime.now(tz=datetime.timezone.utc),
590601
metrics))
@@ -636,8 +647,12 @@ def _delete_directory(self, step: int):
636647

637648
def _remove_old_checkpoints(self):
638649
"""Keeps the `max_to_keep` most recent checkpoint steps."""
650+
# Must have set max_to_keep or keep_time_interval.
639651
if not self._options.max_to_keep and not self._options.keep_time_interval:
640652
return
653+
# Not enough checkpoints accumulated to consider deletion.
654+
if len(self._checkpoints) <= self._options.max_to_keep:
655+
return
641656
if self._track_best:
642657
# Best steps (to keep) are at the end, after sorting.
643658
checkpoints_without_metrics, sorted_checkpoints = self._sort_checkpoints_by_metrics(
@@ -647,12 +662,15 @@ def _remove_old_checkpoints(self):
647662
checkpoints_without_metrics = []
648663
sorted_checkpoints = self._checkpoints
649664

650-
to_remove = len(sorted_checkpoints) - self._options.max_to_keep
651-
if to_remove <= 0:
652-
return
653-
maybe_delete = sorted_checkpoints[:to_remove]
654-
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
655-
to_remove:]
665+
keep = int(self._options.max_to_keep)
666+
if self._options.keep_checkpoints_without_metrics:
667+
maybe_delete = sorted_checkpoints[:-keep]
668+
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
669+
-keep:]
670+
else:
671+
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
672+
maybe_delete = all_checkpoints[:-keep]
673+
active_checkpoints = all_checkpoints[-keep:]
656674

657675
kept_checkpoints = []
658676
for info in maybe_delete:

0 commit comments

Comments
 (0)