-
Notifications
You must be signed in to change notification settings - Fork 4
/
Dataset.py
68 lines (64 loc) · 2.25 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os, sys, torch
from PIL import Image
import numpy as np
from torchvision import transforms
import torch.utils.data as data
class Dataset(data.Dataset):
def __init__(self, dRaw, dExpert, listfn, dSemSeg='', dSaliency='', nclasses=150, trans='', include_filenames=False):
self.dRaw = dRaw
if isinstance(dExpert, str):
dExpert = [dExpert]
self.dExpert = dExpert
self.dSemSeg = dSemSeg
self.dSaliency = dSaliency
self.include_filenames = include_filenames
self.listfn = listfn
self.trans = trans
self.nclasses = nclasses
# read file with filenames
in_file = open(listfn,"r")
text = in_file.read()
in_file.close()
# get filenames
self.fns = [l for l in text.split('\n') if l]
def __getitem__(self, index):
fn = self.fns[index]
# open images
raw = np.array(Image.open(os.path.join(self.dRaw,fn)).convert('RGB'))
raw = raw.astype(np.float32) / 255.
images =[raw]
# check if there are experts to load
if self.dExpert is not None:
# if there are, load their images
for cur_dexp in self.dExpert:
cur_img = np.array(Image.open(os.path.join(cur_dexp,fn)).convert('RGB'))
cur_img = cur_img.astype(np.float32) / 255.
images.append(cur_img)
# open semantic segmentation
semseg = []
if os.path.isdir(self.dSemSeg):
semseg_img = np.array(Image.open(os.path.join(self.dSemSeg,fn)))
chs = []
for i in range(self.nclasses):
cur_ch = np.zeros((images[0].shape[0],images[0].shape[1]))
cur_ch[semseg_img==i]=1
chs.append(np.expand_dims(cur_ch,axis=2))
semseg_maps = np.concatenate(chs,axis=2).astype(np.float32)
images[0] = np.concatenate((images[0], semseg_maps),axis=2)
# open saliency
if os.path.isdir(self.dSaliency):
saliency_img = np.array(Image.open(os.path.join(self.dSaliency,fn)))
saliency_img = saliency_img.astype(np.float) / 255.
saliency_img = np.expand_dims(saliency_img,axis=2).astype(np.float32)
images[0] = np.concatenate((images[0], saliency_img),axis=2)
# apply transforms
if self.trans:
images = self.trans(images)
else:
images = [transforms.ToTensor()(images[i]) for i in range(len(images))]
# return
if self.include_filenames:
return images, fn
return images # first the raw, then the experts
def __len__(self):
return len(self.fns)