Skip to content

Commit e101947

Browse files
committed
updated codes for classification
1 parent b4e7951 commit e101947

File tree

4 files changed

+268
-162
lines changed

4 files changed

+268
-162
lines changed

README.md

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_cvpr_2017/papers/Qi_PointNet_Deep_Learning_CVPR_2017_paper.pdf) and [PointNet++](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf) in pytorch.
44

55
## Update
6+
**2021/03/20:** Update codes for classification, including:
7+
8+
(1) Add codes for training **ModelNet10** dataset. Using setting of ``--num_category 10``.
9+
10+
(2) Add codes for running on CPU only. Using setting of ``--use_cpu``.
11+
12+
(3) Add codes for offline data preprocessing to accelerate training. Using setting of ``--process_data``.
13+
14+
(4) Add codes for training with uniform sampling. Using setting of ``--use_uniform_sample``.
15+
616
**2019/11/26:**
717

818
(1) Fixed some errors in previous codes and added data augmentation tricks. Now classification by only 1024 points can achieve 92.8\%!
@@ -11,33 +21,37 @@ This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_
1121

1222
(3) Organized all models into `./models` files for easy using.
1323

14-
If you find this repo useful in your research, please consider following and citing our other works:
15-
```
16-
@InProceedings{yan2020pointasnl,
17-
title={PointASNL: Robust Point Clouds Processing using Nonlocal Neural Networks with Adaptive Sampling},
18-
author={Yan, Xu and Zheng, Chaoda and Li, Zhen and Wang, Sheng and Cui, Shuguang},
19-
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
20-
year={2020}
21-
}
22-
```
23-
```
24-
@InProceedings{yan2021sparse,
25-
title={Sparse Single Sweep LiDAR Point Cloud Segmentation via Learning Contextual Shape Priors from Scene Completion},
26-
author={Yan, Xu and Gao, Jiantao and Li, Jie and Zhang, Ruimao, and Li, Zhen and Huang, Rui and Cui, Shuguang},
27-
journal={AAAI Conference on Artificial Intelligence ({AAAI})},
28-
year={2021}
29-
}
30-
```
31-
## Classification
24+
25+
## Classification (ModelNet10/40)
3226
### Data Preparation
3327
Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`.
3428

3529
### Run
36-
```
37-
## Check model in ./models
38-
## E.g. pointnet2_msg
39-
python train_cls.py --model pointnet2_cls_msg --normal --log_dir pointnet2_cls_msg
40-
python test_cls.py --normal --log_dir pointnet2_cls_msg
30+
You can run different modes with following codes.
31+
* If you want to use offline processing of data, you can use `--process_data` in the first run.
32+
* If you want to train on ModelNet10, you can use `--num_category 10`.
33+
```shell
34+
# ModelNet40
35+
## Select different models in ./models
36+
37+
## e.g., pointnet2_ssg without normal features
38+
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
39+
python test_classification.py --log_dir pointnet2_cls_ssg
40+
41+
## e.g., pointnet2_ssg with normal features
42+
python train_classification.py --model pointnet2_cls_ssg --use_normals --log_dir pointnet2_cls_ssg_normal
43+
python test_classification.py --use_normals --log_dir pointnet2_cls_ssg_normal
44+
45+
## e.g., pointnet2_ssg with uniform sampling
46+
python train_classification.py --model pointnet2_cls_ssg --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
47+
python test_classification.py --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
48+
49+
# ModelNet10
50+
## Similar setting like ModelNet40, just using --num_category 10
51+
52+
## e.g., pointnet2_ssg without normal features
53+
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg --num_category 10
54+
python test_classification.py --log_dir pointnet2_cls_ssg --num_category 10
4155
```
4256

4357
### Performance
@@ -57,7 +71,7 @@ Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/sh
5771
### Run
5872
```
5973
## Check model in ./models
60-
## E.g. pointnet2_msg
74+
## e.g., pointnet2_msg
6175
python train_partseg.py --model pointnet2_part_seg_msg --normal --log_dir pointnet2_part_seg_msg
6276
python test_partseg.py --normal --log_dir pointnet2_part_seg_msg
6377
```
@@ -73,29 +87,21 @@ python test_partseg.py --normal --log_dir pointnet2_part_seg_msg
7387

