Skip to content

Commit 9a52a2e

Browse files
authored
added unit test for megatron (deepspeedai#102)
1 parent 515798f commit 9a52a2e

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include megatron/data/Makefile
22
include megatron/data/helpers.cpp
3+
recursive-include megatron/fused_kernels *.cpp *.h *.cu *.tr *.cuh *.cc

tests/run_megatron.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import deepspeed
3+
import megatron
4+
from megatron import get_args
5+
from megatron import mpu
6+
from megatron.checkpointing import load_checkpoint
7+
from megatron.initialize import initialize_megatron
8+
from megatron.model import GPTModel
9+
from megatron.training import get_model
10+
from megatron.text_generation_utils import generate_samples_eval
11+
12+
13+
def model_provider(pre_process=True, post_process=True):
14+
model = GPTModel(
15+
num_tokentypes=0,
16+
parallel_output=False,
17+
pre_process=pre_process,
18+
post_process=post_process,
19+
return_moe_loss=False,
20+
)
21+
return model
22+
23+
24+
def add_text_generate_args(parser):
25+
"""Text generation arguments."""
26+
group = parser.add_argument_group(title="text generation")
27+
28+
group.add_argument(
29+
"--temperature", type=float, default=1.0, help="Sampling temperature."
30+
)
31+
group.add_argument(
32+
"--greedy", action="store_true", default=False, help="Use greedy sampling."
33+
)
34+
group.add_argument("--top_p", type=float, default=0.0, help="Top p sampling.")
35+
group.add_argument("--top_k", type=int, default=0, help="Top k sampling.")
36+
group.add_argument(
37+
"--out-seq-length",
38+
type=int,
39+
default=1024,
40+
help="Size of the output generated text.",
41+
)
42+
group.add_argument(
43+
"--sample-input-file",
44+
type=str,
45+
default=None,
46+
help="Get input from file instead of interactive mode, "
47+
"each line is an input.",
48+
)
49+
group.add_argument(
50+
"--sample-output-file",
51+
type=str,
52+
default=None,
53+
help="Output file got from --sample-input-file",
54+
)
55+
group.add_argument(
56+
"--num-samples",
57+
type=int,
58+
default=0,
59+
help="Number of samples to generate unconditionally, "
60+
"defaults to 0 and interactive conditional sampling",
61+
)
62+
group.add_argument(
63+
"--genfile", type=str, help="Output file when generating unconditionally"
64+
)
65+
group.add_argument(
66+
"--recompute",
67+
action="store_true",
68+
help="During generation recompute all attention "
69+
"instead of using previously computed keys/values.",
70+
)
71+
group.add_argument(
72+
"--context-tokens", type=str, default="DeepSpeed is the greatest"
73+
)
74+
group.add_argument("--max-tokens", type=int, default=50)
75+
76+
return parser
77+
78+
79+
if __name__ == "__main__":
80+
# initialize megatron
81+
initialize_megatron(
82+
extra_args_provider=add_text_generate_args,
83+
args_defaults={
84+
"tokenizer_type": "GPT2BPETokenizer",
85+
"no_load_rng": True,
86+
"no_load_optim": True,
87+
},
88+
)
89+
args = get_args()
90+
91+
# setup model
92+
model = get_model(model_provider)
93+
_ = load_checkpoint(model, None, None)
94+
model = model[0]
95+
if args.ds_inference:
96+
engine = deepspeed.init_inference(
97+
model=model,
98+
mp_size=args.tensor_model_parallel_size,
99+
tensor_parallel={"mpu": mpu},
100+
dtype=torch.half,
101+
replace_with_kernel_inject=True,
102+
moe_experts=args.num_experts,
103+
moe_type=args.mlp_type,
104+
)
105+
model = engine.module
106+
107+
# generate output
108+
generate_samples_eval(
109+
model, args.context_tokens, 1, 0
110+
) # Just so we don't get log output from DeepSpeed (this should be removed once we improve logging in DeepSpeed)
111+
print("===START OUTPUT===")
112+
print(generate_samples_eval(model, args.context_tokens, args.max_tokens, 0))
113+
print("===END OUTPUT===")

tests/test_megatron.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
import os
3+
import re
4+
import subprocess
5+
6+
7+
@pytest.fixture(params=[1])
8+
def moe_num_experts(request):
9+
return str(request.param)
10+
11+
12+
@pytest.fixture(params=[1])
13+
def mp_size(request):
14+
return str(request.param)
15+
16+
17+
@pytest.fixture
18+
def params(moe_num_experts, mp_size):
19+
base_dir = os.getenv("MEGATRON_CKPT_DIR")
20+
assert base_dir, "Please set MEGATRON_CKPT_DIR in your environment"
21+
22+
vocab_file = os.path.join(base_dir, "gpt2-vocab.json")
23+
merge_file = os.path.join(base_dir, "gpt2-merges.txt")
24+
ckpt_path = os.path.join(base_dir, "checkpoints/gpt2_345m")
25+
26+
return [
27+
"--micro-batch-size", "1",
28+
"--num-layers", "24",
29+
"--hidden-size", "1024",
30+
"--num-attention-heads", "16",
31+
"--max-position-embeddings", "1024",
32+
"--vocab-file", vocab_file,
33+
"--merge-file", merge_file,
34+
"--load", ckpt_path,
35+
"--seq-length", "1024",
36+
"--out-seq-length", "1024",
37+
"--tensor-model-parallel-size", mp_size,
38+
"--tokenizer-type", "GPT2BPETokenizer",
39+
"--num-experts", moe_num_experts,
40+
"--mlp-type", "standard",
41+
"--num-samples", "0",
42+
"--fp16",
43+
]
44+
45+
46+
def test_moe_megatron(params, mp_size):
47+
output_re = r"===START OUTPUT===([\S\s]*)===END OUTPUT==="
48+
49+
# Run the baseline
50+
baseline_cmd = ["deepspeed", "--num_gpus", mp_size, "./run_megatron.py"] + params
51+
result = subprocess.run(baseline_cmd, stdout=subprocess.PIPE)
52+
baseline_output = re.search(output_re, result.stdout.decode("utf-8")).group(1)
53+
54+
# Run with DeepSpeed
55+
deepspeed_cmd = baseline_cmd + ["--ds-inference"]
56+
result = subprocess.run(deepspeed_cmd, stdout=subprocess.PIPE)
57+
deepspeed_output = re.search(output_re, result.stdout.decode("utf-8")).group(1)
58+
59+
assert (
60+
baseline_output == deepspeed_output
61+
), f"outputs do not match: {baseline_output}\n{deepspeed_output}"

0 commit comments

Comments
 (0)