Skip to content

Commit 92c08b3

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Add monitoring for foreground operations in async checkpointing.
PiperOrigin-RevId: 893037701
1 parent 09d2982 commit 92c08b3

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,13 @@ async def _save(
475475
directory,
476476
)
477477

478-
if await async_path.exists(directory):
478+
exists_start = time.time()
479+
dir_exists = await async_path.exists(directory)
480+
jax.monitoring.record_event_duration_secs(
481+
'/jax/orbax/write/async/foreground/check_dir_exists_secs',
482+
time.time() - exists_start,
483+
)
484+
if dir_exists:
479485
if force:
480486
if utils.is_primary_host(self._primary_host):
481487
logging.info(
@@ -498,7 +504,13 @@ async def _save(
498504
)
499505
)
500506
else:
507+
create_dir_start = time.time()
501508
await self.create_temporary_path(tmpdir)
509+
jax.monitoring.record_event_duration_secs(
510+
'/jax/orbax/write/async/foreground/create_dir_secs',
511+
time.time() - create_dir_start,
512+
)
513+
502514
# Run copy ops.
503515
# Try to save using new CheckpointArgs API if supported by the handler.
504516
ckpt_args = checkpointer.construct_checkpoint_args(

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,18 @@ async def async_save(
776776
total_serialization_initiated_time - batch_requests_ready_time,
777777
async_save_end_time - total_serialization_initiated_time,
778778
)
779+
jax.monitoring.record_event_duration_secs(
780+
'/jax/orbax/write/async/foreground/batch_requests_ready_secs',
781+
batch_requests_ready_time - start_time,
782+
)
783+
jax.monitoring.record_event_duration_secs(
784+
'/jax/orbax/write/async/foreground/d2h_transfer_secs',
785+
total_serialization_initiated_time - batch_requests_ready_time,
786+
)
787+
jax.monitoring.record_event_duration_secs(
788+
'/jax/orbax/write/async/foreground/commit_write_metadata_prep_secs',
789+
async_save_end_time - total_serialization_initiated_time,
790+
)
779791
return chained_futures
780792

781793
def save(self, directory: epath.Path, *args, **kwargs):

0 commit comments

Comments
 (0)