Skip to content

Commit 62bc456

Browse files
authored
Merge azure branch manually (deepspeedai#65)
1 parent 78592ae commit 62bc456

File tree

11 files changed

+502
-155
lines changed

11 files changed

+502
-155
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## Megatron-DeepSpeed
2-
DeepSpeed version of NVIDIA's Megatron-LM that adds additional support for several features such as MoE, Curriculum Learning, 3D Parallelism, etc.
2+
DeepSpeed version of NVIDIA's Megatron-LM that adds additional support for several features such as MoE model training, Curriculum Learning, 3D Parallelism, and others.
3+
4+
### Run on Azure and AzureML
5+
To try out DeepSpeed on Azure, this fork of Megatron offers easy-to-use recipes and bash scripts. We strongly recommend to start with AzureML recipe in the ```examples/azureml``` folder. If you have a custom infrastructure (e.g. HPC clusters) or Azure VM based environment, please refer to the bash scripts in the ```examples/azure``` folder.
36

47
------
58

@@ -76,7 +79,8 @@ The models require vocabulary files to run. The BERT WordPiece vocab file can b
7679
Additional notes for DeepSpeed. We have added a helper script to download the checkpoints and make the example runnable.
7780

7881
Steps to follow:
79-
- bash ds_download_ckpt.sh -- this will download and extract the checkpoint and GPT merges and vocab files.
82+
- bash dataset/download_ckpt.sh -- this will download and extract the checkpoint
83+
- bash dataset/download_vocab.sh -- this will download GPT merges and vocab files.
8084
- bash examples/generate_text.sh -- this will generate examples using the 345m GPT model.
8185

8286
# Usage

dataset/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Run the scripts below to setup dataset
2+
3+
bash download_books.sh
4+
5+
bash download_vocab.sh

dataset/download_books.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.bin
2+
wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.idx

ds_download_ckpt.sh renamed to dataset/download_ckpt.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
2-
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json
3-
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt
4-
51
mkdir -p checkpoints/gpt2_345m
62

73
cd checkpoints/gpt2_345m

dataset/download_vocab.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json
2+
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt

examples/azure/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
## Recipes for experimentation on Azure
2+
3+
The recipes have been tested on command line on a cluster setup using Azure VMs and VMSS as well as inside Docker based environments.
4+
5+
To run any of the examples in this folder, please go to the base directory of Megatron-DeepSpeed and run as follows
6+
7+
```bash examples/azure/run-benchmark-model.sh```
8+
9+
### Pre-requisites
10+
11+
To run the above script, you will need to either setup your own dataset and modify the scripts or use our helper scripts to download the publicly available Books dataset and GPT vocab files. Please use the following from the ```dataset``` folder
12+
13+
```bash dataset/download_books.sh```
14+
15+
```bash dataset/download_vocab.sh```
16+
17+
### Run 175B and 1T models
18+
19+
We have included two recipes for the 175B model and the 1T model. To train the model, we assume that the users will modify and tune hyperparameters and configurations by themselves. To facilitate initial training, we have made the recipes runnable with the Books dataset as follows.
20+
21+
```bash examples/azure/run-175b.sh```
22+
23+
```bash examples/azure/run-1t.sh```
24+
25+
### Note about ZeRO stage 3 and CPU offload
26+
27+
By default, we have enabled ZeRO Stage 3 for both the recipes above. For the 1T model, we have also enabled the CPU-offload feature to save on memory and enable a larger batch size that offers better performance.

examples/azure/run-175b.sh

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/bin/bash
2+
set -ex
3+
4+
data_options=" \
5+
--vocab-file ${VOCAB_PATH} \
6+
--merge-file ${MERGE_PATH} \
7+
--data-path ${DATA_PATH} \
8+
--data-impl mmap"
9+
10+
BASE_PATH=$PWD/dataset/
11+
DATA_PATH=${BASE_PATH}/BookCorpusDataset_text_document
12+
DS_CONFIG=ds_config.json
13+
14+
# Hostfile path
15+
HF=/job/hostfile
16+
17+
# Disabling tensor/pipeline parallelism
18+
TP=1
19+
PP=1
20+
21+
# HEADS ~= HIDDEN/128
22+
23+
# Model: 175B
24+
NLAYERS=96
25+
HIDDEN=12288
26+
HEADS=96
27+
SEQ=1024
28+
29+
30+
MICRO_BATCH=4
31+
NODES=1
32+
GPN=8
33+
GLOBAL_BATCH=$(( ${GPN} * ${MICRO_BATCH} * ${NODES} ))
34+
35+
# Initial power scale for loss
36+
SP=15
37+
38+
# Uncomment/comment one of the following blocks.
39+
40+
# For 1T model, start with microbatch=1, try to get 2 and 4. If OOM w/ 4, use cpu-offloading
41+
42+
# Set to cpu for offloading to cpu for larger models
43+
#OFFLOAD_DEVICE="cpu"
44+
#CPU_OPTIM=" --cpu-optimizer"
45+
46+
# Set to none and empty string for no cpu offloading
47+
OFFLOAD_DEVICE="none"
48+
CPU_OPTIM=" "
49+
50+
ZERO_STAGE=3
51+
OUTPUT_DIR=ds_z_off-${OFFLOAD_DEVICE}_stage_${ZERO_STAGE}_nl${NLAYERS}_hs${HIDDEN}_mb${MICRO_BATCH}_seq${SEQ}_gb${GLOBAL_BATCH}_nodes${NODES}
52+
#OUTPUT_DIR=baseline_nl${NLAYERS}_hs${HIDDEN}_gb${GLOBAL_BATCH}_mb${MICRO_BATCH}
53+
mkdir -p $OUTPUT_DIR
54+
55+
cat <<EOT > $DS_CONFIG
56+
{
57+
"train_batch_size" : $GLOBAL_BATCH,
58+
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
59+
"steps_per_print": 1,
60+
"gradient_accumulation_steps": 1,
61+
"zero_optimization": {
62+
"stage": 3,
63+
"stage3_max_live_parameters": 3e9,
64+
"stage3_max_reuse_distance": 3e9,
65+
"stage3_param_persitence_threshold": 1e5,
66+
"stage3_prefetch_bucket_size": 5e7,
67+
"contiguous_gradients": true,
68+
"overlap_comm": true,
69+
"reduce_bucket_size": 90000000,
70+
"sub_group_size": 1e9,
71+
"offload_optimizer": {
72+
"device": "$OFFLOAD_DEVICE",
73+
"buffer_count": 4,
74+
"pipeline_read": false,
75+
"pipeline_write": false,
76+
"pin_memory": true
77+
}
78+
},
79+
"gradient_clipping": 1.0,
80+
"fp16": {
81+
"enabled": true,
82+
"initial_scale_power" : $SP,
83+
"loss_scale_window": 1000,
84+
"hysteresis": 2,
85+
"min_loss_scale": 1
86+
},
87+
"wall_clock_breakdown": true,
88+
"zero_allow_untested_optimizer": false,
89+
"aio": {
90+
"block_size": 1048576,
91+
"queue_depth": 16,
92+
"single_submit": false,
93+
"overlap_events": true,
94+
"thread_count": 2
95+
}
96+
}
97+
EOT
98+
99+
export NCCL_DEBUG=warn
100+
101+
ds_args=" "
102+
ds_args=" --deepspeed ${ds_args}"
103+
ds_args=" --no-pipeline-parallel ${ds_args}"
104+
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
105+
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
106+
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
107+
108+
109+
110+
deepspeed --force_multi --num_nodes=$NODES --hostfile $HF pretrain_gpt.py \
111+
--tensor-model-parallel-size $TP \
112+
--pipeline-model-parallel-size $PP \
113+
--num-layers $NLAYERS \
114+
--hidden-size $HIDDEN \
115+
--num-attention-heads $HEADS \
116+
--seq-length $SEQ \
117+
--loss-scale $SP \
118+
--max-position-embeddings $SEQ \
119+
--micro-batch-size $MICRO_BATCH \
120+
--global-batch-size $GLOBAL_BATCH \
121+
--train-iters 1000 \
122+
--lr 6.0e-5 \
123+
--min-lr 6.0e-6 \
124+
--lr-decay-style cosine \
125+
--log-interval 1 \
126+
--eval-iters 40 \
127+
--eval-interval 1000 \
128+
--data-path $DATA_PATH \
129+
--vocab-file $BASE_PATH/gpt2-vocab.json \
130+
--merge-file $BASE_PATH/gpt2-merges.txt \
131+
--save-interval 1000 \
132+
--split 98,2,0 \
133+
--clip-grad 1.0 \
134+
--weight-decay 0.1 \
135+
--adam-beta1 0.9 \
136+
--adam-beta2 0.95 \
137+
--init-method-std 0.006 \
138+
--fp16 \
139+
--checkpoint-activations \
140+
--tensorboard-dir $OUTPUT_DIR \
141+
$CPU_OPTIM $ds_args \
142+
--exit-interval 5000 | tee ${OUTPUT_DIR}/output.log

examples/azure/run-1t.sh

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#!/bin/bash
2+
set -ex
3+
4+
data_options=" \
5+
--vocab-file ${VOCAB_PATH} \
6+
--merge-file ${MERGE_PATH} \
7+
--data-path ${DATA_PATH} \
8+
--data-impl mmap"
9+
10+
BASE_PATH=$PWD/dataset/
11+
DATA_PATH=${BASE_PATH}/BookCorpusDataset_text_document
12+
DS_CONFIG=ds_config.json
13+
14+
# Hostfile path
15+
HF=/job/hostfile
16+
17+
# Disabling tensor/pipeline parallelism
18+
TP=1
19+
PP=1
20+
21+
# HEADS ~= HIDDEN/128
22+
23+
# Refer to Megatron-table in the README.md file for model sizes
24+
# Model: 310B
25+
#NLAYERS=96
26+
#HIDDEN=16384
27+
#HEADS=128
28+
#SEQ=2048
29+
30+
# Model 530B
31+
#NLAYERS=105
32+
#HIDDEN=20480
33+
#HEADS=160
34+
#SEQ=2048
35+
36+
# Model 1T
37+
NLAYERS=128
38+
HIDDEN=25600
39+
HEADS=160
40+
SEQ=1024
41+
42+
MICRO_BATCH=1
43+
NODES=1
44+
GPN=8
45+
GLOBAL_BATCH=$(( ${GPN} * ${MICRO_BATCH} * ${NODES} ))
46+
47+
# Initial power scale for loss
48+
SP=15
49+
50+
# Uncomment/comment one of the following blocks.
51+
52+
# For 1T model, start with microbatch=1, try to get 2 and 4. If OOM w/ 4, use cpu-offloading
53+
54+
# Set to cpu for offloading to cpu for larger models
55+
OFFLOAD_DEVICE="cpu"
56+
CPU_OPTIM=" --cpu-optimizer"
57+
58+
# Set to none and empty string for no cpu offloading
59+
#OFFLOAD_DEVICE="none"
60+
#CPU_OPTIM=" "
61+
62+
ZERO_STAGE=3
63+
OUTPUT_DIR=ds_z_off-${OFFLOAD_DEVICE}_stage_${ZERO_STAGE}_nl${NLAYERS}_hs${HIDDEN}_mb${MICRO_BATCH}_seq${SEQ}_gb${GLOBAL_BATCH}_nodes${NODES}
64+
#OUTPUT_DIR=baseline_nl${NLAYERS}_hs${HIDDEN}_gb${GLOBAL_BATCH}_mb${MICRO_BATCH}
65+
mkdir -p $OUTPUT_DIR
66+
67+
cat <<EOT > $DS_CONFIG
68+
{
69+
"train_batch_size" : $GLOBAL_BATCH,
70+
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
71+
"steps_per_print": 1,
72+
"gradient_accumulation_steps": 1,
73+
"zero_optimization": {
74+
"stage": 3,
75+
"stage3_max_live_parameters": 3e9,
76+
"stage3_max_reuse_distance": 3e9,
77+
"stage3_param_persitence_threshold": 1e5,
78+
"stage3_prefetch_bucket_size": 5e7,
79+
"contiguous_gradients": true,
80+
"overlap_comm": true,
81+
"reduce_bucket_size": 90000000,
82+
"sub_group_size": 1e9,
83+
"offload_optimizer": {
84+
"device": "$OFFLOAD_DEVICE",
85+
"buffer_count": 4,
86+
"pipeline_read": false,
87+
"pipeline_write": false,
88+
"pin_memory": true
89+
}
90+
},
91+
"gradient_clipping": 1.0,
92+
"fp16": {
93+
"enabled": true,
94+
"initial_scale_power" : $SP,
95+
"loss_scale_window": 1000,
96+
"hysteresis": 2,
97+
"min_loss_scale": 1
98+
},
99+
"wall_clock_breakdown": true,
100+
"zero_allow_untested_optimizer": false,
101+
"aio": {
102+
"block_size": 1048576,
103+
"queue_depth": 16,
104+
"single_submit": false,
105+
"overlap_events": true,
106+
"thread_count": 2
107+
}
108+
}
109+
EOT
110+
111+
export NCCL_DEBUG=warn
112+
113+
ds_args=" "
114+
ds_args=" --deepspeed ${ds_args}"
115+
ds_args=" --no-pipeline-parallel ${ds_args}"
116+
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
117+
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
118+
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
119+
120+
121+
122+
deepspeed --force_multi --num_nodes=$NODES --hostfile $HF pretrain_gpt.py \
123+
--tensor-model-parallel-size $TP \
124+
--pipeline-model-parallel-size $PP \
125+
--num-layers $NLAYERS \
126+
--hidden-size $HIDDEN \
127+
--num-attention-heads $HEADS \
128+
--seq-length $SEQ \
129+
--loss-scale $SP \
130+
--max-position-embeddings $SEQ \
131+
--micro-batch-size $MICRO_BATCH \
132+
--global-batch-size $GLOBAL_BATCH \
133+
--train-iters 1000 \
134+
--lr 6.0e-5 \
135+
--min-lr 6.0e-6 \
136+
--lr-decay-style cosine \
137+
--log-interval 1 \
138+
--eval-iters 40 \
139+
--eval-interval 1000 \
140+
--data-path $DATA_PATH \
141+
--vocab-file $BASE_PATH/gpt2-vocab.json \
142+
--merge-file $BASE_PATH/gpt2-merges.txt \
143+
--save-interval 1000 \
144+
--split 98,2,0 \
145+
--clip-grad 1.0 \
146+
--weight-decay 0.1 \
147+
--adam-beta1 0.9 \
148+
--adam-beta2 0.95 \
149+
--init-method-std 0.006 \
150+
--fp16 \
151+
--checkpoint-activations \
152+
--tensorboard-dir $OUTPUT_DIR \
153+
$CPU_OPTIM $ds_args \
154+
--exit-interval 5000 | tee ${OUTPUT_DIR}/output.log

0 commit comments

Comments
 (0)