51
51
from MaxText import checkpointing
52
52
from MaxText import max_logging
53
53
from MaxText import maxtext_utils
54
+ from MaxText import max_utils
54
55
from MaxText import optimizers
55
56
from MaxText import pyconfig
56
57
from MaxText .globals import PKG_DIR
@@ -68,16 +69,6 @@ def fmt_size(num_bytes: int) -> str:
68
69
return f"{ num_bytes :.2f} { unit } "
69
70
70
71
71
- def check_memory ():
72
- """print out cpu/tpu memory."""
73
- cpu_bytes = Process ().memory_info ().rss
74
- max_logging .log (f"cpu memory: { fmt_size (cpu_bytes )} " )
75
- for d in jax .local_devices ():
76
- stats = d .memory_stats ()
77
- used = stats ["bytes_in_use" ]
78
- limit = stats ["bytes_limit" ]
79
- max_logging .log (f"tpu memory: Using { fmt_size (used )} / { fmt_size (limit )} ({ used / limit :%} ) on { d } " )
80
-
81
72
82
73
def convert (paxml_ckpt_path , maxtext_model_name , base_output_directory , run_name ):
83
74
"""convert ckpt."""
@@ -113,7 +104,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
113
104
114
105
state , _ , _ , _ = maxtext_utils .setup_training_state (model , None , tx , cfg , init_rng , mesh , checkpoint_manager )
115
106
max_logging .log ("start" )
116
- check_memory ( )
107
+ max_utils . print_mem_stats ( "After params initialized" )
117
108
118
109
# maxtext keystr: (paxml keystr, transform_fn)
119
110
keystr_map = {
@@ -267,12 +258,12 @@ def map_fn(key_path, value):
267
258
del arr
268
259
gc .collect ()
269
260
max_logging .log (f"{ key_path_str } finished" )
270
- check_memory ( )
261
+ max_utils . print_mem_stats ( "After params conversion" )
271
262
return result
272
263
273
264
converted_state = jax .tree_util .tree_map_with_path (map_fn , state )
274
265
max_logging .log ("converted state finished" )
275
- check_memory ( )
266
+ max_utils . print_mem_stats ( "converted state finished" )
276
267
277
268
if save_checkpoint (checkpoint_manager , converted_state .step , converted_state ):
278
269
max_logging .log (f"saved a checkpoint at step { converted_state .step } " )
0 commit comments