Skip to content

Commit 668646c

Browse files
committed
all files
1 parent 144d851 commit 668646c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+2754
-0
lines changed

data/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch.utils.data
2+
from data.base_dataset import collate_fn
3+
4+
def CreateDataset(opt):
5+
"""loads dataset class"""
6+
7+
if opt.dataset_mode == 'segmentation':
8+
from data.segmentation_data import SegmentationData
9+
dataset = SegmentationData(opt)
10+
elif opt.dataset_mode == 'classification':
11+
from data.classification_data import ClassificationData
12+
dataset = ClassificationData(opt)
13+
return dataset
14+
15+
16+
class DataLoader:
17+
"""multi-threaded data loading"""
18+
19+
def __init__(self, opt):
20+
self.opt = opt
21+
self.dataset = CreateDataset(opt)
22+
self.dataloader = torch.utils.data.DataLoader(
23+
self.dataset,
24+
batch_size=opt.batch_size,
25+
shuffle=not opt.serial_batches,
26+
num_workers=int(opt.num_threads),
27+
collate_fn=collate_fn)
28+
29+
def __len__(self):
30+
return len(self.dataset)
31+
32+
def __iter__(self):
33+
for i, data in enumerate(self.dataloader):
34+
yield data

data/base_dataset.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch.utils.data as data
2+
import numpy as np
3+
import pickle
4+
import os
5+
6+
class BaseDataset(data.Dataset):
7+
8+
def __init__(self, opt):
9+
self.opt = opt
10+
self.mean = 0
11+
self.std = 1
12+
self.ninput_channels = None
13+
super(BaseDataset, self).__init__()
14+
15+
def get_mean_std(self):
16+
""" Computes Mean and Standard Deviation from Training Data
17+
If mean/std file doesn't exist, will compute one
18+
:returns
19+
mean: N-dimensional mean
20+
std: N-dimensional standard deviation
21+
ninput_channels: N
22+
(here N=5)
23+
"""
24+
25+
mean_std_cache = os.path.join(self.root, 'mean_std_cache.p')
26+
if not os.path.isfile(mean_std_cache):
27+
print('computing mean std from train data...')
28+
# doesn't run augmentation during m/std computation
29+
num_aug = self.opt.num_aug
30+
self.opt.num_aug = 1
31+
mean, std = np.array(0), np.array(0)
32+
for i, data in enumerate(self):
33+
if i % 500 == 0:
34+
print('{} of {}'.format(i, self.size))
35+
features = data['edge_features']
36+
mean = mean + features.mean(axis=1)
37+
std = std + features.std(axis=1)
38+
mean = mean / (i + 1)
39+
std = std / (i + 1)
40+
transform_dict = {'mean': mean[:, np.newaxis], 'std': std[:, np.newaxis],
41+
'ninput_channels': len(mean)}
42+
with open(mean_std_cache, 'wb') as f:
43+
pickle.dump(transform_dict, f)
44+
print('saved: ', mean_std_cache)
45+
self.opt.num_aug = num_aug
46+
# open mean / std from file
47+
with open(mean_std_cache, 'rb') as f:
48+
transform_dict = pickle.load(f)
49+
print('loaded mean / std from cache')
50+
self.mean = transform_dict['mean']
51+
self.std = transform_dict['std']
52+
self.ninput_channels = transform_dict['ninput_channels']
53+
54+
55+
def collate_fn(batch):
56+
"""Creates mini-batch tensors
57+
We should build custom collate_fn rather than using default collate_fn
58+
"""
59+
meta = {}
60+
keys = batch[0].keys()
61+
for key in keys:
62+
meta.update({key: np.array([d[key] for d in batch])})
63+
return meta

data/classification_data.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import torch
3+
from data.base_dataset import BaseDataset
4+
from util.util import is_mesh_file, pad
5+
from models.layers.mesh import Mesh
6+
7+
class ClassificationData(BaseDataset):
8+
9+
def __init__(self, opt):
10+
BaseDataset.__init__(self, opt)
11+
self.opt = opt
12+
self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
13+
self.root = opt.dataroot
14+
self.dir = os.path.join(opt.dataroot)
15+
self.classes, self.class_to_idx = self.find_classes(self.dir)
16+
self.paths = self.make_dataset_by_class(self.dir, self.class_to_idx, opt.phase)
17+
self.nclasses = len(self.classes)
18+
self.size = len(self.paths)
19+
self.get_mean_std()
20+
# modify for network later.
21+
opt.nclasses = self.nclasses
22+
opt.input_nc = self.ninput_channels
23+
24+
def __getitem__(self, index):
25+
path = self.paths[index][0]
26+
label = self.paths[index][1]
27+
mesh = Mesh(file=path, opt=self.opt, hold_history=False, export_folder=self.opt.export_folder)
28+
meta = {'mesh': mesh, 'label': label}
29+
# get edge features
30+
edge_features = mesh.extract_features()
31+
edge_features = pad(edge_features, self.opt.ninput_edges)
32+
meta['edge_features'] = (edge_features - self.mean) / self.std
33+
return meta
34+
35+
def __len__(self):
36+
return self.size
37+
38+
# this is when the folders are organized by class...
39+
@staticmethod
40+
def find_classes(dir):
41+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
42+
classes.sort()
43+
class_to_idx = {classes[i]: i for i in range(len(classes))}
44+
return classes, class_to_idx
45+
46+
@staticmethod
47+
def make_dataset_by_class(dir, class_to_idx, phase):
48+
meshes = []
49+
dir = os.path.expanduser(dir)
50+
for target in sorted(os.listdir(dir)):
51+
d = os.path.join(dir, target)
52+
if not os.path.isdir(d):
53+
continue
54+
for root, _, fnames in sorted(os.walk(d)):
55+
for fname in sorted(fnames):
56+
if is_mesh_file(fname) and (root.count(phase)==1):
57+
path = os.path.join(root, fname)
58+
item = (path, class_to_idx[target])
59+
meshes.append(item)
60+
return meshes

