|
14 | 14 | limitations under the License.
|
15 | 15 | """
|
16 | 16 |
|
17 |
| -"""Integration tests for test_checkpointing.sh""" |
| 17 | +""" |
| 18 | +Integration tests for test_checkpointing.sh |
| 19 | +
|
| 20 | +Note: Make sure to run |
| 21 | + `bash setup_gcsfuse.sh DATASET_GCS_BUCKET=gs://maxtext-dataset MOUNT_PATH=/tmp/gcsfuse/` |
| 22 | +before running tests locally. |
| 23 | +""" |
| 24 | + |
18 | 25 | from datetime import datetime
|
19 |
| -import subprocess |
| 26 | +import json |
| 27 | +from math import isclose |
20 | 28 | import os.path
|
21 | 29 | import pytest
|
22 | 30 | from MaxText.globals import PKG_DIR
|
23 |
| -from MaxText.tests.globals import TEST_DISABLE_SUBPROCESS, TEST_DISABLE_SUBPROCESS_STR |
| 31 | +from MaxText.train import main as train_main |
24 | 32 |
|
25 | 33 |
|
26 |
| -def run_checkpointing(attention_type): |
27 |
| - """Tests grain checkpoint determinism.""" |
| 34 | +def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type): |
| 35 | + model_params = [ |
| 36 | + "base_emb_dim=384", |
| 37 | + "base_num_query_heads=8", |
| 38 | + "base_num_kv_heads=8", |
| 39 | + "base_mlp_dim=192", |
| 40 | + "base_num_decoder_layers=8", |
| 41 | + "head_dim=128", |
| 42 | + ] |
| 43 | + return [ |
| 44 | + None, |
| 45 | + os.path.join(PKG_DIR, "configs", "base.yml"), |
| 46 | + f"hardware={hardware}", |
| 47 | + f"run_name=runner_{run_date}", |
| 48 | + f"steps={steps}", |
| 49 | + "max_target_length=128", |
| 50 | + "per_device_batch_size=1", |
| 51 | + f"metrics_file={metrics_file}", |
| 52 | + "checkpoint_period=3", |
| 53 | + "base_output_directory=gs://runner-maxtext-logs", |
| 54 | + "dataset_path=/tmp/gcsfuse/", |
| 55 | + "async_checkpointing=False", |
| 56 | + f"attention={attention_type}", |
| 57 | + ] + model_params |
28 | 58 |
|
| 59 | + |
| 60 | +def check_loss(metrics_file, target): |
| 61 | + """Asserts over loss values from loaded checkpoint""" |
| 62 | + metrics_file_saved = "saved_" + metrics_file |
| 63 | + metrics_file_restored = "restored_" + metrics_file |
| 64 | + |
| 65 | + with ( |
| 66 | + open(metrics_file_saved, "rt", encoding="utf8") as saved, |
| 67 | + open(metrics_file_restored, "rt", encoding="utf8") as restored, |
| 68 | + ): |
| 69 | + saved_loss = json.loads(saved.readlines()[-1])[target] |
| 70 | + restored_loss = json.loads(restored.readlines()[0])[target] |
| 71 | + # Checks that checkpoint restore was successful by comparing loss of last |
| 72 | + # step in saved checkpoint to loss of first step in restored checkpoint |
| 73 | + print("saved loss: ", saved_loss) |
| 74 | + print("restored loss: ", restored_loss) |
| 75 | + assert isclose(saved_loss, restored_loss, rel_tol=0.1) |
| 76 | + |
| 77 | + |
| 78 | +def run_checkpointing(hardware, attention_type): |
| 79 | + """Tests grain checkpoint determinism.""" |
29 | 80 | run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
30 |
| - script_path = os.path.join(os.path.dirname(PKG_DIR), "end_to_end", "test_checkpointing.sh") |
31 |
| - if not os.path.isfile(script_path): |
32 |
| - raise FileNotFoundError(script_path) |
33 |
| - command = [ |
34 |
| - "bash", |
35 |
| - script_path, |
36 |
| - f"runner_{run_date}", # run_name |
37 |
| - "gs://runner-maxtext-logs", # output_path |
38 |
| - "gs://maxtext-dataset", # dataset_path |
39 |
| - "False", # collect_stack_trace |
40 |
| - "grain", # dataset_type |
41 |
| - attention_type, |
42 |
| - "False", # async_checkpointing" |
| 81 | + grain_command = [ |
| 82 | + "grain_worker_count=0", |
| 83 | + "dataset_type=grain", |
| 84 | + "grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*", |
43 | 85 | ]
|
| 86 | + train_main( |
| 87 | + get_checkpointing_command( |
| 88 | + run_date, |
| 89 | + hardware=hardware, |
| 90 | + steps=5, |
| 91 | + metrics_file="saved_metrics.txt", |
| 92 | + attention_type=attention_type, |
| 93 | + ) |
| 94 | + + grain_command |
| 95 | + ) |
| 96 | + |
| 97 | + train_main( |
| 98 | + get_checkpointing_command( |
| 99 | + run_date, |
| 100 | + hardware=hardware, |
| 101 | + steps=10, |
| 102 | + metrics_file="restored_metrics.txt", |
| 103 | + attention_type=attention_type, |
| 104 | + ) |
| 105 | + + grain_command |
| 106 | + ) |
44 | 107 |
|
45 |
| - subprocess.run(command, check=True, cwd=os.path.dirname(PKG_DIR)) |
| 108 | + check_loss("metrics.txt", "learning/loss") |
46 | 109 |
|
47 | 110 |
|
48 | 111 | @pytest.mark.integration_test
|
49 | 112 | @pytest.mark.tpu_only
|
50 |
| -@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR) |
51 | 113 | def test_autoselected_attention():
|
52 |
| - run_checkpointing("autoselected") |
| 114 | + run_checkpointing("tpu", "autoselected") |
53 | 115 |
|
54 | 116 |
|
55 | 117 | @pytest.mark.integration_test
|
56 | 118 | @pytest.mark.gpu_only
|
57 |
| -@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR) |
58 | 119 | def test_with_dot_product():
|
59 |
| - run_checkpointing("dot_product") |
| 120 | + run_checkpointing("gpu", "dot_product") |
0 commit comments