generated from cavalleria/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathbase_data_loader.py
52 lines (39 loc) · 1.7 KB
/
base_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
class BaseDataLoader(DataLoader):
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle
self.batch_idx = 0
self.n_samples = len(dataset)
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
def _split_sampler(self, split):
if split == 0.0:
return None, None
idx_full = np.arange(self.n_samples)
np.random.seed(0)
np.random.shuffle(idx_full)
len_valid = int(self.n_samples * split)
valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)
return train_sampler, valid_sampler
def split_validation(self):
if self.valid_sampler is None:
return None
else:
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)