Skip to content

Commit 43ab0e0

Browse files
authored
Add support for weighted train (#299)
1 parent 3d5d151 commit 43ab0e0

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

finetune_t0_non_causal_decoder.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Multitask Finetuning T0"""
22

3-
from multiprocessing.sharedctypes import Value
43
import torch
54

65
from megatron import get_args, get_tokenizer, print_rank_0, mpu
7-
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets
6+
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
87
from megatron.enums import PositionEmbeddingType, AttnMaskType
98
from megatron.model import GPTModelPipe
109
from megatron.training import pretrain
@@ -123,6 +122,40 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
123122
seed=args.seed,
124123
skip_warmup=(not args.mmap_warmup)
125124
)
125+
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
126+
elif args.train_weighted_split_paths:
127+
assigned_train_valid_test = []
128+
if args.train_weighted_split_paths is not None:
129+
train_ds = []
130+
assigned_train_valid_test.append("train")
131+
if args.valid_weighted_split_paths is not None:
132+
valid_ds = []
133+
assigned_train_valid_test.append("valid")
134+
if args.test_weighted_split_paths is not None:
135+
test_ds = []
136+
assigned_train_valid_test.append("test")
137+
138+
for s in assigned_train_valid_test:
139+
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
140+
eval(f"args.{s}_weighted_split_weights"),
141+
eval(f"args.{s}_weighted_split_splits"),
142+
eval(f"args.{s}_weighted_split_names"))
143+
for paths, weights, splits, name in data_groups:
144+
d = build_dataset_group(
145+
dataset_group_name=name,
146+
paths=paths,
147+
weights=weights,
148+
splits=splits,
149+
data_impl=args.data_impl,
150+
train_valid_test_num_samples=train_val_test_num_samples,
151+
seq_length=args.seq_length + 1,
152+
pad_token=tokenizer.pad,
153+
eos_token=tokenizer.eos,
154+
seed=args.seed,
155+
skip_warmup=(not args.mmap_warmup),
156+
train_valid_test=s
157+
)
158+
eval(f"{s}_ds").append(d)
126159
else:
127160
raise NotImplementedError("No dataloading argument passed")
128161

megatron/data/decoder_packed_mtf_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import numpy as np
55
import torch
66

7-
from megatron import print_rank_0, mpu
7+
from megatron import print_rank_0, mpu, logging
88
from megatron.data.blendable_dataset import BlendableDataset
99
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \
1010
get_train_valid_test_split_
1111
from megatron.data.mtf_dataset import MTFDataset
1212
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
1313

14+
logger = logging.get_logger(__name__)
1415

1516
def build_train_valid_test_datasets(
1617
data_prefix,
@@ -487,6 +488,7 @@ def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sam
487488
# TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can
488489
# - silently skip that value [currently implemented]
489490
# - truncate to `seq_length`, and keep the right part
491+
logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}")
490492
current_sample_start = current_sample_end + 1 # skipping
491493
row_length = 0
492494
continue

0 commit comments

Comments
 (0)