-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
122 lines (100 loc) · 3.68 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
from utils import plot_images
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_train_valid_loader(
data_dir,
batch_size,
random_seed,
valid_size=0.1,
shuffle=True,
show_sample=False,
num_workers=4,
pin_memory=False,
):
"""Train and validation data loaders.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args:
data_dir: path directory to the dataset.
batch_size: how many samples per batch to load.
random_seed: fix seed for reproducibility.
valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
In the paper, this number is set to 0.1.
shuffle: whether to shuffle the train/validation indices.
show_sample: plot 9x9 sample grid of the dataset.
num_workers: number of subprocesses to use when loading the dataset.
pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert (valid_size >= 0) and (valid_size <= 1), error_msg
# define transforms
normalize = transforms.Normalize((0.1307,), (0.3081,))
trans = transforms.Compose([transforms.ToTensor(), normalize])
# load dataset
dataset = datasets.MNIST(data_dir, train=True, download=True, transform=trans)
num_train = len(dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
)
# visualize some images
if show_sample:
sample_loader = torch.utils.data.DataLoader(
dataset,
batch_size=9,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
)
data_iter = iter(sample_loader)
images, labels = data_iter.next()
X = images.numpy()
# print(X.max())
X = np.transpose(X, [0, 2, 3, 1])
plot_images(X, labels)
return (train_loader, valid_loader)
def get_test_loader(data_dir, batch_size, num_workers=4, pin_memory=False):
"""Test datalaoder.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args:
data_dir: path directory to the dataset.
batch_size: how many samples per batch to load.
num_workers: number of subprocesses to use when loading the dataset.
pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
"""
# define transforms
normalize = transforms.Normalize((0.1307,), (0.3081,))
trans = transforms.Compose([transforms.ToTensor(), normalize])
# load dataset
dataset = datasets.MNIST(data_dir, train=False, download=True, transform=trans)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
return data_loader