-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
39 lines (27 loc) · 1.38 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
# torchvision.transforms.ToTensor(): ①.避免不同特征值域差异过大,使网络训练更容易 ②.避免反向传播时梯度消失与梯度爆炸,提高模型的训练效率与性能
data_transform = transforms.Compose([transforms.ToTensor()])
# transforms.Normalize((0.5, ), (0.5, ))]) # 标准化把正态分布转化为标准正态分布
class DRIVEDataset(Dataset):
def __init__(self, root_dir, transform=data_transform):
self.root_dir = root_dir
self.transform = transform
self.img_dir = os.path.join(root_dir, "images")
self.img_names = os.listdir(self.img_dir)
self.label_dir = os.path.join(root_dir, "1st_manual")
self.label_names = os.listdir(self.label_dir)
def __getitem__(self, idx):
img_name = self.img_names[idx]
img_item_path = os.path.join(self.img_dir, img_name)
img = Image.open(img_item_path)
img = self.transform(img)
label_name = self.label_names[idx]
label_item_path = os.path.join(self.label_dir, label_name)
label = Image.open(label_item_path)
label = self.transform(label)
return img, label
def __len__(self):
return len(self.img_names)