Skip to content

Commit

Permalink
updated codes for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
yanx27 committed Mar 20, 2021
1 parent b4e7951 commit e101947
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 162 deletions.
106 changes: 70 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_cvpr_2017/papers/Qi_PointNet_Deep_Learning_CVPR_2017_paper.pdf) and [PointNet++](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf) in pytorch.

## Update
**2021/03/20:** Update codes for classification, including:

(1) Add codes for training **ModelNet10** dataset. Using setting of ``--num_category 10``.

(2) Add codes for running on CPU only. Using setting of ``--use_cpu``.

(3) Add codes for offline data preprocessing to accelerate training. Using setting of ``--process_data``.

(4) Add codes for training with uniform sampling. Using setting of ``--use_uniform_sample``.

**2019/11/26:**

(1) Fixed some errors in previous codes and added data augmentation tricks. Now classification by only 1024 points can achieve 92.8\%!
Expand All @@ -11,33 +21,37 @@ This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_

(3) Organized all models into `./models` files for easy using.

If you find this repo useful in your research, please consider following and citing our other works:
```
@InProceedings{yan2020pointasnl,
title={PointASNL: Robust Point Clouds Processing using Nonlocal Neural Networks with Adaptive Sampling},
author={Yan, Xu and Zheng, Chaoda and Li, Zhen and Wang, Sheng and Cui, Shuguang},
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
```
@InProceedings{yan2021sparse,
title={Sparse Single Sweep LiDAR Point Cloud Segmentation via Learning Contextual Shape Priors from Scene Completion},
author={Yan, Xu and Gao, Jiantao and Li, Jie and Zhang, Ruimao, and Li, Zhen and Huang, Rui and Cui, Shuguang},
journal={AAAI Conference on Artificial Intelligence ({AAAI})},
year={2021}
}
```
## Classification

## Classification (ModelNet10/40)
### Data Preparation
Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`.

### Run
```
## Check model in ./models
## E.g. pointnet2_msg
python train_cls.py --model pointnet2_cls_msg --normal --log_dir pointnet2_cls_msg
python test_cls.py --normal --log_dir pointnet2_cls_msg
You can run different modes with following codes.
* If you want to use offline processing of data, you can use `--process_data` in the first run.
* If you want to train on ModelNet10, you can use `--num_category 10`.
```shell
# ModelNet40
## Select different models in ./models

## e.g., pointnet2_ssg without normal features
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
python test_classification.py --log_dir pointnet2_cls_ssg

## e.g., pointnet2_ssg with normal features
python train_classification.py --model pointnet2_cls_ssg --use_normals --log_dir pointnet2_cls_ssg_normal
python test_classification.py --use_normals --log_dir pointnet2_cls_ssg_normal

## e.g., pointnet2_ssg with uniform sampling
python train_classification.py --model pointnet2_cls_ssg --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
python test_classification.py --use_uniform_sample --log_dir pointnet2_cls_ssg_fps

# ModelNet10
## Similar setting like ModelNet40, just using --num_category 10

