@@ -78,6 +78,8 @@ class CheckpointManagerOptions:
7878 function.
7979 best_mode: one of ['max', 'min']. The best metric is determine on the basis of
8080 this value.
81+ keep_checkpoints_without_metrics: If False, checkpoints with metrics present
82+ are eligible for cleanup. Otherwise, they will never be deleted.
8183 step_prefix: if provided, step directories will take the form
8284 f'{step_prefix}_<step>'. Otherwise, they will simply be an integer <step>.
8385
@@ -88,8 +90,15 @@ class CheckpointManagerOptions:
8890 keep_period : Optional [int ] = None
8991 best_fn : Optional [Callable [[PyTree ], float ]] = None
9092 best_mode : str = 'max'
93+ keep_checkpoints_without_metrics : bool = True
9194 step_prefix : Optional [str ] = None
9295
96+ def __post_init__ (self ):
97+ if self .best_mode not in ('min' , 'max' ):
98+ msg = ("`CheckpointManagerOptions.best_mode` must be one of None, 'min' "
99+ "or 'max'. Got {self.dtype}." )
100+ raise ValueError (msg )
101+
93102
94103@dataclasses .dataclass
95104class CheckpointInfo :
@@ -213,6 +222,8 @@ def best_step(self) -> Optional[int]:
213222 if not self ._checkpoints :
214223 return None
215224 _ , sorted_checkpoints = self ._sort_checkpoints_by_metrics (self ._checkpoints )
225+ if not sorted_checkpoints :
226+ return None
216227 return sorted_checkpoints [- 1 ].step
217228
218229 def should_save (self , step : int ) -> bool :
@@ -584,7 +595,7 @@ def get_metrics(step):
584595 for s , t , m in zip (steps , times , metrics )
585596 ]
586597
587- def _add_checkpoint_info (self , step , metrics ):
598+ def _add_checkpoint_info (self , step : int , metrics : Optional [ PyTree ] ):
588599 self ._checkpoints .append (
589600 CheckpointInfo (step , datetime .datetime .now (tz = datetime .timezone .utc ),
590601 metrics ))
@@ -636,8 +647,12 @@ def _delete_directory(self, step: int):
636647
637648 def _remove_old_checkpoints (self ):
638649 """Keeps the `max_to_keep` most recent checkpoint steps."""
650+ # Must have set max_to_keep or keep_time_interval.
639651 if not self ._options .max_to_keep and not self ._options .keep_time_interval :
640652 return
653+ # Not enough checkpoints accumulated to consider deletion.
654+ if len (self ._checkpoints ) <= self ._options .max_to_keep :
655+ return
641656 if self ._track_best :
642657 # Best steps (to keep) are at the end, after sorting.
643658 checkpoints_without_metrics , sorted_checkpoints = self ._sort_checkpoints_by_metrics (
@@ -647,12 +662,15 @@ def _remove_old_checkpoints(self):
647662 checkpoints_without_metrics = []
648663 sorted_checkpoints = self ._checkpoints
649664
650- to_remove = len (sorted_checkpoints ) - self ._options .max_to_keep
651- if to_remove <= 0 :
652- return
653- maybe_delete = sorted_checkpoints [:to_remove ]
654- active_checkpoints = checkpoints_without_metrics + sorted_checkpoints [
655- to_remove :]
665+ keep = int (self ._options .max_to_keep )
666+ if self ._options .keep_checkpoints_without_metrics :
667+ maybe_delete = sorted_checkpoints [:- keep ]
668+ active_checkpoints = checkpoints_without_metrics + sorted_checkpoints [
669+ - keep :]
670+ else :
671+ all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
672+ maybe_delete = all_checkpoints [:- keep ]
673+ active_checkpoints = all_checkpoints [- keep :]
656674
657675 kept_checkpoints = []
658676 for info in maybe_delete :
0 commit comments