7488
## Semantic Segmentation
7589
### Data Preparation
76-
Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/Stanford3dDataset_v1.2_Aligned_Version/`.
90+
Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/s3dis/Stanford3dDataset_v1.2_Aligned_Version/`.
7791
```
7892
cd data_utils
7993
python collect_indoor3d_data.py
8094
```
81-
Processed data will save in `data/stanford_indoor3d/`.
95+
Processed data will save in `data/s3dis/stanford_indoor3d/`.
8296
### Run
8397
```
8498
## Check model in ./models
85-
## E.g. pointnet2_ssg
99+
## e.g., pointnet2_ssg
86100
python train_semseg.py --model pointnet2_sem_seg --test_area 5 --log_dir pointnet2_sem_seg
87101
python test_semseg.py --log_dir pointnet2_sem_seg --test_area 5 --visual
88102
```
89103
Visualization results will save in `log/sem_seg/pointnet2_sem_seg/visual/` and you can visualize these .obj file by [MeshLab](http://www.meshlab.net/).
90-
### Performance on sub-points of raw dataset (processed by official PointNet [Link](https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip))
91-
|Model | Class avg IoU |
92-
|--|--|
93-
| PointNet (Official) | 41.1|
94-
| PointNet (Pytorch) | 48.9|
95-
| PointNet2 (Official) |N/A |
96-
| PointNet2_ssg (Pytorch) | **53.2**|
97-
### Performance on raw dataset
98-
still on testing...
104+
99105

100106
## Visualization
101107
### Using show3d_balls.py
@@ -121,3 +127,31 @@ python show3d_balls.py
121127
Ubuntu 16.04 <br>
122128
Python 3.6.7 <br>
123129
Pytorch 1.1.0
130+
131+
132+
## Citation
133+
If you find this repo useful in your research, please consider citing it and our other works:
134+
```
135+
@article{Pytorch_Pointnet_Pointnet2,
136+
Author = {Xu Yan},
137+
Title = {Pointnet/Pointnet++ Pytorch},
138+
Journal = {https://github.com/yanx27/Pointnet_Pointnet2_pytorch},
139+
Year = {2019}
140+
}
141+
```
142+
```
143+
@InProceedings{yan2020pointasnl,
144+
title={PointASNL: Robust Point Clouds Processing using Nonlocal Neural Networks with Adaptive Sampling},
145+
author={Yan, Xu and Zheng, Chaoda and Li, Zhen and Wang, Sheng and Cui, Shuguang},
146+
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
147+
year={2020}
148+
}
149+
```
150+
```
151+
@InProceedings{yan2021sparse,
152+
title={Sparse Single Sweep LiDAR Point Cloud Segmentation via Learning Contextual Shape Priors from Scene Completion},
153+
author={Yan, Xu and Gao, Jiantao and Li, Jie and Zhang, Ruimao, and Li, Zhen and Huang, Rui and Cui, Shuguang},
154+
journal={AAAI Conference on Artificial Intelligence ({AAAI})},
155+
year={2021}
156+
}
157+
```

data_utils/ModelNetDataLoader.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
'''
2+
@author: Xu Yan
3+
@file: ModelNet.py
4+
@time: 2021/3/19 15:51
5+
'''
6+
import os
17
import numpy as np
28
import warnings
3-
import os
9+
import pickle
10+
11+
from tqdm import tqdm
412
from torch.utils.data import Dataset
5-
warnings.filterwarnings('ignore')
613

14+
warnings.filterwarnings('ignore')
715

816

917
def pc_normalize(pc):
@@ -13,6 +21,7 @@ def pc_normalize(pc):
1321
pc = pc / m
1422
return pc
1523

24+
1625
def farthest_point_sample(point, npoint):
1726
"""
1827
Input:
@@ -36,68 +45,103 @@ def farthest_point_sample(point, npoint):
3645
point = point[centroids.astype(np.int32)]
3746
return point
3847

48+
3949
class ModelNetDataLoader(Dataset):
40-
def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
50+
def __init__(self, root, args, split='train', process_data=False):
4151
self.root = root
42-
self.npoints = npoint
43-
self.uniform = uniform
44-
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
52+
self.npoints = args.num_point
53+
self.process_data = process_data
54+
self.uniform = args.use_uniform_sample
55+
self.use_normals = args.use_normals
56+
self.num_category = args.num_category
57+
58+
if self.num_category == 10:
59+
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
60+
else:
61+
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
4562

4663
self.cat = [line.rstrip() for line in open(self.catfile)]
4764
self.classes = dict(zip(self.cat, range(len(self.cat))))
48-
self.normal_channel = normal_channel
4965

5066
shape_ids = {}
51-
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
52-
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
67+
if self.num_category == 10:
68+
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
69+
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
70+
else:
71+
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
72+
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
5373

5474
assert (split == 'train' or split == 'test')
5575
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
56-
# list of (shape_name, shape_txt_file_path) tuple
5776
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
5877
in range(len(shape_ids[split]))]
59-
print('The size of %s data is %d'%(split,len(self.datapath)))
78+
print('The size of %s data is %d' % (split, len(self.datapath)))
6079

61-
self.cache_size = cache_size # how many data points to cache in memory
62-
self.cache = {} # from index to (point_set, cls) tuple
80+
if self.uniform:
81+
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
82+
else:
83+
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
84+
85+
if self.process_data:
86+
if not os.path.exists(self.save_path):
87+
print('Processing data %s (only running in the first time)...' % self.save_path)
88+
self.list_of_points = [None] * len(self.datapath)
89+
self.list_of_labels = [None] * len(self.datapath)
90+
91+
for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
92+
fn = self.datapath[index]
93+
cls = self.classes[self.datapath[index][0]]
94+
cls = np.array([cls]).astype(np.int32)
95+
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
96+
97+
if self.uniform:
98+
point_set = farthest_point_sample(point_set, self.npoints)
99+
else:
100+
point_set = point_set[0:self.npoints, :]
101+
102+
self.list_of_points[index] = point_set
103+
self.list_of_labels[index] = cls
104+
105+
with open(self.save_path, 'wb') as f:
106+
pickle.dump([self.list_of_points, self.list_of_labels], f)
107+
else:
108+
print('Load processed data from %s...' % self.save_path)
109+
with open(self.save_path, 'rb') as f:
110+
self.list_of_points, self.list_of_labels = pickle.load(f)
63111

64112
def __len__(self):
65113
return len(self.datapath)
66114

67115
def _get_item(self, index):
68-
if index in self.cache:
69-
point_set, cls = self.cache[index]
116+
if self.process_data:
117+
point_set, label = self.list_of_points[index], self.list_of_labels[index]
118+
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
70119
else:
71120
fn = self.datapath[index]
72121
cls = self.classes[self.datapath[index][0]]
73-
cls = np.array([cls]).astype(np.int32)
122+
label = np.array([cls]).astype(np.int32)
74123
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
124+
75125
if self.uniform:
76126
point_set = farthest_point_sample(point_set, self.npoints)
77127
else:
78-
point_set = point_set[0:self.npoints,:]
128+
point_set = point_set[0:self.npoints, :]
79129

80-
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
81-
82-
if not self.normal_channel:
83-
point_set = point_set[:, 0:3]
130+
if not self.use_normals:
131+
point_set = point_set[:, 0:3]
84132

85-
if len(self.cache) < self.cache_size:
86-
self.cache[index] = (point_set, cls)
87133

88-
return point_set, cls
134+
return point_set, label[0]
89135

90136
def __getitem__(self, index):
91137
return self._get_item(index)
92138

93139

94-
95-
96140
if __name__ == '__main__':
97141
import torch
98142

99-
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,)
143+
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train')
100144
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
101-
for point,label in DataLoader:
145+
for point, label in DataLoader:
102146
print(point.shape)
103-
print(label.shape)
147+
print(label.shape)

0 commit comments

Comments
 (0)