Skip to content

Commit 08dd65b

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Ensure AsyncCheckpointer completion logs on every host instead of just the leader.
PiperOrigin-RevId: 724466022
1 parent a1718d2 commit 08dd65b

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

checkpoint/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.4] - 2025-02-07
11+
1012
### Changed
1113

1214
- Updated orbax-checkpoint PyPI package to exclude tests.
1315

16+
### Fixed
17+
18+
- `AsyncCheckpointer` completion logging, to log on all hosts instead of just
19+
the leader.
20+
1421
## [0.11.3] - 2025-02-06
1522

1623
### Changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ def _on_commit_callback(
5858
'/jax/checkpoint/write/async/total_duration_secs',
5959
total_duration_secs,
6060
)
61-
logging.info(
62-
'Finished asynchronous save in %.2f seconds to %s',
63-
total_duration_secs,
64-
tmpdir.get_final(),
65-
)
6661

6762

6863
def _add_deadline_exceeded_notes(e: jax.errors.JaxRuntimeError):
@@ -405,7 +400,8 @@ def _callback() -> None:
405400
# Update StepMetadata after the handler save is complete.
406401
# (blocking write)
407402
self._save_step_metadata(tmpdir.get(), custom_metadata=custom_metadata)
408-
logging.info(
403+
logging.vlog(
404+
1,
409405
'[process=%s][thread=%s] Async Save Callback [1/3]: Finalizing'
410406
' Handler: %s on %s',
411407
multihost.process_index(),
@@ -415,7 +411,8 @@ def _callback() -> None:
415411
)
416412
# Finalize does a final StepMetadata update.
417413
self._handler.finalize(tmpdir.get())
418-
logging.info(
414+
logging.vlog(
415+
1,
419416
'[process=%s][thread=%s] Async Save Callback [2/3]: Running'
420417
' post_finalization_callback: %s on %s',
421418
multihost.process_index(),
@@ -425,7 +422,8 @@ def _callback() -> None:
425422
)
426423
if self._post_finalization_callback is not None:
427424
self._post_finalization_callback()
428-
logging.info(
425+
logging.vlog(
426+
1,
429427
'[process=%s][thread=%s] Async Save Callback [3/3]: Finalizing'
430428
' checkpoint directory: %s',
431429
multihost.process_index(),
@@ -436,6 +434,12 @@ def _callback() -> None:
436434
tmpdir,
437435
checkpoint_start_time,
438436
)
437+
logging.info(
438+
'Finished asynchronous save (blocking + background) in %.2f seconds'
439+
' to %s',
440+
time.time() - checkpoint_start_time,
441+
directory,
442+
)
439443

440444
self._async_manager.start_async_commit(
441445
directory,

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
19-
__version__ = '0.11.3'
19+
__version__ = '0.11.4'
2020

2121

2222
# TODO: b/362813406 - Add latest change timestamp and commit number.

0 commit comments

Comments
 (0)