-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_dataloader.py
39 lines (29 loc) · 1.27 KB
/
check_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np
from torch.utils.data import WeightedRandomSampler, DataLoader
numDataPoints = 1000
data_dim = 5
bs = 100
# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
np.ones(int(numDataPoints * 0.1), dtype=np.int32)))
# target == 0 : 900
print ('target train 0/1: {}/{}'.format(
len(np.where(target == 0)[0]), len(np.where(target == 1)[0])))
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=1, sampler=sampler)
for i, (data, target) in enumerate(train_loader):
print ("batch index {}, 0/1: {}/{}".format(
i,
len(np.where(target.numpy() == 0)[0]),
len(np.where(target.numpy() == 1)[0])))