-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
29 lines (22 loc) · 766 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch.utils.data import Dataset
from torchvision import transforms
class HousesDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, house_data, house_target, transform=None):
"""
Args:
houses_data (numpy arry): the dataset
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.dataset = house_data
self.targets = house_target
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
sample = self.dataset[idx]
target = self.targets[idx]
if self.transform:
sample = self.transform(sample)
return sample, target