-
Notifications
You must be signed in to change notification settings - Fork 100
/
base_dataset.py
128 lines (108 loc) · 4.77 KB
/
base_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
class BaseDataset(Dataset):
"""Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory"""
def __init__(self, data, transform, class_indices=None):
"""Initialization"""
self.labels = data['y']
self.images = data['x']
self.transform = transform
self.class_indices = class_indices
def __len__(self):
"""Denotes the total number of samples"""
return len(self.images)
def __getitem__(self, index):
"""Generates one sample of data"""
x = Image.open(self.images[index]).convert('RGB')
x = self.transform(x)
y = self.labels[index]
return x, y
def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
"""Prepare data: dataset splits, task partition, class order"""
data = {}
taskcla = []
# read filenames and labels
trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str)
tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str)
if class_order is None:
num_classes = len(np.unique(trn_lines[:, 1]))
class_order = list(range(num_classes))
else:
num_classes = len(class_order)
class_order = class_order.copy()
if shuffle_classes:
np.random.shuffle(class_order)
# compute classes per task and num_tasks
if nc_first_task is None:
cpertask = np.array([num_classes // num_tasks] * num_tasks)
for i in range(num_classes % num_tasks):
cpertask[i] += 1
else:
assert nc_first_task < num_classes, "first task wants more classes than exist"
remaining_classes = num_classes - nc_first_task
assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
for i in range(remaining_classes % (num_tasks - 1)):
cpertask[i + 1] += 1
assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
cpertask_cumsum = np.cumsum(cpertask)
init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
# initialize data structure
for tt in range(num_tasks):
data[tt] = {}
data[tt]['name'] = 'task-' + str(tt)
data[tt]['trn'] = {'x': [], 'y': []}
data[tt]['val'] = {'x': [], 'y': []}
data[tt]['tst'] = {'x': [], 'y': []}
# ALL OR TRAIN
for this_image, this_label in trn_lines:
if not os.path.isabs(this_image):
this_image = os.path.join(path, this_image)
this_label = int(this_label)
if this_label not in class_order:
continue
# If shuffling is false, it won't change the class number
this_label = class_order.index(this_label)
# add it to the corresponding split
this_task = (this_label >= cpertask_cumsum).sum()
data[this_task]['trn']['x'].append(this_image)
data[this_task]['trn']['y'].append(this_label - init_class[this_task])
# ALL OR TEST
for this_image, this_label in tst_lines:
if not os.path.isabs(this_image):
this_image = os.path.join(path, this_image)
this_label = int(this_label)
if this_label not in class_order:
continue
# If shuffling is false, it won't change the class number
this_label = class_order.index(this_label)
# add it to the corresponding split
this_task = (this_label >= cpertask_cumsum).sum()
data[this_task]['tst']['x'].append(this_image)
data[this_task]['tst']['y'].append(this_label - init_class[this_task])
# check classes
for tt in range(num_tasks):
data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
# validation
if validation > 0.0:
for tt in data.keys():
for cc in range(data[tt]['ncla']):
cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
rnd_img.sort(reverse=True)
for ii in range(len(rnd_img)):
data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
data[tt]['trn']['x'].pop(rnd_img[ii])
data[tt]['trn']['y'].pop(rnd_img[ii])
# other
n = 0
for t in data.keys():
taskcla.append((t, data[t]['ncla']))
n += data[t]['ncla']
data['ncla'] = n
return data, taskcla, class_order