Skip to content

Commit d97714a

Browse files
author
maxtext authors
committed
Merge pull request #1754 from AI-Hypercomputer:mohit/memstats
PiperOrigin-RevId: 760684416
2 parents c1806b4 + f8a0cb7 commit d97714a

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

MaxText/convert_gpt3_ckpt_from_paxml.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from MaxText import checkpointing
5252
from MaxText import max_logging
5353
from MaxText import maxtext_utils
54+
from MaxText import max_utils
5455
from MaxText import optimizers
5556
from MaxText import pyconfig
5657
from MaxText.globals import PKG_DIR
@@ -68,16 +69,6 @@ def fmt_size(num_bytes: int) -> str:
6869
return f"{num_bytes:.2f} {unit}"
6970

7071

71-
def check_memory():
72-
"""print out cpu/tpu memory."""
73-
cpu_bytes = Process().memory_info().rss
74-
max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}")
75-
for d in jax.local_devices():
76-
stats = d.memory_stats()
77-
used = stats["bytes_in_use"]
78-
limit = stats["bytes_limit"]
79-
max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")
80-
8172

8273
def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name):
8374
"""convert ckpt."""
@@ -113,7 +104,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
113104

114105
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
115106
max_logging.log("start")
116-
check_memory()
107+
max_utils.print_mem_stats("After params initialized")
117108

118109
# maxtext keystr: (paxml keystr, transform_fn)
119110
keystr_map = {
@@ -267,12 +258,12 @@ def map_fn(key_path, value):
267258
del arr
268259
gc.collect()
269260
max_logging.log(f"{key_path_str} finished")
270-
check_memory()
261+
max_utils.print_mem_stats("After params conversion")
271262
return result
272263

273264
converted_state = jax.tree_util.tree_map_with_path(map_fn, state)
274265
max_logging.log("converted state finished")
275-
check_memory()
266+
max_utils.print_mem_stats("converted state finished")
276267

277268
if save_checkpoint(checkpoint_manager, converted_state.step, converted_state):
278269
max_logging.log(f"saved a checkpoint at step {converted_state.step}")

0 commit comments

Comments
 (0)