Skip to content

SAM2.0 Export Branch #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
b7aa776
Export to onnx
kyakuno Aug 21, 2024
5c8c6cd
Export to onnx
kyakuno Aug 21, 2024
baa5202
Export to tflite
kyakuno Aug 21, 2024
3c72e90
Implement tflite export
kyakuno Aug 21, 2024
f1e5ab5
Remove tupple
kyakuno Aug 21, 2024
4ce30eb
Change dynamic axis
kyakuno Aug 21, 2024
7627c28
Improve broadcast error
kyakuno Aug 21, 2024
40de3d8
Fix export for mask decoder
kyakuno Aug 21, 2024
68613b0
Implement onnx inference
kyakuno Aug 22, 2024
dc014ed
Infer using onnx runtime
kyakuno Aug 22, 2024
8438f9f
Export dense pe
kyakuno Aug 22, 2024
1bf6439
Export image encoder to tflite
kyakuno Aug 23, 2024
98e8297
Export to tflite
kyakuno Aug 23, 2024
ed87f6a
Int8 quantization
kyakuno Aug 23, 2024
2e8d096
Quantize mask decoder
kyakuno Aug 23, 2024
f9c8a69
Implement video predictor
kyakuno Aug 23, 2024
98a19fb
Added short video
kyakuno Aug 23, 2024
bf5b16b
Implement onnx inference for video decoder
kyakuno Aug 24, 2024
bed8914
Implement multimask_output argument
kyakuno Aug 24, 2024
7fde3f1
Implement sparse embeddings dynamic axis
kyakuno Aug 24, 2024
fc6dfda
Connect import onnx
kyakuno Aug 24, 2024
cfcd64c
Only export image encoder core
kyakuno Aug 24, 2024
78e1c36
Fix onnx inference error
kyakuno Aug 24, 2024
37573c8
Add output name
kyakuno Aug 24, 2024
d388d2f
Export memory attention and encoder
kyakuno Aug 25, 2024
ebbda3d
Disable export memory attention
kyakuno Aug 25, 2024
08f84c5
Fix type error
kyakuno Aug 25, 2024
cc43208
Implemen masks arguments
kyakuno Aug 26, 2024
bcfca16
Replace repeat interleave
kyakuno Aug 26, 2024
6ec15f6
Fix tflite export error
kyakuno Aug 26, 2024
9fa36af
Fix torch inference error
kyakuno Aug 26, 2024
a02df7c
Fix torch inference error
kyakuno Aug 26, 2024
6fd0a6b
Support model size
kyakuno Aug 29, 2024
1f2c9da
Fix model id selection
kyakuno Aug 29, 2024
cdb71ec
Use dynamic quantize
kyakuno Aug 30, 2024
7d1ff38
Export to sub folder
kyakuno Sep 2, 2024
a224ae5
Update checkpoint
kyakuno Sep 2, 2024
7506a21
Excport memory encoder
kyakuno Sep 2, 2024
ba20e68
Implement matmul version of memory attention
kyakuno Sep 2, 2024
0185597
Fix inference code
kyakuno Sep 2, 2024
e49a1e7
Test rotary enc
kyakuno Sep 2, 2024
249c7a1
Connect model_id
kyakuno Sep 2, 2024
a22a618
Implement tflite import
kyakuno Sep 2, 2024
3063ba5
Export memory attention to tflite
kyakuno Sep 2, 2024
1becd17
Fix checkpoint name
kyakuno Sep 2, 2024
340fff0
Export mlp
kyakuno Sep 2, 2024
0831211
Added assertion
kyakuno Sep 2, 2024
c724630
Added assertion
kyakuno Sep 2, 2024
19afdf0
Load onnx at once
kyakuno Sep 2, 2024
7a3772c
Load onnx at once
kyakuno Sep 2, 2024
a9f2d18
Fix export memory encoder
kyakuno Sep 3, 2024
cc9acbe
Implement inference mode
kyakuno Sep 3, 2024
f2d1dfc
Export mlp to tflite
kyakuno Sep 4, 2024
bbf23d1
Update usage
kyakuno Sep 4, 2024
2246959
Fix memory encoder tflite export
kyakuno Sep 4, 2024
41f688e
Export memory attention
kyakuno Sep 4, 2024
f36169e
Implement tflite inference
kyakuno Sep 4, 2024
a410735
Prepare rotenc weight
kyakuno Sep 4, 2024
b17e0f2
Separate memory 1 and memory 2
kyakuno Sep 5, 2024
1532106
Fix onnx dynamic shape
kyakuno Sep 5, 2024
a7a2792
Improve export code
kyakuno Sep 6, 2024
ccc4b4d
Fix shape
kyakuno Sep 6, 2024
dfed5fe
Fix num maskmem for tflite
kyakuno Sep 6, 2024
1dbe5c6
Added model link
kyakuno Sep 7, 2024
3e8dbf6
Merge branch 'onnx' into memory_attention_tflite
kyakuno Sep 9, 2024
0950d3e
Change model name to opt
kyakuno Sep 9, 2024
b571a15
Fix num maskmem for tflite
kyakuno Sep 9, 2024
32d55ab
Merge pull request #2 from axinc-ai/memory_attention_tflite
kyakuno Sep 9, 2024
cfe762f
Update inference example
kyakuno Sep 10, 2024
f932743
Implement image size option
kyakuno Sep 12, 2024
285a6bf
Implement image size 512
kyakuno Sep 12, 2024
8b96886
Implement image size for position encoder
kyakuno Sep 12, 2024
f945973
Fix export model path
kyakuno Sep 13, 2024
f8aa0e8
Merge pull request #3 from axinc-ai/resolution
kyakuno Sep 15, 2024
3ab1930
Fix tensor order of tflite
kyakuno Oct 8, 2024
d37e3c6
Update required version
kyakuno Oct 8, 2024
4919603
Merge pull request #4 from axinc-ai/fix_tflite
kyakuno Oct 8, 2024
b5cb1f9
Implement inference only code for tflite
kyakuno Oct 10, 2024
14c67f7
Download all models
kyakuno Oct 10, 2024
ea5dc39
Merge pull request #5 from axinc-ai/tflite_import
kyakuno Oct 10, 2024
6bb6b1c
Update checkpoint information
kyakuno Dec 4, 2024
70b93fa
Fix duplicated post process for torch
kyakuno Dec 8, 2024
7e6f905
Fix prompt encoder mismatch
kyakuno Dec 19, 2024
b898bd6
Update release note
kyakuno Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 79 additions & 137 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,186 +1,128 @@
# SAM 2: Segment Anything in Images and Videos
# SAM 2 Export to ONNX and TFLITE

