Aoran Xiao*,
Weihao Xuan*,
Heli Qi,
Yun Xing,
Naoto Yokoya^,
Shijian Lu^
(* indicates co-first authors with equal contributions. ^ indicates the corresponding authors.)
Robust and accurate segmentation of scenes has become one core functionality in various visual recognition and navigation tasks. This has inspired the recent development of Segment Anything Model (SAM), a foundation model for general mask segmentation. However, SAM is largely tailored for single-modal RGB images, limiting its applicability to multi-modal data captured with widely-adopted sensor suites, such as LiDAR plus RGB, depth plus RGB, thermal plus RGB, etc. We develop MM-SAM, an extension and expansion of SAM that supports cross-modal and multi-modal processing for robust and enhanced segmentation with different sensor suites. MM-SAM features two key designs, namely, unsupervised cross-modal transfer and weakly-supervised multi-modal fusion, enabling label-efficient and parameter-efficient adaptation toward various sensor modalities. It addresses three main challenges: 1) adaptation toward diverse non-RGB sensors for single-modal processing, 2) synergistic processing of multi-modal data via sensor fusion, and 3) mask-free training for different downstream tasks. Extensive experiments show that MM-SAM consistently outperforms SAM by large margins, demonstrating its effectiveness and robustness across various sensors and data modalities.
- (2024/9) We released the testing code. Thank you for your waiting!
Please clone our project to your local machine and prepare our environment by the following commands:
conda create -n mm_sam python=3.10 -y
conda activate mm_sam
cd /your/path/to/destination/mm-sam
(mm_sam) python -m pip install -e .
(mm_sam) conda install -c conda-forge gdal==3.8.3
Note: if your local CUDA version is 11.8, you may need to manually install torch
and torchvision
by
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
before python -m pip install -e .
The code has been tested on A100/A6000/V100 with Python 3.10, CUDA 11.8/12.1 and Pytorch 2.1.2. Any other devices and environments may require to update the code for compatibility.
Please register the following 3 paths in utilbox/global_config.py
:
EXP_ROOT
: the root directory where the checkpoint files of trained models will be saved.DATA_ROOT
: the root directory where you want to place the datasets.PRETRAINED_ROOT
: the root directory where you want to place the SAM pretrained model files.
Please download SAM pretrained models from this link. You should have the following file structure:
{your PRETRAINED_ROOT}
|___sam_vit_b_01ec64.pth
|___sam_vit_l_0b3195.pth
|___sam_vit_h_4b8939.pth
Please follow the instructions below to prepare each dataset.
conda activate mm_sam
(mm_sam): python -m pyscripts.sunrgbd_setup
You are expected to have the following file structure:
{your DATA_ROOT}
|___sunrgbd
|___SUNRGBD
| |___kv1
| |___kv2
| |___realsense
| |___xtion
|___SUNRGBDtoolbox
|___test_depth.txt
|___test_label.txt
|___test_rgb.txt
|___train_depth.txt
|___train_label.txt
|___train_rgb.txt
Please download ir_seg_dataset
from this link to {your DATA_ROOT}
cd {your DATA_ROOT}
unzip ir_seg_dataset.zip
mv ir_seg_dataset MFNet
rm ir_seg_dataset.zip
Then, open {your DATA_ROOT}/MFNet/make_flip.py
and change root_dir
to {your DATA_ROOT}/MFNet
and run
cd MFNet
conda activate mm_sam
(mm_sam): python make_flip.py
You are expected to have the following file structure:
{your DATA_ROOT}
|___MFNet
|___images
| |___00001D.png
| |___00001D_flip.png
| |___...
| |___01606D.png
|___labels
| |___00001D.png
| |___00001D_flip.png
| |___...
| |___01606D.png
|___train.txt
|___val.txt
|___test.txt
|___test_day.txt
|___test_night.txt
cd {your DATA_ROOT}
wget http://hyperspectral.ee.uh.edu/QZ23es1aMPH/2018IEEE/phase2.zip
unzip phase2.zip
mv 2018IEEE_Contest/ dfc18
rm phase2.zip
Then, download test_gt_mask_ori_scale.png
from this link to {your DATA_ROOT}/dfc18/Phase2/TrainingGT
.
You are expected to have the following file structure:
{your DATA_ROOT}
|___dfc18
|___Phase2
|___Final RGB HR Imagery
|___FullHSIDataset
|___Lidar GeoTiff Rasters
|___Lidar Point Cloud Tiles
|___TrainingGT
|___...
|___test_gt_mask_ori_scale.png
Finally, run
conda activate mm_sam
(mm_sam) python -m pyscripts.dfc18_setup
dfc18_dump
will be created in {your DATA_ROOT}
with the following file structure:
{your DATA_ROOT}
|___dfc18_dump
|___test
| |___row0_col0.npz
| |___...
| |___row2_col5.npz
|___train
| |___trow0_tcol0_angle0_scale0.80_urow0_ucol0.npz
| |___...
| |___trow1_tcol6_angle160_scale1.20_urow3_ucol3.npz
|___Visualization
|___test.json
|___train.json
Note: dfc18_dump
will consume around 170G disk space.
Download train.tgz
from this link to {your DATA_ROOT}/dfc23
and run
cd {your DATA_ROOT}/dfc23
tar -xzvf train.tgz
rm train.tgz
You are expected to have the following file structure:
{your DATA_ROOT}
|___dfc23
|___train
| |___rgb
| |___sar
|___roof_fine_train.json
|___roof_fine_train_corrected.json
Note: If you are decompressing train.tgz
on a Linux OS, you may need to run
cd {your DATA_ROOT}/dfc23
find ./ -name ".*.tif" -type f -delete
find ./ -name ".*.json" -type f -delete
to remove useless files introduced by the OS mismatch.
Then, run
conda activate mm_sam
(mm_sam): python -m pyscripts.dfc23_setup
Finally, the file structure will be
{your DATA_ROOT}
|___dfc23
|___train
| |___rgb
| |___sar
|___roof_fine_train.json
|___roof_fine_train_corrected.json
|___metadata.json
Coming soon.
Test will be automatically done after each training job and the results will be printed to the console. If you want to merely evaluate your trained checkpoint, you can run the following command after you finish the training:
conda activate mm_sam
(mm_sam): python -m pyscripts.launch --config_name {cm_transfer or mm_fusion}/{dataset}_{gpu_config} --test_only True
We also provide our checkpoints used in our paper. You can find two kinds of checkpoints at our HuggingFace page:
{dataset}_{modality}_encoder_vit_b.pth
: The X encoder of{modality}
data trained by UCMT.{dataset}_{modality}_sfg_vit_b.pth
: The SFG module of{modality}
data trained by WMMF based on{dataset}_{modality}_encoder_vit_b.pth
.
You can download them to your local machine and evaluate their performance by our configuration files using the command:
- For UCMT experiments,
conda activate mm_sam (mm_sam): python -m pyscripts.launch --config_name cm_transfer/{dataset}_{gpu_config} --test_only True --ckpt_path /your/path/to/{dataset}_{modality}_encoder_vit_b.pth
- For WMMF experiments, first modify
agent_kwargs[x_encoder_ckpt_path]
in your target configuration./config/mm_fusion/{dataset}_{gpu_config}.yaml
to/your/path/to/{dataset}_{modality}_encoder_vit_b.pth
. Then, runconda activate mm_sam (mm_sam): python -m pyscripts.launch --config_name mm_fusion/{dataset}_{gpu_config} --test_only True --ckpt_path /your/path/to/{dataset}_{modality}_encoder_vit_b.pth
Note: we recommend you to download the model checkpoints to {your PRETRAIN_ROOT}/mmsam_ckpt
for easy management.
We provide user-friendly API for model inference by HuggingFace integration.
Below is an example to conduct inference of UCMT models on an image sample from SunRGBD dataset.
import torch
from mm_sam.models.sam import SAMbyUCMT
from utilbox.global_config import DATA_ROOT
from utilbox.data_load.read_utils import read_depth_from_disk
ucmt_sam = SAMbyUCMT.from_pretrained("weihao1115/ucmt_sam_on_depth")
ucmt_sam = ucmt_sam.to("cuda").eval()
# depth_image: (H, W, 1)
depth_image_path = f"{DATA_ROOT}/sunrgbd/SUNRGBD/kv2/kinect2data/000002_2014-05-26_14-23-37_260595134347_rgbf000103-resize/depth_bfx/0000103.png"
depth_image = read_depth_from_disk(depth_image_path, return_tensor=True)
ucmt_sam.set_infer_img(img=depth_image, channel_last=True)
# 1. Single-box Inference
# box_coords: (4,)
box_coords = torch.Tensor([291, 53, 729, 388])
pred_masks, pred_ious = ucmt_sam.infer(box_coords=box_coords)
# pred_mask: (H, W) 0-1 binary mask in torch.Tensor type
pred_mask = pred_masks[0].squeeze()
# 2. Multi-box Inference
# 2.1. ensemble predicted mask
# box_coords: (2, 4)
box_coords = torch.Tensor([[291, 53, 729, 388], [23, 289, 729, 529]])
pred_masks, pred_ious = ucmt_sam.infer(box_coords=box_coords)
# pred_mask: (H, W) 0-1 binary mask
pred_mask = pred_masks[0].squeeze()
# 2.2. separate predicted mask
pred_masks, pred_ious = ucmt_sam.infer(box_coords=box_coords, return_all_prompt_masks=True)
# pred_mask: (2, H, W) 0-1 binary mask in torch.Tensor type
pred_mask = pred_masks[0].squeeze()
If you want to use your trained checkpoint files to conduct inference, please initialize UCMT model by
from mm_sam.models.sam import SAMbyUCMT
from utilbox.global_config import EXP_ROOT
ucmt_sam = SAMbyUCMT.from_pretrained("weihao1115/ucmt_sam_on_depth")
ucmt_sam.load_x_encoder(f"{EXP_ROOT}/cm_transfer/sunrgbd_1x4090/checkpoints/best_mean_nonzero_fore_iu_models/your_checkpoint.pth")
Above is an example for the UCMT models trained on SunRGBD. Please refer to our HuggingFace page for more available checkpoints.
Below is an example to conduct inference of WMMF models on an image sample from SunRGBD dataset.
import torch
from mm_sam.models.sam import SAMbyWMMF
from utilbox.global_config import DATA_ROOT
from utilbox.data_load.read_utils import read_image_as_rgb_from_disk, read_depth_from_disk
wmmf_sam = SAMbyWMMF.from_pretrained("weihao1115/wmmf_sam_on_depth")
wmmf_sam = wmmf_sam.to("cuda").eval()
# rgb_image: (H, W, 3) 0-255 RGB image
rgb_image_path = f"{DATA_ROOT}/sunrgbd/SUNRGBD/kv2/kinect2data/000002_2014-05-26_14-23-37_260595134347_rgbf000103-resize/image/0000103.jpg"
rgb_image = read_image_as_rgb_from_disk(rgb_image_path, return_tensor=True)
# depth_image: (H, W, 1)
depth_image_path = f"{DATA_ROOT}/sunrgbd/SUNRGBD/kv2/kinect2data/000002_2014-05-26_14-23-37_260595134347_rgbf000103-resize/depth_bfx/0000103.png"
depth_image = read_depth_from_disk(depth_image_path, return_tensor=True)
wmmf_sam.set_infer_img(rgb_img=rgb_image, x_img=depth_image, channel_last=True)
# 1. Single-box Inference
# box_coords: (4,)
box_coords = torch.Tensor([291, 53, 729, 388])
pred_masks, pred_ious = wmmf_sam.infer(box_coords=box_coords)
# pred_mask: (H, W)
pred_mask = pred_masks[0].squeeze()
# 2. Multi-box Inference
# 2.1. ensemble predicted mask
# box_coords: (2, 4)
box_coords = torch.Tensor([[291, 53, 729, 388], [23, 289, 729, 529]])
pred_masks, pred_ious = wmmf_sam.infer(box_coords=box_coords)
# pred_mask: (H, W) 0-1 binary mask
pred_mask = pred_masks[0].squeeze()
# 2.2. separate predicted mask
pred_masks, pred_ious = wmmf_sam.infer(box_coords=box_coords, return_all_prompt_masks=True)
# pred_mask: (2, H, W) 0-1 binary mask
pred_mask = pred_masks[0].squeeze()
If you want to use your trained checkpoint files to conduct inference, please initialize WMMF model by
from mm_sam.models.sam import SAMbyWMMF
from utilbox.global_config import EXP_ROOT
wmmf_sam = SAMbyWMMF.from_pretrained("weihao1115/wmmf_sam_on_depth")
# if you want to change the weights of the X encoder
wmmf_sam.load_x_encoder(f"{EXP_ROOT}/cm_transfer/sunrgbd_1x4090/checkpoints/best_mean_nonzero_fore_iu_models/your_checkpoint.pth")
# if you want to change the weights of the SFG module
wmmf_sam.load_sfg(f"{EXP_ROOT}/mm_fusion/sunrgbd_1x4090/checkpoints/best_mean_nonzero_fore_iu_models/your_checkpoint.pth")
Above is an example for the WMMF models trained on SUNRGBD. Please refer to our HuggingFace page for more available checkpoints.
If you find this work helpful, please kindly consider citing our paper:
@article{mmsam,
title={Segment Anything with Multiple Modalities},
author={Aoran Xiao and Weihao Xuan and Heli Qi and Yun Xing and Naoto Yokoya and Shijian Lu},
journal={arXiv preprint arXiv:2408.09085},
year={2024}
}
We acknowledge the use of the following public resources throughout this work: Segment Anything Model, and LoRA.
Find our other projects for visual foundation models!
CAT-SAM: Conditional Tuning for Few-Shot Adaptation of Segment Anything Model, ECCV 2024, Oral Paper.