Skip to content

Commit 3d5d151

Browse files
thomasw21Lintang SutawikalintangsutawikaMuennighoff
authored
MTF train script (#295)
Co-authored-by: Lintang Sutawika <[email protected]> Co-authored-by: Lintang Sutawika <[email protected]> Co-authored-by: Muennighoff <[email protected]>
1 parent e1c479e commit 3d5d151

28 files changed

+1394
-695
lines changed

finetune_t0_non_causal_decoder.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Multitask Finetuning T0"""
2+
3+
from multiprocessing.sharedctypes import Value
4+
import torch
5+
6+
from megatron import get_args, get_tokenizer, print_rank_0, mpu
7+
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets
8+
from megatron.enums import PositionEmbeddingType, AttnMaskType
9+
from megatron.model import GPTModelPipe
10+
from megatron.training import pretrain
11+
from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask
12+
13+
import deepspeed
14+
from deepspeed.runtime.utils import see_memory_usage
15+
16+
try:
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
except ImportError:
19+
# noop
20+
def record(fn):
21+
return fn
22+
23+
def model_provider(pre_process=True, post_process=True):
24+
"""Build the model."""
25+
26+
print_rank_0("building GPT model ...")
27+
see_memory_usage(f"Before Building Model", force=True)
28+
29+
args = get_args()
30+
31+
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
32+
remote_device=None if args.remote_device == "none" else args.remote_device,
33+
config_dict_or_path=args.deepspeed_config,
34+
enabled=args.zero_stage == 3,
35+
mpu=mpu):
36+
if args.deepspeed:
37+
model = GPTModelPipe(
38+
num_tokentypes=0,
39+
parallel_output=True,
40+
attn_mask_type=AttnMaskType.custom
41+
)
42+
# This is a hack to give us a reference to get_batch_pipe from within training.py
43+
# We need to call model.set_batch_fn after deepspeed.initialize
44+
model._megatron_batch_fn = get_batch_pipe
45+
else:
46+
raise NotImplementedError("DeepSpeed is required for T0")
47+
48+
see_memory_usage(f"After Building Model", force=True)
49+
return model
50+
51+
def get_batch_pipe(data):
52+
"""
53+
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
54+
55+
data:
56+
decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]]
57+
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
58+
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
59+
"""
60+
args = get_args()
61+
tokenizer = get_tokenizer()
62+
63+
# Broadcast data.
64+
data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64)
65+
data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool)
66+
67+
# Unpack.
68+
tokens_ = data_b["decoder_token_ids"].long()
69+
labels = tokens_[:, 1:].contiguous()
70+
tokens = tokens_[:, :-1].contiguous()
71+
72+
segment_ids = data_b["decoder_segment_ids"].long()[:, :-1]
73+
decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1]
74+
75+
# Get the masks and position ids.
76+
causal_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
77+
tokens,
78+
tokenizer.eod,
79+
args.reset_position_ids,
80+
args.reset_attention_mask,
81+
args.eod_mask_loss,
82+
prefix_indices=None,
83+
loss_on_targets_only=False # This is done below
84+
)
85+
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
86+
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
87+
loss_on_non_pad_only = (tokens != tokenizer.pad)
88+
loss_mask *= loss_on_targets_only * loss_on_non_pad_only
89+
90+
attention_mask = get_packed_attention_mask(
91+
# Run non-causal decoder
92+
is_causal=False,
93+
causal_mask=~(causal_mask.bool()),
94+
decoder_is_inputs=decoder_is_inputs.bool(),
95+
segment_ids=segment_ids.long(),
96+
)
97+
98+
if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
99+
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")
100+
101+
return (tokens, position_ids, attention_mask), (labels, loss_mask)
102+
103+
104+
def train_valid_test_datasets_provider(train_val_test_num_samples):
105+
"""Build train, valid, and test datasets."""
106+
args = get_args()
107+
train_ds, valid_ds, test_ds = None, None, None
108+
109+
tokenizer = get_tokenizer()
110+
111+
print_rank_0("> building train, validation, and test datasets for T0 ...")
112+
# Option 1 of data loading using --data-path
113+
if args.data_path:
114+
# TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path))
115+
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
116+
data_prefix=args.data_path,
117+
data_impl=args.data_impl,
118+
splits_string=args.split,
119+
seq_length=args.seq_length + 1,
120+
pad_token=tokenizer.pad,
121+
eos_token=tokenizer.eos,
122+
train_valid_test_num_samples=train_val_test_num_samples,
123+
seed=args.seed,
124+
skip_warmup=(not args.mmap_warmup)
125+
)
126+
else:
127+
raise NotImplementedError("No dataloading argument passed")
128+
129+
print_rank_0("> finished creating T0 datasets ...")
130+
return train_ds, valid_ds, test_ds
131+
132+
@record
133+
def main():
134+
pretrain(
135+
train_valid_test_datasets_provider,
136+
model_provider,
137+
forward_step_func=None,
138+
args_defaults={}
139+
)
140+
141+
if __name__ == "__main__":
142+
main()

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def _add_training_args(parser):
557557
'please refer https://github.com/facebookresearch/bitsandbytes.',
558558
dest='use_bnb_optimizer')
559559
group.add_argument('--dataloader-type', type=str, default=None,
560-
choices=['single', 'cyclic', 'decoder_packed'],
560+
choices=['single', 'cyclic'],
561561
help='Single pass vs multiple pass data loader')
562562
group.add_argument('--cpu-optimizer', action='store_true',
563563
help='Run optimizer on CPU')

megatron/data/data_samplers.py

Lines changed: 3 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -15,77 +15,11 @@
1515

1616
"""Dataloaders."""
1717

18-
from functools import partial
19-
20-
import numpy as np
2118
import torch
2219

23-
from megatron import get_args, get_tokenizer
20+
from megatron import get_args
2421
from megatron import mpu
25-
from megatron.data.mtf_dataset import MTFDataset
26-
27-
28-
def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
29-
"""
30-
Greedily packs samples.
31-
32-
Items:
33-
[
34-
{
35-
'input_tokens': array([6, 7]),
36-
'target_tokens': array([8])
37-
},
38-
{
39-
'input_tokens': array([3, 4]),
40-
'target_tokens': array([5])
41-
}
42-
]
43-
44-
Output:
45-
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
46-
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
47-
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
48-
"""
49-
50-
decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
51-
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
52-
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))
53-
54-
batch_num = 0
55-
# `0` is reserved for padding
56-
item_num = 1
57-
cur_len = 0
58-
for token_dict in items:
59-
input_token_len = len(token_dict["input_tokens"])
60-
target_token_len = len(token_dict["target_tokens"])
61-
total_len = input_token_len + target_token_len
62-
if cur_len + total_len > max_seq_len:
63-
len_diff = max_seq_len - cur_len
64-
# Padding
65-
if len_diff > 0:
66-
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
67-
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
68-
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
69-
batch_num += 1
70-
assert batch_num < micro_batch_size
71-
item_num = 1
72-
cur_len = 0
73-
74-
decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
75-
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
76-
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
77-
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
78-
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target
79-
80-
item_num += 1
81-
cur_len += total_len
82-
assert cur_len < max_seq_len
83-
84-
return {
85-
"decoder_target_tokens": decoder_target_tokens,
86-
"decoder_segment_ids": decoder_segment_ids,
87-
"decoder_causal_attention": decoder_causal_attention,
88-
}
22+
from megatron.data.decoder_packed_mtf_dataset import DecoderPackedMTFDataset
8923

9024

9125
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
@@ -110,41 +44,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
11044
micro_batch_size=args.micro_batch_size,
11145
data_parallel_rank=mpu.get_data_parallel_rank(),
11246
data_parallel_size=mpu.get_data_parallel_world_size())
113-
elif args.dataloader_type == 'decoder_packed':
114-
assert isinstance(dataset, MTFDataset)
115-
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
116-
sequence_length=args.seq_length + 1,
117-
dataset=dataset,
118-
total_samples=len(dataset),
119-
consumed_samples=consumed_samples,
120-
micro_batch_size=args.micro_batch_size,
121-
data_parallel_rank=mpu.get_data_parallel_rank(),
122-
data_parallel_size=mpu.get_data_parallel_world_size())
12347
else:
12448
raise Exception('{} dataloader type is not supported.'.format(
12549
args.dataloader_type))
12650

12751
if num_workers is None:
12852
num_workers = args.num_workers
12953

130-
collate_fn = None
131-
if args.dataloader_type == 'decoder_packed':
132-
assert isinstance(dataset, MTFDataset)
133-
pad_token = get_tokenizer().pad
134-
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
135-
pad_token=pad_token)
136-
13754
# Torch dataloader.
13855
return torch.utils.data.DataLoader(
13956
dataset,
14057
batch_sampler=batch_sampler,
14158
num_workers=num_workers,
14259
generator=torch.Generator().manual_seed(args.seed),
143-
collate_fn=collate_fn,
60+
collate_fn=None,
14461
pin_memory=True
14562
)
14663

147-
14864
class MegatronPretrainingSampler:
14965

15066
def __init__(self, total_samples, consumed_samples, micro_batch_size,
@@ -246,76 +162,3 @@ def __iter__(self):
246162
self.consumed_samples += self.micro_batch_times_data_parallel_size
247163
yield batch
248164
batch = []
249-
250-
251-
class MegatronDecoderPackedText2TextRandomSampler(object):
252-
"""
253-
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
254-
packed to be passed onto the decoder model.
255-
256-
To be used with `pack_samples` as collate_fn
257-
"""
258-
259-
def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
260-
data_parallel_rank, data_parallel_size):
261-
# Keep a copy of input params for later use.
262-
self.dataset = dataset
263-
self.sequence_length = sequence_length
264-
self.total_samples = total_samples
265-
self.consumed_samples = consumed_samples
266-
self.micro_batch_size = micro_batch_size
267-
self.data_parallel_rank = data_parallel_rank
268-
self.data_parallel_size = data_parallel_size
269-
self.micro_batch_times_data_parallel_size = \
270-
self.micro_batch_size * data_parallel_size
271-
self.last_batch_size = \
272-
self.total_samples % self.micro_batch_times_data_parallel_size
273-
274-
# Sanity checks.
275-
assert self.total_samples > 0, \
276-
'no sample to consume: {}'.format(self.total_samples)
277-
assert self.micro_batch_size > 0
278-
assert data_parallel_size > 0
279-
assert self.data_parallel_rank < data_parallel_size, \
280-
'data_parallel_rank should be smaller than data size: {}, ' \
281-
'{}'.format(self.data_parallel_rank, data_parallel_size)
282-
283-
def __len__(self):
284-
return self.total_samples
285-
286-
def __iter__(self):
287-
active_total_samples = self.total_samples - self.last_batch_size
288-
self.epoch = self.consumed_samples // active_total_samples
289-
current_epoch_samples = self.consumed_samples % active_total_samples
290-
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
291-
292-
# data sharding and random sampling
293-
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
294-
* self.micro_batch_size
295-
bucket_offset = current_epoch_samples // self.data_parallel_size
296-
start_idx = self.data_parallel_rank * bucket_size
297-
298-
g = torch.Generator()
299-
g.manual_seed(self.epoch)
300-
301-
random_idx = torch.randperm(bucket_size, generator=g).tolist()
302-
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
303-
304-
batch = []
305-
batch_count = 0
306-
token_lens = 0
307-
# Last batch if not complete will be dropped.
308-
for idx in idx_range:
309-
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
310-
if token_lens + tok_len > self.sequence_length:
311-
batch_count += 1
312-
token_lens = 0
313-
314-
if batch_count == self.micro_batch_size:
315-
self.consumed_samples += self.micro_batch_times_data_parallel_size
316-
yield batch
317-
batch_count = 0
318-
batch = []
319-
else:
320-
token_lens += tok_len
321-
batch.append(idx)

0 commit comments

Comments
 (0)