|
1 | 1 | """Multitask Finetuning T0"""
|
2 | 2 |
|
3 |
| -from multiprocessing.sharedctypes import Value |
4 | 3 | import torch
|
5 | 4 |
|
6 | 5 | from megatron import get_args, get_tokenizer, print_rank_0, mpu
|
7 |
| -from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets |
| 6 | +from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group |
8 | 7 | from megatron.enums import PositionEmbeddingType, AttnMaskType
|
9 | 8 | from megatron.model import GPTModelPipe
|
10 | 9 | from megatron.training import pretrain
|
@@ -123,6 +122,40 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
|
123 | 122 | seed=args.seed,
|
124 | 123 | skip_warmup=(not args.mmap_warmup)
|
125 | 124 | )
|
| 125 | + # Option 2 of data loading using --(train|valid|test)-weighted-split-paths |
| 126 | + elif args.train_weighted_split_paths: |
| 127 | + assigned_train_valid_test = [] |
| 128 | + if args.train_weighted_split_paths is not None: |
| 129 | + train_ds = [] |
| 130 | + assigned_train_valid_test.append("train") |
| 131 | + if args.valid_weighted_split_paths is not None: |
| 132 | + valid_ds = [] |
| 133 | + assigned_train_valid_test.append("valid") |
| 134 | + if args.test_weighted_split_paths is not None: |
| 135 | + test_ds = [] |
| 136 | + assigned_train_valid_test.append("test") |
| 137 | + |
| 138 | + for s in assigned_train_valid_test: |
| 139 | + data_groups = zip(eval(f"args.{s}_weighted_split_paths"), |
| 140 | + eval(f"args.{s}_weighted_split_weights"), |
| 141 | + eval(f"args.{s}_weighted_split_splits"), |
| 142 | + eval(f"args.{s}_weighted_split_names")) |
| 143 | + for paths, weights, splits, name in data_groups: |
| 144 | + d = build_dataset_group( |
| 145 | + dataset_group_name=name, |
| 146 | + paths=paths, |
| 147 | + weights=weights, |
| 148 | + splits=splits, |
| 149 | + data_impl=args.data_impl, |
| 150 | + train_valid_test_num_samples=train_val_test_num_samples, |
| 151 | + seq_length=args.seq_length + 1, |
| 152 | + pad_token=tokenizer.pad, |
| 153 | + eos_token=tokenizer.eos, |
| 154 | + seed=args.seed, |
| 155 | + skip_warmup=(not args.mmap_warmup), |
| 156 | + train_valid_test=s |
| 157 | + ) |
| 158 | + eval(f"{s}_ds").append(d) |
126 | 159 | else:
|
127 | 160 | raise NotImplementedError("No dataloading argument passed")
|
128 | 161 |
|
|
0 commit comments