Skip to content

Commit d702bfc

Browse files
author
maxtext authors
committed
Merge pull request #1770 from AI-Hypercomputer:checkpointing_integ_test
PiperOrigin-RevId: 762199934
2 parents 9b5fca5 + 5748748 commit d702bfc

File tree

1 file changed

+84
-23
lines changed

1 file changed

+84
-23
lines changed

MaxText/tests/integration_tests/checkpointing_test.py

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,107 @@
1414
limitations under the License.
1515
"""
1616

17-
"""Integration tests for test_checkpointing.sh"""
17+
"""
18+
Integration tests for test_checkpointing.sh
19+
20+
Note: Make sure to run
21+
`bash setup_gcsfuse.sh DATASET_GCS_BUCKET=gs://maxtext-dataset MOUNT_PATH=/tmp/gcsfuse/`
22+
before running tests locally.
23+
"""
24+
1825
from datetime import datetime
19-
import subprocess
26+
import json
27+
from math import isclose
2028
import os.path
2129
import pytest
2230
from MaxText.globals import PKG_DIR
23-
from MaxText.tests.globals import TEST_DISABLE_SUBPROCESS, TEST_DISABLE_SUBPROCESS_STR
31+
from MaxText.train import main as train_main
2432

2533

26-
def run_checkpointing(attention_type):
27-
"""Tests grain checkpoint determinism."""
34+
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type):
35+
model_params = [
36+
"base_emb_dim=384",
37+
"base_num_query_heads=8",
38+
"base_num_kv_heads=8",
39+
"base_mlp_dim=192",
40+
"base_num_decoder_layers=8",
41+
"head_dim=128",
42+
]
43+
return [
44+
None,
45+
os.path.join(PKG_DIR, "configs", "base.yml"),
46+
f"hardware={hardware}",
47+
f"run_name=runner_{run_date}",
48+
f"steps={steps}",
49+
"max_target_length=128",
50+
"per_device_batch_size=1",
51+
f"metrics_file={metrics_file}",
52+
"checkpoint_period=3",
53+
"base_output_directory=gs://runner-maxtext-logs",
54+
"dataset_path=/tmp/gcsfuse/",
55+
"async_checkpointing=False",
56+
f"attention={attention_type}",
57+
] + model_params
2858

59+
60+
def check_loss(metrics_file, target):
61+
"""Asserts over loss values from loaded checkpoint"""
62+
metrics_file_saved = "saved_" + metrics_file
63+
metrics_file_restored = "restored_" + metrics_file
64+
65+
with (
66+
open(metrics_file_saved, "rt", encoding="utf8") as saved,
67+
open(metrics_file_restored, "rt", encoding="utf8") as restored,
68+
):
69+
saved_loss = json.loads(saved.readlines()[-1])[target]
70+
restored_loss = json.loads(restored.readlines()[0])[target]
71+
# Checks that checkpoint restore was successful by comparing loss of last
72+
# step in saved checkpoint to loss of first step in restored checkpoint
73+
print("saved loss: ", saved_loss)
74+
print("restored loss: ", restored_loss)
75+
assert isclose(saved_loss, restored_loss, rel_tol=0.1)
76+
77+
78+
def run_checkpointing(hardware, attention_type):
79+
"""Tests grain checkpoint determinism."""
2980
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
30-
script_path = os.path.join(os.path.dirname(PKG_DIR), "end_to_end", "test_checkpointing.sh")
31-
if not os.path.isfile(script_path):
32-
raise FileNotFoundError(script_path)
33-
command = [
34-
"bash",
35-
script_path,
36-
f"runner_{run_date}", # run_name
37-
"gs://runner-maxtext-logs", # output_path
38-
"gs://maxtext-dataset", # dataset_path
39-
"False", # collect_stack_trace
40-
"grain", # dataset_type
41-
attention_type,
42-
"False", # async_checkpointing"
81+
grain_command = [
82+
"grain_worker_count=0",
83+
"dataset_type=grain",
84+
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
4385
]
86+
train_main(
87+
get_checkpointing_command(
88+
run_date,
89+
hardware=hardware,
90+
steps=5,
91+
metrics_file="saved_metrics.txt",
92+
attention_type=attention_type,
93+
)
94+
+ grain_command
95+
)
96+
97+
train_main(
98+
get_checkpointing_command(
99+
run_date,
100+
hardware=hardware,
101+
steps=10,
102+
metrics_file="restored_metrics.txt",
103+
attention_type=attention_type,
104+
)
105+
+ grain_command
106+
)
44107

45-
subprocess.run(command, check=True, cwd=os.path.dirname(PKG_DIR))
108+
check_loss("metrics.txt", "learning/loss")
46109

47110

48111
@pytest.mark.integration_test
49112
@pytest.mark.tpu_only
50-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
51113
def test_autoselected_attention():
52-
run_checkpointing("autoselected")
114+
run_checkpointing("tpu", "autoselected")
53115

54116

55117
@pytest.mark.integration_test
56118
@pytest.mark.gpu_only
57-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
58119
def test_with_dot_product():
59-
run_checkpointing("dot_product")
120+
run_checkpointing("gpu", "dot_product")

0 commit comments

Comments
 (0)