forked from yanx27/Pointnet_Pointnet2_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
2,420 additions
and
1,162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,103 @@ | ||
import numpy as np | ||
import warnings | ||
import h5py | ||
import os | ||
from torch.utils.data import Dataset | ||
warnings.filterwarnings('ignore') | ||
|
||
def load_h5(h5_filename): | ||
f = h5py.File(h5_filename) | ||
data = f['data'][:] | ||
label = f['label'][:] | ||
seg = [] | ||
return (data, label, seg) | ||
|
||
def load_data(dir,classification = False): | ||
data_train0, label_train0,Seglabel_train0 = load_h5(dir + 'ply_data_train0.h5') | ||
data_train1, label_train1,Seglabel_train1 = load_h5(dir + 'ply_data_train1.h5') | ||
data_train2, label_train2,Seglabel_train2 = load_h5(dir + 'ply_data_train2.h5') | ||
data_train3, label_train3,Seglabel_train3 = load_h5(dir + 'ply_data_train3.h5') | ||
data_train4, label_train4,Seglabel_train4 = load_h5(dir + 'ply_data_train4.h5') | ||
data_test0, label_test0,Seglabel_test0 = load_h5(dir + 'ply_data_test0.h5') | ||
data_test1, label_test1,Seglabel_test1 = load_h5(dir + 'ply_data_test1.h5') | ||
train_data = np.concatenate([data_train0,data_train1,data_train2,data_train3,data_train4]) | ||
train_label = np.concatenate([label_train0,label_train1,label_train2,label_train3,label_train4]) | ||
train_Seglabel = np.concatenate([Seglabel_train0,Seglabel_train1,Seglabel_train2,Seglabel_train3,Seglabel_train4]) | ||
test_data = np.concatenate([data_test0,data_test1]) | ||
test_label = np.concatenate([label_test0,label_test1]) | ||
test_Seglabel = np.concatenate([Seglabel_test0,Seglabel_test1]) | ||
|
||
if classification: | ||
return train_data, train_label, test_data, test_label | ||
else: | ||
return train_data, train_Seglabel, test_data, test_Seglabel | ||
|
||
|
||
def pc_normalize(pc): | ||
centroid = np.mean(pc, axis=0) | ||
pc = pc - centroid | ||
m = np.max(np.sqrt(np.sum(pc**2, axis=1))) | ||
pc = pc / m | ||
return pc | ||
|
||
def farthest_point_sample(point, npoint): | ||
""" | ||
Input: | ||
xyz: pointcloud data, [N, D] | ||
npoint: number of samples | ||
Return: | ||
centroids: sampled pointcloud index, [npoint, D] | ||
""" | ||
N, D = point.shape | ||
xyz = point[:,:3] | ||
centroids = np.zeros((npoint,)) | ||
distance = np.ones((N,)) * 1e10 | ||
farthest = np.random.randint(0, N) | ||
for i in range(npoint): | ||
centroids[i] = farthest | ||
centroid = xyz[farthest, :] | ||
dist = np.sum((xyz - centroid) ** 2, -1) | ||
mask = dist < distance | ||
distance[mask] = dist[mask] | ||
farthest = np.argmax(distance, -1) | ||
point = point[centroids.astype(np.int32)] | ||
return point | ||
|
||
class ModelNetDataLoader(Dataset): | ||
def __init__(self, data, labels, rotation = None): | ||
self.data = data | ||
self.labels = labels | ||
self.rotation = rotation | ||
def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): | ||
self.root = root | ||
self.npoints = npoint | ||
self.uniform = uniform | ||
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') | ||
|
||
self.cat = [line.rstrip() for line in open(self.catfile)] | ||
self.classes = dict(zip(self.cat, range(len(self.cat)))) | ||
self.normal_channel = normal_channel | ||
|
||
shape_ids = {} | ||
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] | ||
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] | ||
|
||
assert (split == 'train' or split == 'test') | ||
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] | ||
# list of (shape_name, shape_txt_file_path) tuple | ||
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i | ||
in range(len(shape_ids[split]))] | ||
print('The size of %s data is %d'%(split,len(self.datapath))) | ||
|
||
self.cache_size = cache_size # how many data points to cache in memory | ||
self.cache = {} # from index to (point_set, cls) tuple | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def rotate_point_cloud_by_angle(self, data, rotation_angle): | ||
""" | ||
Rotate the point cloud along up direction with certain angle. | ||
:param batch_data: Nx3 array, original batch of point clouds | ||
:param rotation_angle: range of rotation | ||
:return: Nx3 array, rotated batch of point clouds | ||
""" | ||
cosval = np.cos(rotation_angle) | ||
sinval = np.sin(rotation_angle) | ||
rotation_matrix = np.array([[cosval, 0, sinval], | ||
[0, 1, 0], | ||
[-sinval, 0, cosval]]) | ||
rotated_data = np.dot(data, rotation_matrix) | ||
|
||
return rotated_data | ||
return len(self.datapath) | ||
|
||
def _get_item(self, index): | ||
if index in self.cache: | ||
point_set, cls = self.cache[index] | ||
else: | ||
fn = self.datapath[index] | ||
cls = self.classes[self.datapath[index][0]] | ||
cls = np.array([cls]).astype(np.int32) | ||
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) | ||
if self.uniform: | ||
point_set = farthest_point_sample(point_set, self.npoints) | ||
else: | ||
point_set = point_set[0:self.npoints,:] | ||
|
||
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) | ||
|
||
if not self.normal_channel: | ||
point_set = point_set[:, 0:3] | ||
|
||
if len(self.cache) < self.cache_size: | ||
self.cache[index] = (point_set, cls) | ||
|
||
return point_set, cls | ||
|
||
def __getitem__(self, index): | ||
if self.rotation is not None: | ||
pointcloud = self.data[index] | ||
angle = np.random.randint(self.rotation[0], self.rotation[1]) * np.pi / 180 | ||
pointcloud = self.rotate_point_cloud_by_angle(pointcloud, angle) | ||
return self._get_item(index) | ||
|
||
return pointcloud, self.labels[index] | ||
else: | ||
return self.data[index], self.labels[index] | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
import torch | ||
|
||
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,) | ||
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) | ||
for point,label in DataLoader: | ||
print(point.shape) | ||
print(label.shape) |
Oops, something went wrong.