Skip to content

Commit e271e6b

Browse files
committed
Initial commit
1 parent b431402 commit e271e6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2714
-3
lines changed

README.md

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,173 @@
1-
# SMD-Nets
1+
# SMD-Nets: Stereo Mixture Density Networks
22

3-
"SMD-Nets: Stereo Mixture Density Networks"
3+
![Alt text](./images/architecture2-1.jpg "architecture")
4+
5+
6+
This repository contains a Pytorch implementation of "SMD-Nets: Stereo Mixture Density Networks" (CVPR 2021) by [Fabio Tosi](https://vision.disi.unibo.it/~ftosi/), [Yiyi Liao](https://www.is.mpg.de/person/yliao), [Carolin Schmitt](https://avg.is.tuebingen.mpg.de/person/cschmitt) and [Andreas Geiger](http://www.cvlibs.net/)
7+
8+
**Contributions:**
9+
10+
* A novel learning framework for stereo matching that exploits compactly parameterized bimodal mixture densities as output representation and can be trained using a simple likelihood-based loss function. Our simple formulation lets us avoid bleeding artifacts at depth discontinuities and provides a measure for aleatoric uncertainty.
11+
12+
* A continuous function formulation aimed at estimating disparities at arbitrary spatial resolution with constant memory footprint.
13+
14+
* A new large-scale synthetic binocular stereo dataset with ground truth disparities at 3840×2160 resolution, comprising photo-realistic renderings of indoor and outdoor environments.
15+
16+
For more details, please check:
17+
18+
[Paper](http://www.cvlibs.net/publications/Tosi2021CVPR.pdf)
19+
[Supplementary material](http://www.cvlibs.net/publications/Tosi2021CVPR_supplementary.pdf)
20+
21+
If you find this code useful in your research, please cite:
22+
23+
```shell
24+
@INPROCEEDINGS{Tosi2021CVPR,
25+
author = {Fabio Tosi and Yiyi Liao and Carolin Schmitt and Andreas Geiger},
26+
title = {SMD-Nets: Stereo Mixture Density Networks},
27+
booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR)},
28+
year = {2021}
29+
}
30+
```
31+
32+
## Requirements
33+
This code was tested with Python 3.8, Pytotch 1.8, CUDA 11.2 and Ubuntu 20.04. <br>All our experiments were performed on a single NVIDIA Titan V100 GPU.<br>Requirements can be installed using the following script:
34+
```shell
35+
pip install -r requirements
36+
```
37+
38+
## Datasets
39+
40+
We create our synthetic dataset, UnrealStereo4K, using the popular game engine [Unreal Engine](https://www.unrealengine.com/en-US/) combined with the open-source plugin [UnrealCV](https://unrealcv.org/).
41+
42+
### UnrealStereo4K
43+
Our photo-realistic synthetic passive binocular UnrealStereo4K dataset consists of images of 8 scenes, including indoor and outdoor environments. We rendered stereo pairs at 3840×2160 resolution for each scene with pixel-accurate ground truth (aligned with both the left and the right images!).
44+
For the active monocular UnrealStereo4K dataset, instead, we render 4 scenes using the intrinsic matrix of the IR camera on our structured light sensor. We then warp the reference dot pattern to each image to simulate the IR camera.
45+
46+
<u>***Both datasets will be publicly available soon!***</u>
47+
48+
### RealActive4K
49+
Our real-world active dataset consists of 2570 images of an indoor room captured with a Kinect-like structured light sensor at 4112×3008 resolution. To obtain pseudo-ground truth as co-supervision during training, we perform BlockMatching with left-right consistency check.
50+
51+
<u>***The dataset will be publicly available soon!***</u>
52+
53+
54+
## Training
55+
56+
All training and testing scripts are provided in the ```scripts``` folder. <br>
57+
As an example, use the following command to train SMD-Nets on our UnrealStereo4K dataset.
58+
59+
```shell
60+
python apps/train.py --dataroot $dataroot \
61+
--checkpoints_path $checkpoints_path \
62+
--training_file $training_file \
63+
--testing_file $testing_file \
64+
--results_path $results_path \
65+
--mode $mode \
66+
--name $name \
67+
--batch_size $batch_size \
68+
--num_epoch $num_epoch \
69+
--learning_rate $learning_rate \
70+
--gamma $gamma \
71+
--crop_height $crop_height \
72+
--crop_width $crop_width \
73+
--num_sample_inout $num_sample_inout \
74+
--aspect_ratio $aspect_ratio \
75+
--sampling $sampling \
76+
--output_representation $output_representation \
77+
--backbone $backbone
78+
```
79+
For a detailed description of training options, please take a look at ```lib/options.py```<br>
80+
81+
In order to monitor and visualize the training process, you can start a tensorboard session with:
82+
83+
```shell
84+
tensorboard --logdir checkpoints
85+
```
86+
87+
## Evaluation
88+
89+
Use the following command to evaluate the trained SMD-Nets on our UnrealStereo4K dataset.
90+
91+
```shell
92+
python apps/test.py --dataroot $dataroot \
93+
--testing_file $testing_file \
94+
--results_path $results_path \
95+
--mode $mode \
96+
--batch_size 1 \
97+
--superes_factor $superes_factor \
98+
--aspect_ratio $aspect_ratio \
99+
--output_representation $output_representation \
100+
--load_checkpoint_path $checkpoints_path \
101+
--backbone $backbone
102+
```
103+
104+
**Warning!** The soft edge error (SEE) on the KITTI dataset requires instance segmentation maps from the KITTI 2015 dataset.
105+
106+
107+
<u>**Stereo Ultra High-Resolution**</u>: if you want to estimate a disparity map at arbitrary spatial resolution given a low resolution stereo pair at testing time, just use a different value for the ```superres_factor``` parameter (e.g. 2,4,8..32!). Below, a comparison of our model using the PSMNet backbone at 128Mpx resolution (top) and the original PSMNet at 0.5Mpx resolution (bottom), both taking stereo pairs at 0.5Mpx resolution as input.
108+
109+
<p align="center">
110+
<img src="./images/super_resolution.jpg" width="700" />
111+
</p>
112+
113+
114+
## Pretrained models
115+
116+
You can download pre-trained models on our UnrealStereo4K dataset from the following links:
117+
118+
* [PSMNet + SMD Head](https://drive.google.com/file/d/1lDguePc4yVnVjwxxRhez3wgLhzfOJkyP/view?usp=sharing)
119+
120+
* [PSMNet + SMD Head](https://drive.google.com/file/d/1PHICTx08m3kIxNQHmIrgoUFfrSyLbTiQ/view?usp=sharing) (fine-tuned on KITTI)
121+
122+
* [HSMNet + SMD Head](https://drive.google.com/file/d/1NG7yX8YGDGemKm2M8TNceSxlkMhEvOBU/view?usp=sharing)
123+
124+
125+
## Qualitative results
126+
127+
**Disparity Visualization.** Some qualitative results of the proposed SMD-Nets using PSMNet as stereo backbone. From left to right, the input image from the UnrealStereo4K test set, the predicted disparity and the corresponding error map. Please zoom-in to better perceive details near depth boundaries.
128+
129+
<p float="left">
130+
<img src="./images/img.jpg" width="290" />
131+
<img src="./images/pred.jpg" width="290" />
132+
<img src="./images/error.jpg" width="290" />
133+
</p>
134+
<p float="left">
135+
<img src="./images/img1.jpg" width="290" />
136+
<img src="./images/pred1.jpg" width="290" />
137+
<img src="./images/error1.jpg" width="290" />
138+
</p>
139+
<p float="left">
140+
<img src="./images/img2.jpg" width="290" />
141+
<img src="./images/pred2.jpg" width="290" />
142+
<img src="./images/error2.jpg" width="290" />
143+
</p>
144+
145+
**Point Cloud Visualization.** Below, instead, we show point cloud visualizations on UnrealStereo4K for both the passive binocular stereo and the active depth datasets, adopting HSMNet as backbone. From left to right, the reference image, the results obtained using a standard disparity regression (i.e., disparity point estimate), a unimodal Laplacian distribution and our bimodal Laplacian mixture distribution. Note that our bimodal representation notably alleviates bleeding artifacts near object boundaries compared to both disparity regression and the unimodal formulation.
146+
147+
<p float="left">
148+
<img src="./images/img4.jpg" width="220" />
149+
<img src="./images/pcl_hsm_l1_passive.jpg" width="220" />
150+
<img src="./images/pcl_hsm_unimodal_passive.jpg" width="220" />
151+
<img src="./images/pcl_hsm_ours_passive.jpg" width="220" />
152+
</p>
153+
154+
<p float="left">
155+
<img src="./images/img3.jpg" width="250" />
156+
<img src="./images/pcl_hsm_l1_active.jpg" width="210" />
157+
<img src="./images/pcl_hsm_unimodal_active.jpg" width="210" />
158+
<img src="./images/pcl_hsm_ours_active.jpg" width="210" />
159+
</p>
160+
161+
## Contacts
162+
163+
For questions, please send an email to [email protected]
164+
165+
## Acknowledgements
166+
167+
We thank the authors that shared the code of their works. In particular:
168+
169+
* Jia-Ren Chang for providing the code of [PSMNet](https://github.com/JiaRenChang/PSMNet).
170+
* Gengshan Yang for providing the code of [HSMNet](https://github.com/gengshan-y/high-res-stereo).
171+
* Clement Godard for providing the code of [Monodepth](https://github.com/mrharicot/monodepth) (extended to Stereodepth).
172+
* Shunsuke Saito for providing the code of [PIFu](https://github.com/shunsukesaito/PIFu)
4173

5-
## Code and dataset available soon...

apps/test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6+
7+
import torch
8+
from torch.utils.data import DataLoader
9+
from lib.options import BaseOptions
10+
from lib.data import Loader
11+
from lib.model import *
12+
from lib.evaluation_utils import *
13+
14+
# get options
15+
opt = BaseOptions().parse()
16+
17+
def test(opt):
18+
# set cuda
19+
cuda = torch.device('cuda:%d' % opt.gpu_id)
20+
21+
test_dataset = Loader(opt, phase='test')
22+
23+
test_data_loader = DataLoader(test_dataset,
24+
batch_size=1, shuffle=False,
25+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
26+
print('test data size: ', len(test_data_loader))
27+
28+
# create net
29+
net = SMDHead(opt).to(device=cuda)
30+
31+
def set_eval():
32+
net.eval()
33+
34+
# load checkpoints
35+
if opt.load_checkpoint_path is not None:
36+
print('loading weights ...', opt.load_checkpoint_path)
37+
net.load_state_dict(torch.load(opt.load_checkpoint_path, map_location=cuda)['state_dict'])
38+
39+
os.makedirs('%s' % (opt.results_path), exist_ok=True)
40+
41+
with torch.no_grad():
42+
set_eval()
43+
validation(opt, net, cuda, test_dataset, num_gen_test=len(test_dataset), save_imgs=True)
44+
45+
if __name__ == '__main__':
46+
test(opt)

0 commit comments

Comments
 (0)