data/segmentation_data.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import torch
3+
from data.base_dataset import BaseDataset
4+
from util.util import is_mesh_file, pad
5+
import numpy as np
6+
from models.layers.mesh import Mesh
7+
8+
class SegmentationData(BaseDataset):
9+
10+
def __init__(self, opt):
11+
BaseDataset.__init__(self, opt)
12+
self.opt = opt
13+
self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
14+
self.root = opt.dataroot
15+
self.dir = os.path.join(opt.dataroot, opt.phase)
16+
self.paths = self.make_dataset(self.dir)
17+
self.seg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'seg'), seg_ext='.eseg')
18+
self.sseg_paths = self.get_seg_files(self.paths, os.path.join(self.root, 'sseg'), seg_ext='.seseg')
19+
self.classes, self.offset = self.get_n_segs(os.path.join(self.root, 'classes.txt'), self.seg_paths)
20+
self.nclasses = len(self.classes)
21+
self.size = len(self.paths)
22+
self.get_mean_std()
23+
# # modify for network later.
24+
opt.nclasses = self.nclasses
25+
opt.input_nc = self.ninput_channels
26+
27+
def __getitem__(self, index):
28+
path = self.paths[index]
29+
mesh = Mesh(file=path, opt=self.opt, hold_history=True, export_folder=self.opt.export_folder)
30+
meta = {}
31+
meta['mesh'] = mesh
32+
label = read_seg(self.seg_paths[index]) - self.offset
33+
label = pad(label, self.opt.ninput_edges, val=-1, dim=0)
34+
meta['label'] = label
35+
soft_label = read_sseg(self.sseg_paths[index])
36+
meta['soft_label'] = pad(soft_label, self.opt.ninput_edges, val=-1, dim=0)
37+
# get edge features
38+
edge_features = mesh.extract_features()
39+
edge_features = pad(edge_features, self.opt.ninput_edges)
40+
meta['edge_features'] = (edge_features - self.mean) / self.std
41+
return meta
42+
43+
def __len__(self):
44+
return self.size
45+
46+
@staticmethod
47+
def get_seg_files(paths, seg_dir, seg_ext='.seg'):
48+
segs = []
49+
for path in paths:
50+
segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
51+
assert(os.path.isfile(segfile))
52+
segs.append(segfile)
53+
return segs
54+
55+
@staticmethod
56+
def get_n_segs(classes_file, seg_files):
57+
if not os.path.isfile(classes_file):
58+
all_segs = np.array([], dtype='float64')
59+
for seg in seg_files:
60+
all_segs = np.concatenate((all_segs, read_seg(seg)))
61+
segnames = np.unique(all_segs)
62+
np.savetxt(classes_file, segnames, fmt='%d')
63+
classes = np.loadtxt(classes_file)
64+
offset = classes[0]
65+
classes = classes - offset
66+
return classes, offset
67+
68+
@staticmethod
69+
def make_dataset(path):
70+
meshes = []
71+
assert os.path.isdir(path), '%s is not a valid directory' % path
72+
73+
for root, _, fnames in sorted(os.walk(path)):
74+
for fname in fnames:
75+
if is_mesh_file(fname):
76+
path = os.path.join(root, fname)
77+
meshes.append(path)
78+
79+
return meshes
80+
81+
82+
def read_seg(seg):
83+
seg_labels = np.loadtxt(open(seg, 'r'), dtype='float64')
84+
return seg_labels
85+
86+
87+
def read_sseg(sseg_file):
88+
sseg_labels = read_seg(sseg_file)
89+
sseg_labels = np.array(sseg_labels > 0, dtype=np.int32)
90+
return sseg_labels

docs/imgs/T18.png

253 KB
Loading

docs/imgs/T252.png

171 KB
Loading

docs/imgs/T76.png

304 KB
Loading

docs/imgs/alien.gif

811 KB
Loading

docs/imgs/coseg_alien.png

186 KB
Loading

docs/imgs/coseg_chair.png

341 KB
Loading

0 commit comments

Comments
 (0)