Skip to content

Commit b7f1b05

Browse files
author
Neel Kant
committed
Lint whole repo
1 parent c99fa80 commit b7f1b05

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1103
-990
lines changed

megatron/arguments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
357357
return parser
358358

359359

360-
361360
def add_data_args_(parser):
362361
"""Train/valid/test data arguments."""
363362

@@ -367,6 +366,4 @@ def add_data_args_(parser):
367366
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
368367
help='Which data loader to use. Default varies by model.')
369368

370-
371369
return parser
372-

megatron/checkpointing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
6767
directory = 'iter_{:07d}'.format(iteration)
6868
return os.path.join(checkpoints_path, directory,
6969
'mp_rank_{:02d}'.format(
70-
mpu.get_model_parallel_rank() if mp_rank is None \
70+
mpu.get_model_parallel_rank() if mp_rank is None
7171
else mp_rank),
7272
'model_optim_rng.pt')
7373

@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
179179
'megatron.fp16.loss_scaler']
180180
state_dict = torch.load(checkpoint_name, map_location='cpu')
181181
sys.modules.pop('fp16.loss_scaler', None)
182-
except:
182+
except BaseException:
183183
print_rank_0('could not load the checkpoint')
184184
sys.exit()
185185

@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
190190
try:
191191
iteration = state_dict['iteration']
192192
except KeyError:
193-
try: # Backward compatible with older checkpoints
193+
try: # Backward compatible with older checkpoints
194194
iteration = state_dict['total_iters']
195195
except KeyError:
196196
print_rank_0('A metadata file exists but unable to load '

megatron/data/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
11
from . import indexed_dataset
2-
3-

megatron/data/bert_dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
4747

4848
# Print stats about the splits.
4949
print_rank_0(' > dataset split:')
50+
5051
def print_split_stats(name, index):
5152
print_rank_0(' {}:'.format(name))
5253
print_rank_0(' document indices in [{}, {}) total of {} '
@@ -113,7 +114,6 @@ def __init__(self, name, indexed_dataset, data_prefix,
113114
# Dataset.
114115
self.indexed_dataset = indexed_dataset
115116

116-
117117
# Build the samples mapping.
118118
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
119119
data_prefix,
@@ -133,11 +133,9 @@ def __init__(self, name, indexed_dataset, data_prefix,
133133
self.mask_id = tokenizer.mask
134134
self.pad_id = tokenizer.pad
135135

136-
137136
def __len__(self):
138137
return self.samples_mapping.shape[0]
139138

140-
141139
def __getitem__(self, idx):
142140

143141
start_index, end_index, seq_length = self.samples_mapping[idx]
@@ -148,7 +146,7 @@ def __getitem__(self, idx):
148146
# python randint is inclusive whereas the numpy one is exclusive.
149147
np_rng = np.random.RandomState(seed=(self.seed + idx))
150148
return build_training_sample(sample, seq_length,
151-
self.max_seq_length, # needed for padding
149+
self.max_seq_length, # needed for padding
152150
self.vocab_id_list,
153151
self.vocab_id_to_token_dict,
154152
self.cls_id, self.sep_id,
@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
192190
splits = splits[:3]
193191
splits_sum = sum(splits)
194192
assert splits_sum > 0.0
195-
splits = [split/splits_sum for split in splits]
193+
splits = [split / splits_sum for split in splits]
196194
splits_index = [0]
197195
for index, split in enumerate(splits):
198196
splits_index.append(splits_index[index] +
@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset,
254252
indexed_dataset.sizes,
255253
num_epochs,
256254
max_num_samples,
257-
max_seq_length-3, # account for added tokens
255+
max_seq_length - 3, # account for added tokens
258256
short_seq_prob,
259257
seed,
260258
verbose)

megatron/data/gpt2_dataset.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
4242

4343
# Print stats about the splits.
4444
print_rank_0(' > dataset split:')
45+
4546
def print_split_stats(name, index):
4647
print_rank_0(' {}:'.format(name))
4748
print_rank_0(' document indices in [{}, {}) total of {} '
@@ -54,7 +55,7 @@ def print_split_stats(name, index):
5455
def build_dataset(index, name):
5556
dataset = None
5657
if splits[index + 1] > splits[index]:
57-
documents = np.arange(start=splits[index], stop=splits[index+1],
58+
documents = np.arange(start=splits[index], stop=splits[index + 1],
5859
step=1, dtype=np.int32)
5960
dataset = GPT2Dataset(name, data_prefix,
6061
documents, indexed_dataset,
@@ -102,21 +103,19 @@ def __init__(self, name, data_prefix, documents, indexed_dataset,
102103
self.name, data_prefix, documents, self.indexed_dataset.sizes,
103104
num_samples, seq_length, seed)
104105

105-
106106
def __len__(self):
107107
# -1 is due to data structure used to retieve the index:
108108
# sample i --> [sample_idx[i], sample_idx[i+1])
109109
return self.sample_idx.shape[0] - 1
110110

111-
112111
def __getitem__(self, idx):
113112
# Get the shuffled index.
114113
idx = self.shuffle_idx[idx]
115114
# Start and end documents and offsets.
116115
doc_index_f = self.sample_idx[idx][0]
117-
doc_index_l = self.sample_idx[idx+1][0]
116+
doc_index_l = self.sample_idx[idx + 1][0]
118117
offset_f = self.sample_idx[idx][1]
119-
offset_l = self.sample_idx[idx+1][1]
118+
offset_l = self.sample_idx[idx + 1][1]
120119
# If we are within the same document, just extract the chunk.
121120
if doc_index_f == doc_index_l:
122121
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
@@ -127,18 +126,17 @@ def __getitem__(self, idx):
127126
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
128127
offset=offset_f)]
129128
# Loop over all in between documents and add the entire document.
130-
for i in range(doc_index_f+1, doc_index_l):
129+
for i in range(doc_index_f + 1, doc_index_l):
131130
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
132131
# And finally add the relevant portion of last document.
133132
sample_list.append(self.indexed_dataset.get(
134133
self.doc_idx[doc_index_l],
135-
length=offset_l+1))
134+
length=offset_l + 1))
136135
sample = np.concatenate(sample_list)
137136

138137
return {'text': np.array(sample, dtype=np.int64)}
139138

140139

141-
142140
def _build_index_mappings(name, data_prefix, documents, sizes,
143141
num_samples, seq_length, seed):
144142
"""Build doc-idx, sample-idx, and shuffle-idx.
@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
185183
assert sizes.dtype == np.int32
186184
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
187185
num_epochs, tokens_per_epoch)
188-
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
186+
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
189187
# num_epochs, tokens_per_epoch)
190188
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
191189
print_rank_0(' > elasped time to build and save sample-idx mapping '
@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
194192
start_time = time.time()
195193
# -1 is due to data structure used to retieve the index:
196194
# sample i --> [sample_idx[i], sample_idx[i+1])
197-
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
195+
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
198196
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
199197
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
200198
' (seconds): {:4f}'.format(time.time() - start_time))

megatron/data/indexed_dataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from megatron import print_rank_0
2222

23+
2324
def __best_fitting_dtype(vocab_size=None):
2425
if vocab_size is not None and vocab_size < 65500:
2526
return np.uint16
@@ -109,13 +110,15 @@ def index_file_path(prefix_path):
109110
def data_file_path(prefix_path):
110111
return prefix_path + '.bin'
111112

113+
112114
def create_doc_idx(sizes):
113115
doc_idx = [0]
114116
for i, s in enumerate(sizes):
115117
if s == 0:
116-
doc_idx.append(i+1)
118+
doc_idx.append(i + 1)
117119
return doc_idx
118120

121+
119122
class IndexedDataset(torch.utils.data.Dataset):
120123
"""Loader for IndexedDataset"""
121124
_HDR_MAGIC = b'TNTIDX\x00\x00'
@@ -155,7 +158,7 @@ def __del__(self):
155158
if self.data_file:
156159
self.data_file.close()
157160

158-
#@lru_cache(maxsize=8)
161+
# @lru_cache(maxsize=8)
159162
def __getitem__(self, idx):
160163
if not self.data_file:
161164
self.read_data(self.path)
@@ -235,7 +238,7 @@ def prefetch(self, indices):
235238
self.data_file.close()
236239
self.data_file = None
237240

238-
#@lru_cache(maxsize=8)
241+
# @lru_cache(maxsize=8)
239242
def __getitem__(self, idx):
240243
if isinstance(idx, int):
241244
i = idx
@@ -399,13 +402,18 @@ def __init__(self, path, skip_warmup=False):
399402
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
400403
self._bin_buffer = memoryview(self._bin_buffer_mmap)
401404
print_rank_0(" reading sizes...")
402-
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
405+
self._sizes = np.frombuffer(
406+
self._bin_buffer,
407+
dtype=np.int32,
408+
count=self._len,
409+
offset=offset)
403410
print_rank_0(" reading pointers...")
404411
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
405412
offset=offset + self._sizes.nbytes)
406413
print_rank_0(" reading document index...")
407414
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
408415
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
416+
409417
def __del__(self):
410418
self._bin_buffer_mmap._mmap.close()
411419
del self._bin_buffer_mmap
@@ -464,7 +472,7 @@ def __del__(self):
464472
def __len__(self):
465473
return len(self._index)
466474

467-
#@lru_cache(maxsize=8)
475+
# @lru_cache(maxsize=8)
468476
def __getitem__(self, idx):
469477
if isinstance(idx, int):
470478
ptr, size = self._index[idx]

megatron/data/samplers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
8181
sampler level. This allows wrapping of arbitrary data samplers
8282
(sequential, random, WeightedRandomSampler, etc.) with this batch
8383
sampler."""
84+
8485
def __init__(self, sampler, batch_size, drop_last, rank=-1,
8586
world_size=2, wrap_last=False):
8687
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
@@ -120,7 +121,7 @@ def __iter__(self):
120121
def data_iterator(self, _iter, wrap_around=False):
121122
"""iterates through data and handles wrap around"""
122123
for i, idx in enumerate(_iter):
123-
if i < self.wrap_around%self.batch_size:
124+
if i < self.wrap_around % self.batch_size:
124125
continue
125126
if wrap_around:
126127
self.wrap_around += 1
@@ -129,6 +130,6 @@ def data_iterator(self, _iter, wrap_around=False):
129130

130131
def _batch(self, batch):
131132
"""extracts samples only pertaining to this worker's batch"""
132-
start = self.rank*self.batch_size//self.world_size
133-
end = (self.rank+1)*self.batch_size//self.world_size
133+
start = self.rank * self.batch_size // self.world_size
134+
end = (self.rank + 1) * self.batch_size // self.world_size
134135
return batch[start:end]

megatron/data/test/test_indexed_dataset.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# put some code used during development and manual testing of
33
# indexed_dataset.
44

5+
from megatron.data import indexed_dataset
6+
from megatron.tokenizer import build_tokenizer
57
import argparse
68
import os
79
import sys
@@ -11,8 +13,6 @@
1113
script_dir = os.path.dirname(os.path.realpath(__file__))
1214
sys.path.append(os.path.join(script_dir, "../../../"))
1315

14-
from megatron.tokenizer import build_tokenizer
15-
from megatron.data import indexed_dataset
1616

1717
def test_indexed_dataset(args):
1818
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
@@ -23,12 +23,12 @@ def test_indexed_dataset(args):
2323
if ds.supports_prefetch:
2424
# just prefetch the whole thing in test (so assume it is small)
2525
ds.prefetch(range(len(ds)))
26-
if args.count > len(ds.doc_idx)-1:
27-
args.count = len(ds.doc_idx)-1
26+
if args.count > len(ds.doc_idx) - 1:
27+
args.count = len(ds.doc_idx) - 1
2828

2929
for i in range(args.count):
3030
start = ds.doc_idx[i]
31-
end = ds.doc_idx[i+1]
31+
end = ds.doc_idx[i + 1]
3232
ids = ds[start:end]
3333
print(f"Document {i}:")
3434
print("--------------")
@@ -39,26 +39,27 @@ def test_indexed_dataset(args):
3939
print(text)
4040
print("---")
4141

42+
4243
def test_indexed_dataset_get(args):
4344
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
4445
tokenizer = build_tokenizer(args)
4546
size = ds.sizes[0]
4647
print(f"size: {size}")
4748
full = ds.get(0)
4849
print(full)
49-
#print(tokenizer.detokenize(full.data.tolist()))
50+
# print(tokenizer.detokenize(full.data.tolist()))
5051
print("---")
51-
end = ds.get(0, offset=size-10)
52+
end = ds.get(0, offset=size - 10)
5253
print(end)
53-
#print(tokenizer.detokenize(end.data.tolist()))
54+
# print(tokenizer.detokenize(end.data.tolist()))
5455

5556
start = ds.get(0, length=10)
5657
print(start)
57-
#print(tokenizer.detokenize(start.data.tolist()))
58+
# print(tokenizer.detokenize(start.data.tolist()))
5859

5960
part = ds.get(0, offset=2, length=8)
6061
print(part)
61-
#print(tokenizer.detokenize(part.data.tolist()))
62+
# print(tokenizer.detokenize(part.data.tolist()))
6263

6364
# def test_albert_dataset(args):
6465
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args):
7778
# if i >= args.count-1:
7879
# exit()
7980

81+
8082
def main():
8183
parser = argparse.ArgumentParser()
8284
parser.add_argument('--data', type=str, help='prefix to data files')
@@ -118,5 +120,6 @@ def main():
118120
# test_albert_dataset(args)
119121
test_indexed_dataset_get(args)
120122

123+
121124
if __name__ == "__main__":
122125
main()

0 commit comments

Comments
 (0)