**[AI at Meta, FAIR](https://ai.meta.com/research/)**
## Download model

[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)

[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]

![SAM 2 architecture](assets/model_diagram.png?raw=true)

**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.

![SA-V dataset](assets/sa_v_dataset.jpg?raw=true)

## Installation

SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:

```bash
git clone https://github.com/facebookresearch/segment-anything-2.git

cd segment-anything-2 & pip install -e .
```
If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.

To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:

```bash
pip install -e ".[demo]"
```

Note:
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).

Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.

## Getting Started

### Download Checkpoints

First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:

```bash
cd checkpoints && \
./download_ckpts.sh && \
cd ..
```

or individually from:

- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)
- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)

Then SAM 2 can be used in a few lines as follows for image and video prediction.
## Requirements

### Image prediction
onnx

SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.

```python
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
```
torch 2.2.1
onnx 1.16.2
```

Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.

SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.

### Video prediction

For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.

```python
import torch
from sam2.build_sam import build_sam2_video_predictor
tflite

checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
```
torch 2.4.0
ai-edge-torch 0.2.0
tf-nightly 2.18.0.dev20240811 for image mode
tf-nightly 2.18.0.dev20240905 for video mode
```

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)
## Export and Inference

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
onnx

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```
python3 export_image_predictor.py --framework onnx
python3 export_video_predictor.py --framework onnx
```

Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
tflite

## Load from 🤗 Hugging Face
```
export PJRT_DEVICE=CPU
python3 export_image_predictor.py --framework tflite
python3 export_video_predictor.py --framework tflite
```

Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
## Inference only

For image prediction:
onnx

```python
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor
```
download_onnx_models.sh
python3 export_image_predictor.py --framework onnx --mode import
python3 export_video_predictor.py --framework onnx --mode import
```

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
tflite

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
```
download_tflite_models.sh
python3 export_image_predictor.py --framework tflite --mode import
python3 export_video_predictor.py --framework tflite --mode import
python3 export_image_predictor.py --framework tflite --mode import --image_size 512
python3 export_video_predictor.py --framework tflite --mode import --image_size 512
```

For video prediction:
## Test

```python
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor
Replacing the complex tensor of RotaryEnc with matmul. To test this behavior, you can also run it with torch.

predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
```
python3 export_video_predictor.py --framework torch
```

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)
## Artifacts

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
The deliverables will be stored below.

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```
output/*
model/*
```

## Model Description
You can also download it from the following.

| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
| sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
| sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
| sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
| sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
### ONNX

\* Compile the model by setting `compile_image_encoder: True` in the config.
- https://storage.googleapis.com/ailia-models/segment-anything-2/image_encoder_hiera_t.onnx
- https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder_hiera_t.onnx
- https://storage.googleapis.com/ailia-models/segment-anything-2/mask_decoder_hiera_t.onnx
- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_encoder_hiera_t.onnx
- https://storage.googleapis.com/ailia-models/segment-anything-2/mlp_hiera_t.onnx
- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.onnx (6dim matmul, batch = N)
- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.opt.onnx (4dim matmul, batch = 1)

## Segment Anything Video Dataset
(The model of the Prompt Encoder was replaced on 2024/12/19 due to a problem found in the Prompt Encoder.)

See [sav_dataset/README.md](sav_dataset/README.md) for details.
### TFLITE

## License
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_t.tflite
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/prompt_encoder_hiera_t.tflite
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_t.tflite
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mlp_hiera_t.tflite
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_encoder_hiera_t.tflite
- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_attention_hiera_t.tflite (4dim matmul, batch = 1, num_maskmem = 1)

The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models.
The memory attention in tflite does not support dynamic shapes, so num_maskmem and max_obj_ptrs_in_encoder need to be fixed to 1.

## Contributing
(The model of the Prompt Encoder was replaced on 2024/12/19 due to a problem found in the Prompt Encoder.)

See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
## Inference Example

## Contributors
- [ailia-models](https://github.com/axinc-ai/ailia-models/tree/master/image_segmentation/segment-anything-2)
- [ailia-models-tflite](https://github.com/axinc-ai/ailia-models-tflite/pull/90)

The SAM 2 project was made possible with the help of many contributors (alphabetical):
## Original document

Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
- [README_ORIGINAL.md](README_ORIGINAL.md)

Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
## Tags

## Citing SAM 2
### 4dim matmul

If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
main

```bibtex
@article{ravi2024sam2,
title={SAM 2: Segment Anything in Images and Videos},
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
journal={arXiv preprint arXiv:2408.00714},
url={https://arxiv.org/abs/2408.00714},
year={2024}
}
```
### 6dim matmul

https://github.com/axinc-ai/segment-anything-2/tree/f36169e87ec302c75279fadc60cda1c3763165eb
Loading
Loading