Skip to content

Commit 8a85d59

Browse files
committed
Add Megatron-LM pretrain function for the core.
1 parent 9d83398 commit 8a85d59

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

pretrain_gpt_core.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
3+
"""Pretrain GPT"""
4+
5+
import torch
6+
from functools import partial
7+
from megatron import get_args
8+
from megatron.arguments import core_transformer_config_from_args
9+
from megatron import print_rank_0
10+
from megatron import get_timers
11+
from megatron import get_tokenizer
12+
from megatron.core import tensor_parallel
13+
from megatron.core.enums import ModelType
14+
from megatron.data.gpt_dataset import build_train_valid_test_datasets
15+
from megatron.core.models.gpt import GPTModel
16+
from megatron.training import pretrain
17+
from megatron.utils import get_ltor_masks_and_position_ids
18+
from megatron.utils import average_losses_across_data_parallel_group
19+
20+
def model_provider(pre_process=True, post_process=True):
21+
"""Build the model."""
22+
23+
args = get_args()
24+
config = core_transformer_config_from_args(args)
25+
26+
print_rank_0('building GPT model ...')
27+
model = GPTModel(
28+
config=config,
29+
vocab_size=args.padded_vocab_size,
30+
max_sequence_length=args.max_position_embeddings,
31+
pre_process=pre_process,
32+
post_process=post_process,
33+
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
34+
parallel_output=True,
35+
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights
36+
)
37+
return model
38+
39+
40+
def get_batch(data_iterator):
41+
"""Generate a batch"""
42+
args = get_args()
43+
tokenizer = get_tokenizer()
44+
45+
# Items and their type.
46+
keys = ['text']
47+
datatype = torch.int64
48+
49+
# Broadcast data.
50+
if data_iterator is not None:
51+
data = next(data_iterator)
52+
else:
53+
data = None
54+
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
55+
56+
# Unpack.
57+
tokens_ = data_b['text'].long()
58+
labels = tokens_[:, 1:].contiguous()
59+
tokens = tokens_[:, :-1].contiguous()
60+
61+
# Get the masks and postition ids.
62+
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
63+
tokens,
64+
tokenizer.eod,
65+
args.reset_position_ids,
66+
args.reset_attention_mask,
67+
args.eod_mask_loss)
68+
69+
return tokens, labels, loss_mask, attention_mask, position_ids
70+
71+
def loss_func(loss_mask, output_tensor):
72+
losses = output_tensor.float()
73+
loss_mask = loss_mask.view(-1).float()
74+
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
75+
76+
# Reduce loss for logging.
77+
averaged_loss = average_losses_across_data_parallel_group([loss])
78+
79+
return loss, {'lm loss': averaged_loss[0]}
80+
81+
82+
def forward_step(data_iterator, model):
83+
"""Forward step."""
84+
args = get_args()
85+
timers = get_timers()
86+
87+
# Get the batch.
88+
timers('batch-generator', log_level=2).start()
89+
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
90+
data_iterator)
91+
timers('batch-generator').stop()
92+
93+
output_tensor = model(tokens, position_ids, attention_mask,
94+
labels=labels)
95+
96+
return output_tensor, partial(loss_func, loss_mask)
97+
98+
99+
def train_valid_test_datasets_provider(train_val_test_num_samples):
100+
"""Build train, valid, and test datasets."""
101+
args = get_args()
102+
103+
print_rank_0('> building train, validation, and test datasets '
104+
'for GPT ...')
105+
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
106+
data_prefix=args.data_path,
107+
data_impl=args.data_impl,
108+
splits_string=args.split,
109+
train_valid_test_num_samples=train_val_test_num_samples,
110+
seq_length=args.seq_length,
111+
seed=args.seed,
112+
skip_warmup=(not args.mmap_warmup),
113+
train_data_prefix=args.train_data_path,
114+
valid_data_prefix=args.valid_data_path,
115+
test_data_prefix=args.test_data_path)
116+
print_rank_0("> finished creating GPT datasets ...")
117+
118+
return train_ds, valid_ds, test_ds
119+
120+
121+
if __name__ == "__main__":
122+
123+
pretrain(train_valid_test_datasets_provider, model_provider,
124+
ModelType.encoder_or_decoder,
125+
forward_step,
126+
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
127+
)

0 commit comments

Comments
 (0)