15
15
16
16
"""Dataloaders."""
17
17
18
- from functools import partial
19
-
20
- import numpy as np
21
18
import torch
22
19
23
- from megatron import get_args , get_tokenizer
20
+ from megatron import get_args
24
21
from megatron import mpu
25
- from megatron .data .mtf_dataset import MTFDataset
26
-
27
-
28
- def pack_samples (items , max_seq_len : int , micro_batch_size : int , pad_token : int ):
29
- """
30
- Greedily packs samples.
31
-
32
- Items:
33
- [
34
- {
35
- 'input_tokens': array([6, 7]),
36
- 'target_tokens': array([8])
37
- },
38
- {
39
- 'input_tokens': array([3, 4]),
40
- 'target_tokens': array([5])
41
- }
42
- ]
43
-
44
- Output:
45
- decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
46
- decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
47
- decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
48
- """
49
-
50
- decoder_target_tokens = np .full ((micro_batch_size , max_seq_len ), pad_token )
51
- decoder_segment_ids = np .zeros ((micro_batch_size , max_seq_len ))
52
- decoder_causal_attention = np .zeros ((micro_batch_size , max_seq_len ))
53
-
54
- batch_num = 0
55
- # `0` is reserved for padding
56
- item_num = 1
57
- cur_len = 0
58
- for token_dict in items :
59
- input_token_len = len (token_dict ["input_tokens" ])
60
- target_token_len = len (token_dict ["target_tokens" ])
61
- total_len = input_token_len + target_token_len
62
- if cur_len + total_len > max_seq_len :
63
- len_diff = max_seq_len - cur_len
64
- # Padding
65
- if len_diff > 0 :
66
- decoder_target_tokens [batch_num ][cur_len : max_seq_len ] = pad_token
67
- decoder_segment_ids [batch_num ][cur_len : max_seq_len ] = 0
68
- decoder_causal_attention [batch_num ][cur_len : max_seq_len ] = 0
69
- batch_num += 1
70
- assert batch_num < micro_batch_size
71
- item_num = 1
72
- cur_len = 0
73
-
74
- decoder_target_tokens [batch_num ][cur_len : cur_len + input_token_len ] = token_dict ["input_tokens" ]
75
- decoder_target_tokens [batch_num ][cur_len + input_token_len : cur_len + total_len ] = token_dict ["target_tokens" ]
76
- decoder_segment_ids [batch_num ][cur_len : cur_len + total_len ] = item_num
77
- decoder_causal_attention [batch_num ][cur_len : cur_len + input_token_len ] = 1 # input
78
- decoder_causal_attention [batch_num ][cur_len + input_token_len : cur_len + total_len ] = 0 # target
79
-
80
- item_num += 1
81
- cur_len += total_len
82
- assert cur_len < max_seq_len
83
-
84
- return {
85
- "decoder_target_tokens" : decoder_target_tokens ,
86
- "decoder_segment_ids" : decoder_segment_ids ,
87
- "decoder_causal_attention" : decoder_causal_attention ,
88
- }
22
+ from megatron .data .decoder_packed_mtf_dataset import DecoderPackedMTFDataset
89
23
90
24
91
25
def build_pretraining_data_loader (dataset , consumed_samples , num_workers = None ):
@@ -110,41 +44,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
110
44
micro_batch_size = args .micro_batch_size ,
111
45
data_parallel_rank = mpu .get_data_parallel_rank (),
112
46
data_parallel_size = mpu .get_data_parallel_world_size ())
113
- elif args .dataloader_type == 'decoder_packed' :
114
- assert isinstance (dataset , MTFDataset )
115
- batch_sampler = MegatronDecoderPackedText2TextRandomSampler (
116
- sequence_length = args .seq_length + 1 ,
117
- dataset = dataset ,
118
- total_samples = len (dataset ),
119
- consumed_samples = consumed_samples ,
120
- micro_batch_size = args .micro_batch_size ,
121
- data_parallel_rank = mpu .get_data_parallel_rank (),
122
- data_parallel_size = mpu .get_data_parallel_world_size ())
123
47
else :
124
48
raise Exception ('{} dataloader type is not supported.' .format (
125
49
args .dataloader_type ))
126
50
127
51
if num_workers is None :
128
52
num_workers = args .num_workers
129
53
130
- collate_fn = None
131
- if args .dataloader_type == 'decoder_packed' :
132
- assert isinstance (dataset , MTFDataset )
133
- pad_token = get_tokenizer ().pad
134
- collate_fn = partial (pack_samples , max_seq_len = args .seq_length + 1 , micro_batch_size = args .micro_batch_size ,
135
- pad_token = pad_token )
136
-
137
54
# Torch dataloader.
138
55
return torch .utils .data .DataLoader (
139
56
dataset ,
140
57
batch_sampler = batch_sampler ,
141
58
num_workers = num_workers ,
142
59
generator = torch .Generator ().manual_seed (args .seed ),
143
- collate_fn = collate_fn ,
60
+ collate_fn = None ,
144
61
pin_memory = True
145
62
)
146
63
147
-
148
64
class MegatronPretrainingSampler :
149
65
150
66
def __init__ (self , total_samples , consumed_samples , micro_batch_size ,
@@ -246,76 +162,3 @@ def __iter__(self):
246
162
self .consumed_samples += self .micro_batch_times_data_parallel_size
247
163
yield batch
248
164
batch = []
249
-
250
-
251
- class MegatronDecoderPackedText2TextRandomSampler (object ):
252
- """
253
- Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
254
- packed to be passed onto the decoder model.
255
-
256
- To be used with `pack_samples` as collate_fn
257
- """
258
-
259
- def __init__ (self , sequence_length , dataset , total_samples , consumed_samples , micro_batch_size ,
260
- data_parallel_rank , data_parallel_size ):
261
- # Keep a copy of input params for later use.
262
- self .dataset = dataset
263
- self .sequence_length = sequence_length
264
- self .total_samples = total_samples
265
- self .consumed_samples = consumed_samples
266
- self .micro_batch_size = micro_batch_size
267
- self .data_parallel_rank = data_parallel_rank
268
- self .data_parallel_size = data_parallel_size
269
- self .micro_batch_times_data_parallel_size = \
270
- self .micro_batch_size * data_parallel_size
271
- self .last_batch_size = \
272
- self .total_samples % self .micro_batch_times_data_parallel_size
273
-
274
- # Sanity checks.
275
- assert self .total_samples > 0 , \
276
- 'no sample to consume: {}' .format (self .total_samples )
277
- assert self .micro_batch_size > 0
278
- assert data_parallel_size > 0
279
- assert self .data_parallel_rank < data_parallel_size , \
280
- 'data_parallel_rank should be smaller than data size: {}, ' \
281
- '{}' .format (self .data_parallel_rank , data_parallel_size )
282
-
283
- def __len__ (self ):
284
- return self .total_samples
285
-
286
- def __iter__ (self ):
287
- active_total_samples = self .total_samples - self .last_batch_size
288
- self .epoch = self .consumed_samples // active_total_samples
289
- current_epoch_samples = self .consumed_samples % active_total_samples
290
- assert current_epoch_samples % self .micro_batch_times_data_parallel_size == 0
291
-
292
- # data sharding and random sampling
293
- bucket_size = (self .total_samples // self .micro_batch_times_data_parallel_size ) \
294
- * self .micro_batch_size
295
- bucket_offset = current_epoch_samples // self .data_parallel_size
296
- start_idx = self .data_parallel_rank * bucket_size
297
-
298
- g = torch .Generator ()
299
- g .manual_seed (self .epoch )
300
-
301
- random_idx = torch .randperm (bucket_size , generator = g ).tolist ()
302
- idx_range = [start_idx + x for x in random_idx [bucket_offset :]]
303
-
304
- batch = []
305
- batch_count = 0
306
- token_lens = 0
307
- # Last batch if not complete will be dropped.
308
- for idx in idx_range :
309
- tok_len = len (self .dataset [idx ]['input_tokens' ]) + len (self .dataset [idx ]['target_tokens' ])
310
- if token_lens + tok_len > self .sequence_length :
311
- batch_count += 1
312
- token_lens = 0
313
-
314
- if batch_count == self .micro_batch_size :
315
- self .consumed_samples += self .micro_batch_times_data_parallel_size
316
- yield batch
317
- batch_count = 0
318
- batch = []
319
- else :
320
- token_lens += tok_len
321
- batch .append (idx )
0 commit comments