## e.g., pointnet2_ssg without normal features
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg --num_category 10
python test_classification.py --log_dir pointnet2_cls_ssg --num_category 10
```

### Performance
Expand All @@ -57,7 +71,7 @@ Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/sh
### Run
```
## Check model in ./models
## E.g. pointnet2_msg
## e.g., pointnet2_msg
python train_partseg.py --model pointnet2_part_seg_msg --normal --log_dir pointnet2_part_seg_msg
python test_partseg.py --normal --log_dir pointnet2_part_seg_msg
```
Expand All @@ -73,29 +87,21 @@ python test_partseg.py --normal --log_dir pointnet2_part_seg_msg

## Semantic Segmentation
### Data Preparation
Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/Stanford3dDataset_v1.2_Aligned_Version/`.
Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/s3dis/Stanford3dDataset_v1.2_Aligned_Version/`.
```
cd data_utils
python collect_indoor3d_data.py
```
Processed data will save in `data/stanford_indoor3d/`.
Processed data will save in `data/s3dis/stanford_indoor3d/`.
### Run
```
## Check model in ./models
## E.g. pointnet2_ssg
## e.g., pointnet2_ssg
python train_semseg.py --model pointnet2_sem_seg --test_area 5 --log_dir pointnet2_sem_seg
python test_semseg.py --log_dir pointnet2_sem_seg --test_area 5 --visual
```
Visualization results will save in `log/sem_seg/pointnet2_sem_seg/visual/` and you can visualize these .obj file by [MeshLab](http://www.meshlab.net/).
### Performance on sub-points of raw dataset (processed by official PointNet [Link](https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip))
|Model | Class avg IoU |
|--|--|
| PointNet (Official) | 41.1|
| PointNet (Pytorch) | 48.9|
| PointNet2 (Official) |N/A |
| PointNet2_ssg (Pytorch) | **53.2**|
### Performance on raw dataset
still on testing...


## Visualization
### Using show3d_balls.py
Expand All @@ -121,3 +127,31 @@ python show3d_balls.py
Ubuntu 16.04 <br>
Python 3.6.7 <br>
Pytorch 1.1.0


## Citation
If you find this repo useful in your research, please consider citing it and our other works:
```
@article{Pytorch_Pointnet_Pointnet2,
Author = {Xu Yan},
Title = {Pointnet/Pointnet++ Pytorch},
Journal = {https://github.com/yanx27/Pointnet_Pointnet2_pytorch},
Year = {2019}
}
```
```
@InProceedings{yan2020pointasnl,
title={PointASNL: Robust Point Clouds Processing using Nonlocal Neural Networks with Adaptive Sampling},
author={Yan, Xu and Zheng, Chaoda and Li, Zhen and Wang, Sheng and Cui, Shuguang},
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
```
@InProceedings{yan2021sparse,
title={Sparse Single Sweep LiDAR Point Cloud Segmentation via Learning Contextual Shape Priors from Scene Completion},
author={Yan, Xu and Gao, Jiantao and Li, Jie and Zhang, Ruimao, and Li, Zhen and Huang, Rui and Cui, Shuguang},
journal={AAAI Conference on Artificial Intelligence ({AAAI})},
year={2021}
}
```
102 changes: 73 additions & 29 deletions data_utils/ModelNetDataLoader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import os
import pickle

from tqdm import tqdm
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')

warnings.filterwarnings('ignore')


def pc_normalize(pc):
Expand All @@ -13,6 +21,7 @@ def pc_normalize(pc):
pc = pc / m
return pc


def farthest_point_sample(point, npoint):
"""
Input:
Expand All @@ -36,68 +45,103 @@ def farthest_point_sample(point, npoint):
point = point[centroids.astype(np.int32)]
return point


class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
def __init__(self, root, args, split='train', process_data=False):
self.root = root
self.npoints = npoint
self.uniform = uniform
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.npoints = args.num_point
self.process_data = process_data
self.uniform = args.use_uniform_sample
self.use_normals = args.use_normals
self.num_category = args.num_category

if self.num_category == 10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel

shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
if self.num_category == 10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]

assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d'%(split,len(self.datapath)))
print('The size of %s data is %d' % (split, len(self.datapath)))

self.cache_size = cache_size # how many data points to cache in memory
self.cache = {} # from index to (point_set, cls) tuple
if self.uniform:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
else:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))

if self.process_data:
if not os.path.exists(self.save_path):
print('Processing data %s (only running in the first time)...' % self.save_path)
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)

for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)

if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]

self.list_of_points[index] = point_set
self.list_of_labels[index] = cls

with open(self.save_path, 'wb') as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print('Load processed data from %s...' % self.save_path)
with open(self.save_path, 'rb') as f:
self.list_of_points, self.list_of_labels = pickle.load(f)

def __len__(self):
return len(self.datapath)

def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[index]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)

if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set = point_set[0:self.npoints, :]

point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

if not self.normal_channel:
point_set = point_set[:, 0:3]
if not self.use_normals:
point_set = point_set[:, 0:3]

if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)

return point_set, cls
return point_set, label[0]

def __getitem__(self, index):
return self._get_item(index)




if __name__ == '__main__':
import torch

data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,)
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train')
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
for point,label in DataLoader:
for point, label in DataLoader:
print(point.shape)
print(label.shape)
print(label.shape)
Loading

0 comments on commit e101947

Please sign in to comment.