diff --git a/README.md b/README.md index bc13c1c9a..46f67e871 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,7 @@ -# SAM 2: Segment Anything in Images and Videos +# SAM 2 Export to ONNX and TFLITE -**[AI at Meta, FAIR](https://ai.meta.com/research/)** +## Download model -[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/) - -[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)] - -![SAM 2 architecture](assets/model_diagram.png?raw=true) - -**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains. - -![SA-V dataset](assets/sa_v_dataset.jpg?raw=true) - -## Installation - -SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: - -```bash -git clone https://github.com/facebookresearch/segment-anything-2.git - -cd segment-anything-2 & pip install -e . -``` -If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. - -To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by: - -```bash -pip install -e ".[demo]" -``` - -Note: -1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. -2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version. -3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases). - -Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions. - -## Getting Started - -### Download Checkpoints - -First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running: ```bash cd checkpoints && \ @@ -48,139 +9,120 @@ cd checkpoints && \ cd .. ``` -or individually from: - -- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) -- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) -- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) -- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) - -Then SAM 2 can be used in a few lines as follows for image and video prediction. +## Requirements -### Image prediction +onnx -SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting. - -```python -import torch -from sam2.build_sam import build_sam2 -from sam2.sam2_image_predictor import SAM2ImagePredictor - -checkpoint = "./checkpoints/sam2_hiera_large.pt" -model_cfg = "sam2_hiera_l.yaml" -predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) - -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - predictor.set_image() - masks, _, _ = predictor.predict() +``` +torch 2.2.1 +onnx 1.16.2 ``` -Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases. - -SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images. - -### Video prediction - -For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video. - -```python -import torch -from sam2.build_sam import build_sam2_video_predictor +tflite -checkpoint = "./checkpoints/sam2_hiera_large.pt" -model_cfg = "sam2_hiera_l.yaml" -predictor = build_sam2_video_predictor(model_cfg, checkpoint) +``` +torch 2.4.0 +ai-edge-torch 0.2.0 +tf-nightly 2.18.0.dev20240811 for image mode +tf-nightly 2.18.0.dev20240905 for video mode +``` -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - state = predictor.init_state() +## Export and Inference - # add new prompts and instantly get the output on the same frame - frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): +onnx - # propagate the prompts to get masklets throughout the video - for frame_idx, object_ids, masks in predictor.propagate_in_video(state): - ... +``` +python3 export_image_predictor.py --framework onnx +python3 export_video_predictor.py --framework onnx ``` -Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos. +tflite -## Load from 🤗 Hugging Face +``` +export PJRT_DEVICE=CPU +python3 export_image_predictor.py --framework tflite +python3 export_video_predictor.py --framework tflite +``` -Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). +## Inference only -For image prediction: +onnx -```python -import torch -from sam2.sam2_image_predictor import SAM2ImagePredictor +``` +download_onnx_models.sh +python3 export_image_predictor.py --framework onnx --mode import +python3 export_video_predictor.py --framework onnx --mode import +``` -predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") +tflite -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - predictor.set_image() - masks, _, _ = predictor.predict() +``` +download_tflite_models.sh +python3 export_image_predictor.py --framework tflite --mode import +python3 export_video_predictor.py --framework tflite --mode import +python3 export_image_predictor.py --framework tflite --mode import --image_size 512 +python3 export_video_predictor.py --framework tflite --mode import --image_size 512 ``` -For video prediction: +## Test -```python -import torch -from sam2.sam2_video_predictor import SAM2VideoPredictor +Replacing the complex tensor of RotaryEnc with matmul. To test this behavior, you can also run it with torch. -predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") +``` +python3 export_video_predictor.py --framework torch +``` -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - state = predictor.init_state() +## Artifacts - # add new prompts and instantly get the output on the same frame - frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): +The deliverables will be stored below. - # propagate the prompts to get masklets throughout the video - for frame_idx, object_ids, masks in predictor.propagate_in_video(state): - ... +``` +output/* +model/* ``` -## Model Description +You can also download it from the following. -| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | -| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | -| sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | -| sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | -| sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | -| sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | +### ONNX -\* Compile the model by setting `compile_image_encoder: True` in the config. +- https://storage.googleapis.com/ailia-models/segment-anything-2/image_encoder_hiera_t.onnx +- https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder_hiera_t.onnx +- https://storage.googleapis.com/ailia-models/segment-anything-2/mask_decoder_hiera_t.onnx +- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_encoder_hiera_t.onnx +- https://storage.googleapis.com/ailia-models/segment-anything-2/mlp_hiera_t.onnx +- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.onnx (6dim matmul, batch = N) +- https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.opt.onnx (4dim matmul, batch = 1) -## Segment Anything Video Dataset +(The model of the Prompt Encoder was replaced on 2024/12/19 due to a problem found in the Prompt Encoder.) -See [sav_dataset/README.md](sav_dataset/README.md) for details. +### TFLITE -## License +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_t.tflite +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/prompt_encoder_hiera_t.tflite +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_t.tflite +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mlp_hiera_t.tflite +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_encoder_hiera_t.tflite +- https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_attention_hiera_t.tflite (4dim matmul, batch = 1, num_maskmem = 1) -The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models. +The memory attention in tflite does not support dynamic shapes, so num_maskmem and max_obj_ptrs_in_encoder need to be fixed to 1. -## Contributing +(The model of the Prompt Encoder was replaced on 2024/12/19 due to a problem found in the Prompt Encoder.) -See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). +## Inference Example -## Contributors +- [ailia-models](https://github.com/axinc-ai/ailia-models/tree/master/image_segmentation/segment-anything-2) +- [ailia-models-tflite](https://github.com/axinc-ai/ailia-models-tflite/pull/90) -The SAM 2 project was made possible with the help of many contributors (alphabetical): +## Original document -Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang. +- [README_ORIGINAL.md](README_ORIGINAL.md) -Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions. +## Tags -## Citing SAM 2 +### 4dim matmul -If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry. +main -```bibtex -@article{ravi2024sam2, - title={SAM 2: Segment Anything in Images and Videos}, - author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, - journal={arXiv preprint arXiv:2408.00714}, - url={https://arxiv.org/abs/2408.00714}, - year={2024} -} -``` +### 6dim matmul + +https://github.com/axinc-ai/segment-anything-2/tree/f36169e87ec302c75279fadc60cda1c3763165eb diff --git a/README_ORIGINAL.md b/README_ORIGINAL.md new file mode 100644 index 000000000..bc13c1c9a --- /dev/null +++ b/README_ORIGINAL.md @@ -0,0 +1,186 @@ +# SAM 2: Segment Anything in Images and Videos + +**[AI at Meta, FAIR](https://ai.meta.com/research/)** + +[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/) + +[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)] + +![SAM 2 architecture](assets/model_diagram.png?raw=true) + +**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains. + +![SA-V dataset](assets/sa_v_dataset.jpg?raw=true) + +## Installation + +SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using: + +```bash +git clone https://github.com/facebookresearch/segment-anything-2.git + +cd segment-anything-2 & pip install -e . +``` +If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. + +To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by: + +```bash +pip install -e ".[demo]" +``` + +Note: +1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`. +2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version. +3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases). + +Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions. + +## Getting Started + +### Download Checkpoints + +First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running: + +```bash +cd checkpoints && \ +./download_ckpts.sh && \ +cd .. +``` + +or individually from: + +- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) +- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) +- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) +- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) + +Then SAM 2 can be used in a few lines as follows for image and video prediction. + +### Image prediction + +SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting. + +```python +import torch +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases. + +SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images. + +### Video prediction + +For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video. + +```python +import torch +from sam2.build_sam import build_sam2_video_predictor + +checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +predictor = build_sam2_video_predictor(model_cfg, checkpoint) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos. + +## Load from 🤗 Hugging Face + +Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +## Model Description + +| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | +| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | +| sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | +| sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | +| sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | +| sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | + +\* Compile the model by setting `compile_image_encoder: True` in the config. + +## Segment Anything Video Dataset + +See [sav_dataset/README.md](sav_dataset/README.md) for details. + +## License + +The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models. + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Contributors + +The SAM 2 project was made possible with the help of many contributors (alphabetical): + +Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang. + +Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions. + +## Citing SAM 2 + +If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry. + +```bibtex +@article{ravi2024sam2, + title={SAM 2: Segment Anything in Images and Videos}, + author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, + journal={arXiv preprint arXiv:2408.00714}, + url={https://arxiv.org/abs/2408.00714}, + year={2024} +} +``` diff --git a/download_onnx_models.sh b/download_onnx_models.sh new file mode 100644 index 000000000..3a38c895f --- /dev/null +++ b/download_onnx_models.sh @@ -0,0 +1,7 @@ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/image_encoder_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/mask_decoder_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/memory_encoder_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/mlp_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.onnx -P ./model/ +wget https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.opt.onnx -P ./model/ \ No newline at end of file diff --git a/download_tflite_models.sh b/download_tflite_models.sh new file mode 100644 index 000000000..f295991a9 --- /dev/null +++ b/download_tflite_models.sh @@ -0,0 +1,12 @@ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/prompt_encoder_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mlp_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_encoder_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_attention_hiera_t.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_t_512.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/prompt_encoder_hiera_t_512.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_t_512.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mlp_hiera_t_512.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_encoder_hiera_t_512.tflite -P ./model/ +wget https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/memory_attention_hiera_t_512.tflite -P ./model/ \ No newline at end of file diff --git a/export_image_predictor.py b/export_image_predictor.py new file mode 100644 index 000000000..fb202a3ee --- /dev/null +++ b/export_image_predictor.py @@ -0,0 +1,145 @@ +# Export image encoder and prompt encoder and mask decoder +# Implemented by ax Inc. 2024 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--model_id', default="hiera_t", choices=["hiera_l", "hiera_b+", "hiera_s", "hiera_t"]) +parser.add_argument('--framework', default="onnx", choices=["onnx", "tflite", "torch"]) +parser.add_argument('--accuracy', default="float", choices=["float", "int8"]) +parser.add_argument('--mode', default="both", choices=["both", "import", "export"]) +parser.add_argument('--image_size', default=1024, type=int, choices=[512, 1024]) +args = parser.parse_args() + +import os +import numpy as np +import torch +import matplotlib.pyplot as plt +from PIL import Image + +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +# output +os.makedirs("output", exist_ok=True) +os.makedirs("model", exist_ok=True) + +# export settings +export_to_onnx_image_encoder = args.framework == "onnx" and (args.mode=="export" or args.mode=="both") +export_to_onnx_mask_decoder = args.framework == "onnx" and (args.mode=="export" or args.mode=="both") +import_from_onnx = args.framework == "onnx" and (args.mode=="import" or args.mode=="both") + +export_to_tflite_image_encoder = args.framework == "tflite" and (args.mode=="export" or args.mode=="both") +export_to_tflite_mask_decoder = args.framework == "tflite" and (args.mode=="export" or args.mode=="both") +import_from_tflite = args.framework == "tflite" and (args.mode=="import" or args.mode=="both") + +tflite_int8 = args.accuracy == "int8" + +# export PJRT_DEVICE=CPU + +# model settings +model_id = args.model_id +if model_id == "hiera_l": + sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" +elif model_id == "hiera_b+": + sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt" + model_cfg = "sam2_hiera_b+.yaml" +elif model_id == "hiera_s": + sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt" + model_cfg = "sam2_hiera_s.yaml" +elif model_id == "hiera_t": + sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt" + model_cfg = "sam2_hiera_t.yaml" +else: + print("unknown model id") + exit() + +# resolution settings +if args.image_size == 512: + model_id = model_id + "_512" + +# use cpu for export +device = torch.device("cpu") + +# utility +np.random.seed(3) + +def show_mask(mask, ax, random_color=False, borders = True): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask = mask.astype(np.uint8) + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + if borders: + import cv2 + contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # Try to smooth contours + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) + ax.imshow(mask_image) + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) + +def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True, model_id=model_id): + for i, (mask, score) in enumerate(zip(masks, scores)): + plt.figure(figsize=(10, 10)) + plt.imshow(image) + show_mask(mask, plt.gca(), borders=borders) + if point_coords is not None: + assert input_labels is not None + show_points(point_coords, input_labels, plt.gca()) + if box_coords is not None: + # boxes + show_box(box_coords, plt.gca()) + if len(scores) > 1: + plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) + plt.axis('off') + #plt.show() + plt.savefig(f'output/output{i+1}_'+model_id+'.png') + +# logic +image = Image.open('notebooks/images/truck.jpg') +image = np.array(image.convert("RGB")) + +sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device, image_size=args.image_size) + +predictor = SAM2ImagePredictor(sam2_model) + +predictor.set_image(image, export_to_onnx = export_to_onnx_image_encoder, + export_to_tflite = export_to_tflite_image_encoder, + import_from_onnx = import_from_onnx, import_from_tflite = import_from_tflite, + tflite_int8 = tflite_int8, model_id = model_id) + +input_point = np.array([[500, 375]]) +input_label = np.array([1]) + +masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=True, + export_to_onnx=export_to_onnx_mask_decoder, + export_to_tflite=export_to_tflite_mask_decoder, + import_from_onnx=import_from_onnx, + import_from_tflite=import_from_tflite, + tflite_int8=tflite_int8, + model_id=model_id +) +sorted_ind = np.argsort(scores)[::-1] +masks = masks[sorted_ind] +scores = scores[sorted_ind] +logits = logits[sorted_ind] + +show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True, model_id=model_id) + +print("Success!") \ No newline at end of file diff --git a/export_video_predictor.py b/export_video_predictor.py new file mode 100644 index 000000000..8856b3173 --- /dev/null +++ b/export_video_predictor.py @@ -0,0 +1,150 @@ +# Export memory attention and memory encoder +# Implemented by ax Inc. 2024 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--model_id', default="hiera_t", choices=["hiera_l", "hiera_b+", "hiera_s", "hiera_t"]) +parser.add_argument('--framework', default="onnx", choices=["onnx", "tflite", "torch"]) +parser.add_argument('--accuracy', default="float", choices=["float", "int8"]) +parser.add_argument('--mode', default="both", choices=["both", "import", "export"]) +parser.add_argument('--image_size', default=1024, type=int, choices=[512, 1024]) +args = parser.parse_args() + +import os +import numpy as np +import torch +import matplotlib.pyplot as plt +from PIL import Image + +# output +os.makedirs("output", exist_ok=True) +os.makedirs("model", exist_ok=True) + +# export settings +model_id = args.model_id + +export_to_onnx = args.framework=="onnx" and (args.mode=="export" or args.mode=="both") +import_from_onnx = args.framework=="onnx" and (args.mode=="import" or args.mode=="both") +export_to_tflite = args.framework=="tflite" and (args.mode=="export" or args.mode=="both") +import_from_tflite = args.framework=="tflite" and (args.mode=="import" or args.mode=="both") + +# import +if model_id == "hiera_l": + model_cfg = "sam2_hiera_l.yaml" + sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" +elif model_id == "hiera_s": + model_cfg = "sam2_hiera_s.yaml" + sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt" +elif model_id == "hiera_b+": + model_cfg = "sam2_hiera_b+.yaml" + sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt" +elif model_id == "hiera_t": + model_cfg = "sam2_hiera_t.yaml" + sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt" +else: + raise("unknown model type") + +# resolution settings +if args.image_size == 512: + model_id = model_id + "_512" + +device = torch.device("cpu") +print(f"using device: {device}") + +from sam2.build_sam import build_sam2_video_predictor + +predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device, image_size=args.image_size) + +if export_to_tflite or import_from_tflite: + predictor.set_num_maskmem(num_maskmem=1, max_obj_ptrs_in_encoder=1) + +def show_mask(mask, ax, obj_id=None, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + cmap = plt.get_cmap("tab10") + cmap_idx = 0 if obj_id is None else obj_id + color = np.array([*cmap(cmap_idx)[:3], 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_points(coords, labels, ax, marker_size=200): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) + +video_dir = "./notebooks/videos/bedroom_short" + +# scan all the JPEG frame names in this directory +frame_names = [ + p for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + +inference_state = predictor.init_state(video_path=video_dir, import_from_onnx=import_from_onnx, import_from_tflite=import_from_tflite, model_id=model_id) +predictor.reset_state(inference_state) + +ann_frame_idx = 0 # the frame index we interact with +ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + +# Let's add a 2nd positive click at (x, y) = (250, 220) to refine the mask +# sending all clicks (and their labels) to `add_new_points_or_box` +# for labels, `1` means positive click and `0` means negative click +if args.framework == "tflite": + points = np.array([[210, 350]], dtype=np.float32) + labels = np.array([1], np.int32) +else: + points = np.array([[210, 350], [250, 220]], dtype=np.float32) + labels = np.array([1, 1], np.int32) + +_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id +) + +# show the results on the current (interacted) frame +plt.figure(figsize=(9, 6)) +plt.title(f"frame {ann_frame_idx}") +plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx]))) +show_points(points, labels, plt.gca()) +show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0]) +#plt.show() +plt.savefig(f'output/video_'+model_id+'.png') + +# run propagation throughout the video and collect the results in a dict +video_segments = {} # video_segments contains the per-frame segmentation results +for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, import_from_tflite=import_from_tflite, export_to_tflite=export_to_tflite, model_id=model_id): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + +# render the segmentation results every few frames +vis_frame_stride = 1 +plt.close("all") +for out_frame_idx in range(0, len(frame_names), vis_frame_stride): + plt.figure(figsize=(6, 4)) + plt.title(f"frame {out_frame_idx}") + plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))) + for out_obj_id, out_mask in video_segments[out_frame_idx].items(): + show_mask(out_mask, plt.gca(), obj_id=out_obj_id) + #plt.show() + plt.savefig(f'output/video{out_frame_idx+1}_'+model_id+'.png') diff --git a/notebooks/videos/bedroom_short/00000.jpg b/notebooks/videos/bedroom_short/00000.jpg new file mode 100644 index 000000000..26c7b234c Binary files /dev/null and b/notebooks/videos/bedroom_short/00000.jpg differ diff --git a/notebooks/videos/bedroom_short/00001.jpg b/notebooks/videos/bedroom_short/00001.jpg new file mode 100644 index 000000000..2350255b6 Binary files /dev/null and b/notebooks/videos/bedroom_short/00001.jpg differ diff --git a/notebooks/videos/bedroom_short/00002.jpg b/notebooks/videos/bedroom_short/00002.jpg new file mode 100644 index 000000000..db3da8f11 Binary files /dev/null and b/notebooks/videos/bedroom_short/00002.jpg differ diff --git a/notebooks/videos/bedroom_short/00003.jpg b/notebooks/videos/bedroom_short/00003.jpg new file mode 100644 index 000000000..1066831d2 Binary files /dev/null and b/notebooks/videos/bedroom_short/00003.jpg differ diff --git a/notebooks/videos/bedroom_short/00004.jpg b/notebooks/videos/bedroom_short/00004.jpg new file mode 100644 index 000000000..b3da8f794 Binary files /dev/null and b/notebooks/videos/bedroom_short/00004.jpg differ diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 3a29eda3c..69d165aef 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -19,6 +19,7 @@ def build_sam2( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + image_size=1024, **kwargs, ): @@ -32,6 +33,7 @@ def build_sam2( ] # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + cfg.model.image_size = image_size OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) @@ -48,6 +50,7 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + image_size=1024, **kwargs, ): hydra_overrides = [ @@ -69,6 +72,7 @@ def build_sam2_video_predictor( # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) + cfg.model.image_size = image_size OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py index 0b07f9d87..84c23764f 100644 --- a/sam2/modeling/memory_attention.py +++ b/sam2/modeling/memory_attention.py @@ -59,23 +59,19 @@ def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 - tgt2 = self.self_attn(q, k, v=tgt2) + tgt2 = self.self_attn.self_attn(q, k = k, v = tgt2) tgt = tgt + self.dropout1(tgt2) return tgt - def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): - kwds = {} - if num_k_exclude_rope > 0: - assert isinstance(self.cross_attn_image, RoPEAttention) - kwds = {"num_k_exclude_rope": num_k_exclude_rope} - + def _forward_ca(self, tgt, memory_1, memory_2, query_pos, pos_1, pos_2): # Cross-Attention tgt2 = self.norm2(tgt) - tgt2 = self.cross_attn_image( + tgt2 = self.cross_attn_image.cross_attn( q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, - k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, - v=memory, - **kwds, + k_1=memory_1 + pos_1 if self.pos_enc_at_cross_attn_keys else memory_1, + v_1=memory_1, + k_2=memory_2 + pos_2 if self.pos_enc_at_cross_attn_keys else memory_2, + v_2=memory_2 ) tgt = tgt + self.dropout2(tgt2) return tgt @@ -83,15 +79,16 @@ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): def forward( self, tgt, - memory, - pos: Optional[Tensor] = None, + memory_1, + memory_2, + pos_1: Optional[Tensor] = None, + pos_2: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, - num_k_exclude_rope: int = 0, ) -> torch.Tensor: # Self-Attn, Cross-Attn tgt = self._forward_sa(tgt, query_pos) - tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + tgt = self._forward_ca(tgt, memory_1, memory_2, query_pos, pos_1, pos_2) # MLP tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) @@ -116,13 +113,40 @@ def __init__( self.pos_enc_at_input = pos_enc_at_input self.batch_first = batch_first + def allocate_rope_attention_weight( + self, + curr: torch.Tensor, # self-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + image_size = 1024, + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + output = curr + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + + for layer in self.layers: + if isinstance(layer.cross_attn_image, RoPEAttention): + layer.cross_attn_image.allocate_rope_attention_weight(output, image_size = image_size) + if isinstance(layer.self_attn, RoPEAttention): + layer.self_attn.allocate_rope_attention_weight(output, image_size = image_size) + def forward( self, curr: torch.Tensor, # self-attention inputs - memory: torch.Tensor, # cross-attention inputs + memory_1: torch.Tensor, # cross-attention inputs + memory_2: torch.Tensor, # cross-attention inputs curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs - memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs - num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + memory_pos_1: Optional[Tensor] = None, # pos_enc for cross-attention inputs + memory_pos_2: Optional[Tensor] = None, # pos_enc for cross-attention inputs ): if isinstance(curr, list): assert isinstance(curr_pos, list) @@ -133,7 +157,7 @@ def forward( ) assert ( - curr.shape[1] == memory.shape[1] + curr.shape[1] == memory_1.shape[1] ), "Batch size must be the same for curr and memory" output = curr @@ -144,20 +168,18 @@ def forward( # Convert to batch first output = output.transpose(0, 1) curr_pos = curr_pos.transpose(0, 1) - memory = memory.transpose(0, 1) - memory_pos = memory_pos.transpose(0, 1) + memory_1 = memory_1.transpose(0, 1) + memory_2 = memory_2.transpose(0, 1) + memory_pos_1 = memory_pos_1.transpose(0, 1) + memory_pos_2 = memory_pos_2.transpose(0, 1) for layer in self.layers: - kwds = {} - if isinstance(layer.cross_attn_image, RoPEAttention): - kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} - output = layer( tgt=output, - memory=memory, - pos=memory_pos, - query_pos=curr_pos, - **kwds, + memory_1=memory_1, + memory_2=memory_2, + pos_1=memory_pos_1, + pos_2=memory_pos_2, ) normed_output = self.norm(output) diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py index f60202dfa..c2fe6f340 100644 --- a/sam2/modeling/memory_encoder.py +++ b/sam2/modeling/memory_encoder.py @@ -159,8 +159,9 @@ def forward( self, pix_feat: torch.Tensor, masks: torch.Tensor, - skip_mask_sigmoid: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: + skip_mask_sigmoid = True # Fix for tflite + ## Process masks # sigmoid, so that less domain shift from gt masks which are bool if not skip_mask_sigmoid: @@ -178,4 +179,5 @@ def forward( pos = self.position_encoding(x).to(x.dtype) - return {"vision_features": x, "vision_pos_enc": [pos]} + return x, pos + diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index 52ac22674..41269d6b5 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -198,16 +198,9 @@ def apply_rotary_enc( repeat_freqs_k: bool = False, ): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = ( - torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - if xk.shape[-2] != 0 - else None - ) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - if xk_ is None: - # no keys to rotate, due to dropout - return xq_out.type_as(xq).to(xq.device), xk # repeat freqs along seq_len dim to match k seq_len if repeat_freqs_k: r = xk_.shape[-2] // xq_.shape[-2] @@ -219,3 +212,91 @@ def apply_rotary_enc( freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +# Matrix version of rotary enc +# https://github.com/facebookresearch/segment-anything-2/issues/186 + +def get_rotation_matrices(dim, end_x, end_y, theta=10000.0, device=None, dtype=None): + + powers = torch.linspace(0, 1, 1 + (dim // 4), device=device, dtype=dtype)[:-1] + base_angles = torch.pow(theta, -powers) + + end_x, end_y = int(end_x), int(end_y) + x_mults = torch.arange(end_x, device=device, dtype=dtype).repeat(end_y) + y_mults = torch.arange(end_y, device=device, dtype=dtype).repeat_interleave(end_x) + angles_xy = (torch.outer(mults, base_angles) for mults in (x_mults, y_mults)) + + rotmats_list = [] + for angles in angles_xy: + sterm, cterm = torch.sin(-angles), torch.cos(-angles) + rotmat = torch.stack( + [ + torch.stack([cterm, -sterm], dim=-1), + torch.stack([sterm, cterm], dim=-1), + ], + dim=-1, + ) + rotmats_list.append(rotmat) + + return torch.cat(rotmats_list, dim=1).unsqueeze(0).unsqueeze(0) + + +def apply_rotary_matenc(xq, xk, rotmats, repeat_freqs_k=False): + # オリジナル実装 (6次元テンソル処理) + #bq, hq, nq, cq = xq.shape + #bk, hk, nk, ck = xk.shape + #q_out = torch.matmul(rotmats, xq.reshape(bq, hq, nq, cq // 2, 2, 1)).flatten(3) + #k_rotmat = rotmats.repeat(1, 1, nk // nq, 1, 1, 1) if repeat_freqs_k else rotmats + #k_out = torch.matmul(k_rotmat, xk.reshape(bk, hk, nk, ck // 2, 2, 1)).flatten(3) + + # tfliteでは4次元テンソルまでしか扱えないのでバッチサイズに制約をかける + + bq, hq, nq, cq = xq.shape + #torch._check_is_size(bq) + #torch._check_is_size(hq) + #torch._check_is_size(nq) + #torch._check_is_size(cq) + #torch._check(bq == 1) # for dynamo trace + #torch._check(hq == 1) # for dynamo trace + #torch._check(cq == 256) # for dynamo trace + + #print(rotmats.shape) + + q_rotmat = rotmats.reshape(4096, 128, 2, 2) + q_out = torch.matmul(q_rotmat, xq.reshape(nq, 128, 2, 1)).reshape(1, 1, 4096, 256) + #print(q_out.shape) + + bk, hk, nk, ck = xk.shape + k_rotmat = q_rotmat.repeat(nk // 4096, 1, 1, 1)# if repeat_freqs_k else rotmats # for tflite trace, repeat_freqs_k == Falseの場合は nk // nq == 1 なのでrepeatを常に呼び出しても等価になる + + bk, hk, nk, ck = xk.shape + #torch._check_is_size(bq == 1) + #torch._check_is_size(hq == 1) + #torch._check(ck == 256) + + #torch._check(xk.size(3) == 256) + + k_in = xk.reshape(nk, 128, 2, 1) + #k_in = k_in[:k_rotmat.shape[0], :, :, :] + k_out = torch.matmul(k_rotmat, k_in).reshape(1, 1, nk, 256) + + #print("k_rotmat", k_rotmat.shape) + #print("k_in", k_in.shape) + #print("k_out", k_out.shape) + + return q_out, k_out + + +def apply_rotary_matenc_512(xq, xk, rotmats, repeat_freqs_k=False): + bq, hq, nq, cq = xq.shape + q_rotmat = rotmats.reshape(1024, 128, 2, 2) + q_out = torch.matmul(q_rotmat, xq.reshape(nq, 128, 2, 1)).reshape(1, 1, 1024, 256) + + bk, hk, nk, ck = xk.shape + k_rotmat = q_rotmat.repeat(nk // 1024, 1, 1, 1) + + bk, hk, nk, ck = xk.shape + k_in = xk.reshape(nk, 128, 2, 1) + k_out = torch.matmul(k_rotmat, k_in).reshape(1, 1, nk, 256) + return q_out, k_out diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py index b7c7dfdb3..43276f94f 100644 --- a/sam2/modeling/sam/mask_decoder.py +++ b/sam2/modeling/sam/mask_decoder.py @@ -107,7 +107,8 @@ def __init__( self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh - def forward( + # デフォルト実装 + def forward_normal( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, @@ -115,7 +116,8 @@ def forward( dense_prompt_embeddings: torch.Tensor, multimask_output: bool, repeat_image: bool, - high_res_features: Optional[List[torch.Tensor]] = None, + high_res_features1: Optional[torch.Tensor] = None, + high_res_features2: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. @@ -139,7 +141,8 @@ def forward( sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, repeat_image=repeat_image, - high_res_features=high_res_features, + high_res_features1=high_res_features1, + high_res_features2=high_res_features2, ) # Select the correct mask or masks for output @@ -165,6 +168,59 @@ def forward( # Prepare output return masks, iou_pred, sam_tokens_out, object_score_logits + # ONNXに変換するために推論とポスト処理を分離するバージョン + def forward_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features1: Optional[torch.Tensor] = None, + high_res_features2: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features1=high_res_features1, + high_res_features2=high_res_features2, + ) + return masks, iou_pred, mask_tokens_out, object_score_logits + + def forward_postprocess( + self, + masks: torch.Tensor, + iou_pred: torch.Tensor, + mask_tokens_out: torch.Tensor, + object_score_logits: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + def predict_masks( self, image_embeddings: torch.Tensor, @@ -172,7 +228,8 @@ def predict_masks( sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, repeat_image: bool, - high_res_features: Optional[List[torch.Tensor]] = None, + high_res_features1: Optional[torch.Tensor] = None, + high_res_features2: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens @@ -192,7 +249,7 @@ def predict_masks( [self.iou_token.weight, self.mask_tokens.weight], dim=0 ) output_tokens = output_tokens.unsqueeze(0).expand( - sparse_prompt_embeddings.size(0), -1, -1 + sparse_prompt_embeddings.shape[0], -1, -1 ) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) @@ -204,9 +261,14 @@ def predict_masks( src = image_embeddings src = src + dense_prompt_embeddings assert ( - image_pe.size(0) == 1 + image_pe.shape[0] == 1 ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" - pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + + pos_src = torch.tensor((tokens.shape[0], image_pe.shape[1], image_pe.shape[2])) + pos_src = image_pe # batch broad cast + + #pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) # one_hotが生成responseえる + b, c, h, w = src.shape # Run the transformer @@ -220,7 +282,7 @@ def predict_masks( upscaled_embedding = self.output_upscaling(src) else: dc1, ln1, act1, dc2, act2 = self.output_upscaling - feat_s0, feat_s1 = high_res_features + feat_s0, feat_s1 = high_res_features1, high_res_features2 upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py index 6b3bbb95b..49c45b7af 100644 --- a/sam2/modeling/sam/prompt_encoder.py +++ b/sam2/modeling/sam/prompt_encoder.py @@ -92,12 +92,36 @@ def _embed_points( point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + + # こっちだとonnxでbroadcast error + #point_embedding[labels == -1] = 0.0 + #point_embedding[labels == -1] += self.not_a_point_embed.weight + + # こっちだとonnxで動くが、tfliteでうごかない + #point_embedding[labels == -1] = self.not_a_point_embed.weight + + #point_embedding[labels == 0] += self.point_embeddings[0].weight + #point_embedding[labels == 1] += self.point_embeddings[1].weight + #point_embedding[labels == 2] += self.point_embeddings[2].weight + #point_embedding[labels == 3] += self.point_embeddings[3].weight + + # こっちだと、tfliteでも動く + + # Create the index mask for each label + labels = labels.int() + mask_neg1 = (labels == -1).unsqueeze(-1).expand_as(point_embedding) + mask_0 = (labels == 0).unsqueeze(-1).expand_as(point_embedding) + mask_1 = (labels == 1).unsqueeze(-1).expand_as(point_embedding) + mask_2 = (labels == 2).unsqueeze(-1).expand_as(point_embedding) + mask_3 = (labels == 3).unsqueeze(-1).expand_as(point_embedding) + + # Apply the weights according to the mask + point_embedding = torch.where(mask_neg1, self.not_a_point_embed.weight.expand_as(point_embedding), point_embedding) + point_embedding = torch.where(mask_0, point_embedding + self.point_embeddings[0].weight.expand_as(point_embedding), point_embedding) + point_embedding = torch.where(mask_1, point_embedding + self.point_embeddings[1].weight.expand_as(point_embedding), point_embedding) + point_embedding = torch.where(mask_2, point_embedding + self.point_embeddings[2].weight.expand_as(point_embedding), point_embedding) + point_embedding = torch.where(mask_3, point_embedding + self.point_embeddings[3].weight.expand_as(point_embedding), point_embedding) + return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: @@ -118,17 +142,15 @@ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: def _get_batch_size( self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], + coords: Optional[torch.Tensor], + labels: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ - if points is not None: - return points[0].shape[0] - elif boxes is not None: - return boxes.shape[0] + if coords is not None and labels is not None: + return coords.shape[0] elif masks is not None: return masks.shape[0] else: @@ -139,9 +161,10 @@ def _get_device(self) -> torch.device: def forward( self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], + coords: Optional[torch.Tensor], + labels: Optional[torch.Tensor], masks: Optional[torch.Tensor], + masks_enable: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense @@ -160,23 +183,22 @@ def forward( torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ - bs = self._get_batch_size(points, boxes, masks) + if coords is None or labels is None: + raise("onnx not supported coords is None") + + bs = self._get_batch_size(coords, labels, masks) sparse_embeddings = torch.empty( (bs, 0, self.embed_dim), device=self._get_device() ) - if points is not None: - coords, labels = points - point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) - if boxes is not None: - box_embeddings = self._embed_boxes(boxes) - sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) - - if masks is not None: - dense_embeddings = self._embed_masks(masks) - else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( - bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] - ) - return sparse_embeddings, dense_embeddings + point_embeddings = self._embed_points(coords, labels, pad=True) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + + dense_embeddings1 = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + dense_embeddings2 = self._embed_masks(masks) + + dense_embeddings = torch.where(masks_enable[0] == 1, dense_embeddings2, dense_embeddings1) + + return sparse_embeddings, dense_embeddings, self.get_dense_pe() diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index b5b6fa2f8..42b8bae3c 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -15,6 +15,7 @@ from torch import nn, Tensor from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.position_encoding import apply_rotary_matenc, get_rotation_matrices, apply_rotary_matenc_512 from sam2.modeling.sam2_utils import MLP from sam2.utils.misc import get_sdpa_settings @@ -24,6 +25,8 @@ # A fallback setting to allow all available kernels if Flash Attention fails ALLOW_ALL_KERNELS = False +# Use matrix version of rotrary enc +USE_MAT_ROTARY_ENC = True def sdp_kernel_context(dropout_p): """ @@ -265,17 +268,18 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: + #try: + # with sdp_kernel_context(dropout_p): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + #except Exception as e: + if True: # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) + #warnings.warn( + # f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + # f"kernels for scaled_dot_product_attention (which may have a slower speed).", + # category=UserWarning, + # stacklevel=2, + #) global ALLOW_ALL_KERNELS ALLOW_ALL_KERNELS = True out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) @@ -301,15 +305,34 @@ def __init__( ): super().__init__(*args, **kwargs) - self.compute_cis = partial( - compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta - ) - freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis self.rope_k_repeat = rope_k_repeat - def forward( - self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + if USE_MAT_ROTARY_ENC: + rotmats = get_rotation_matrices(dim=self.internal_dim // self.num_heads, end_x=feat_sizes[0], end_y=feat_sizes[1], theta=rope_theta) + self.rotmats = rotmats + self.rope_theta = rope_theta + else: + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + + def allocate_rope_attention_weight( + self, q: Tensor, image_size + ): + # prepare weight of rope attention for dynamo export + w = h = math.sqrt(q.shape[-2]) + if USE_MAT_ROTARY_ENC: + if self.rotmats.shape[2] != q.shape[-2]: + self.rotmats = get_rotation_matrices(dim=self.internal_dim // self.num_heads, end_x=w, end_y=h, theta=self.rope_theta) + else: + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + self.is_512 = image_size == 512 + + def self_attn( + self, q: Tensor, k: Tensor, v: Tensor ) -> Tensor: # Input projections q = self.q_proj(q) @@ -322,34 +345,141 @@ def forward( v = self._separate_heads(v, self.num_heads) # Apply rotary position encoding - w = h = math.sqrt(q.shape[-2]) - self.freqs_cis = self.freqs_cis.to(q.device) - if self.freqs_cis.shape[0] != q.shape[-2]: - self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) - if q.shape[-2] != k.shape[-2]: - assert self.rope_k_repeat - - num_k_rope = k.size(-2) - num_k_exclude_rope - q, k[:, :, :num_k_rope] = apply_rotary_enc( - q, - k[:, :, :num_k_rope], - freqs_cis=self.freqs_cis, - repeat_freqs_k=self.rope_k_repeat, - ) + #if USE_MAT_ROTARY_ENC: + # #self.rotmats = self.rotmats.to(q.device) + # if self.rotmats.shape[2] != q.shape[-2]: + # raise("rotmat shape error " + str(self.rotmats.shape[2]) + " " + str(q.shape[-2])) + #else: + # #self.freqs_cis = self.freqs_cis.to(q.device) + # if self.freqs_cis.shape[0] != q.shape[-2]: + # raise("freqs_cis shape error " + str(self.freqs_cis.shape[0]) + " " + str(q.shape[-2])) + + #if q.shape[-2] != k.shape[-2]: + # assert self.rope_k_repeat + + if USE_MAT_ROTARY_ENC: + if self.is_512: + q, k = apply_rotary_matenc_512( + q, + k, + rotmats=self.rotmats, + repeat_freqs_k=self.rope_k_repeat, + ) + else: + q, k = apply_rotary_matenc( + q, + k, + rotmats=self.rotmats, + repeat_freqs_k=self.rope_k_repeat, + ) + else: + q, k = apply_rotary_enc( + q, + k, + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: + #try: + # with sdp_kernel_context(dropout_p): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + #except Exception as e: + if True: # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, + #warnings.warn( + # f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + # f"kernels for scaled_dot_product_attention (which may have a slower speed).", + # category=UserWarning, + # stacklevel=2, + #) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + def cross_attn( + self, q: Tensor, k_1: Tensor, v_1: Tensor, k_2: Tensor = None, v_2: Tensor = None + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k_1 = self.k_proj(k_1) + v_1 = self.v_proj(v_1) + k_2 = self.k_proj(k_2) + v_2 = self.v_proj(v_2) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k_1 = self._separate_heads(k_1, self.num_heads) + v_1 = self._separate_heads(v_1, self.num_heads) + k_2 = self._separate_heads(k_2, self.num_heads) + v_2 = self._separate_heads(v_2, self.num_heads) + + # Apply rotary position encoding + #if USE_MAT_ROTARY_ENC: + # #self.rotmats = self.rotmats.to(q.device) + # if self.rotmats.shape[2] != q.shape[-2]: + # raise("rotmat shape error " + str(self.rotmats.shape[2]) + " " + str(q.shape[-2])) + #else: + # #self.freqs_cis = self.freqs_cis.to(q.device) + # if self.freqs_cis.shape[0] != q.shape[-2]: + # raise("freqs_cis shape error " + str(self.freqs_cis.shape[0]) + " " + str(q.shape[-2])) + + #if q.shape[-2] != k_1.shape[-2]: + # assert self.rope_k_repeat + + if USE_MAT_ROTARY_ENC: + if self.is_512: + q, k_1 = apply_rotary_matenc_512( + q, + k_1, + rotmats=self.rotmats, + repeat_freqs_k=self.rope_k_repeat, + ) + else: + q, k_1 = apply_rotary_matenc( + q, + k_1, + rotmats=self.rotmats, + repeat_freqs_k=self.rope_k_repeat, + ) + else: + q, k_1 = apply_rotary_enc( + q, + k_1, + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, ) + + #print(k_1.shape, k_2.shape) + #if k_2.shape[2] == 0: + # k = k_1 + #else: + k = torch.concat((k_1, k_2), dim = 2) + #if v_2.shape[2] == 0: + # v = v_1 + #else: + v = torch.concat((v_1, v_2), dim = 2) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + #try: + # with sdp_kernel_context(dropout_p): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + #except Exception as e: + if True: + # Fall back to all kernels if the Flash attention kernel fails + #warnings.warn( + # f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + # f"kernels for scaled_dot_product_attention (which may have a slower speed).", + # category=UserWarning, + # stacklevel=2, + #) global ALLOW_ALL_KERNELS ALLOW_ALL_KERNELS = True out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) @@ -358,3 +488,4 @@ def forward( out = self.out_proj(out) return out + \ No newline at end of file diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index 224a8c1bb..a1e0fbd95 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -115,10 +115,14 @@ def __init__( # Part 2: memory attention to condition current frame's visual features # with memories (and obj ptrs) from past frames self.memory_attention = memory_attention + self.memory_attention_onnx_exported = False + self.memory_attention_tflite_exported = False self.hidden_dim = memory_attention.d_model # Part 3: memory encoder for the previous frame's outputs self.memory_encoder = memory_encoder + self.memory_encoder_onnx_exported = False + self.memory_encoder_tflite_exported = False self.mem_dim = self.hidden_dim if hasattr(self.memory_encoder, "out_proj") and hasattr( self.memory_encoder.out_proj, "weight" @@ -175,6 +179,9 @@ def __init__( self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn + self.mlp_onnx_exported = False + self.mlp_tflite_exported = False + # Model compilation if compile_image_encoder: # Compile the forward function (not the full module) to allow loading checkpoints. @@ -188,11 +195,73 @@ def __init__( dynamic=False, ) + # debug + self.debug = False + + # onnx + self.image_encoder_onnx = None + self.prompt_encoder_onnx = None + self.mask_decoder_onnx = None + self.mlp_onnx = None + self.memory_attention_onnx = None + self.memory_encoder_onnx = None + + # tflite + self.image_encoder_tflite = None + self.prompt_encoder_tflite = None + self.mask_decoder_tflite = None + self.mlp_tflite = None + self.memory_attention_tflite = None + self.memory_encoder_tflite = None + + # Check decoder sample parameter + assert(self.image_size == 512 or self.image_size == 1024) + assert(self.num_feature_levels == 3) + assert(self.hidden_dim == 256) + assert(self.num_maskmem == 7) + assert(self.directly_add_no_mem_embed == True) + #assert(self.training == False) + assert(self.mem_dim == 64) + assert(self.add_tpos_enc_to_obj_ptrs == False) + assert(self.use_obj_ptrs_in_encoder == True) + assert(self.add_all_frames_to_correct_as_cond == False) + assert(self.multimask_output_in_sam == True) + assert(self.multimask_min_pt_num == 0) + assert(self.multimask_max_pt_num == 1) + assert(self.sam_prompt_embed_dim == self.hidden_dim) + assert(self.backbone_stride == 16) + assert(self.sam_image_embedding_size == self.image_size // self.backbone_stride) + assert(self.pred_obj_scores == True) + assert(self.use_obj_ptrs_in_encoder == True) + assert(self.use_mlp_for_obj_ptr_proj == True) + assert(self.proj_tpos_enc_in_obj_ptrs == False) + assert(self.soft_no_obj_ptr == False) + assert(self.fixed_no_obj_ptr == True) + assert(self.non_overlap_masks_for_mem_enc == False) + assert(self.binarize_mask_from_pts_for_mem_enc == False or self.binarize_mask_from_pts_for_mem_enc == True) # True for video + assert(self.sigmoid_scale_for_mem_enc == 20) + assert(self.sigmoid_bias_for_mem_enc == -10.0) + assert(self.sam_mask_decoder.dynamic_multimask_via_stability == True) + assert(self.sam_mask_decoder.dynamic_multimask_stability_delta == 0.05) + assert(self.sam_mask_decoder.dynamic_multimask_stability_thresh == 0.98) + assert(self.max_cond_frames_in_attn == -1) + assert(self.memory_temporal_stride_for_eval == 1) + assert(self.max_obj_ptrs_in_encoder == 16) + assert(self.only_obj_ptrs_in_the_past_for_eval == True) + assert(self.multimask_output_for_tracking == True) + assert(self.use_multimask_token_for_obj_ptr == True) + + def set_num_maskmem(self, num_maskmem, max_obj_ptrs_in_encoder): + self.num_maskmem = num_maskmem + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + assert(self.num_maskmem == 1 or self.num_maskmem == 7) + assert(self.max_obj_ptrs_in_encoder == 1 or self.max_obj_ptrs_in_encoder == 16) + @property def device(self): return next(self.parameters()).device - def forward(self, *args, **kwargs): + def forward(self): raise NotImplementedError( "Please use the corresponding methods in SAM2VideoPredictor for inference." "See notebooks/video_predictor_example.ipynb for an example." @@ -255,6 +324,11 @@ def _forward_sam_heads( mask_inputs=None, high_res_features=None, multimask_output=False, + export_to_onnx=False, + import_from_onnx=False, + export_to_tflite=False, + import_from_tflite=False, + model_id=None ): """ Forward SAM prompt encoders and mask heads. @@ -331,25 +405,160 @@ def _forward_sam_heads( # a learned `no_mask_embed` to indicate no mask input in this case). sam_mask_prompt = None - sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( - points=(sam_point_coords, sam_point_labels), - boxes=None, - masks=sam_mask_prompt, - ) - ( - low_res_multimasks, - ious, - sam_output_tokens, - object_score_logits, - ) = self.sam_mask_decoder( - image_embeddings=backbone_features, - image_pe=self.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - repeat_image=False, # the image is already batched - high_res_features=high_res_features, - ) + if sam_mask_prompt is None: + import numpy as np + mask_input_dummy = torch.Tensor(np.zeros((1, self.image_size // 4, self.image_size // 4))) + masks_enable = torch.tensor([0], dtype=torch.int) + else: + mask_input_dummy = sam_mask_prompt + masks_enable = torch.tensor([1], dtype=torch.int) + + if import_from_onnx: + if self.debug: + print("begin prompt encoder onnx") + import onnxruntime + if self.prompt_encoder_onnx == None: + self.prompt_encoder_onnx = onnxruntime.InferenceSession("model/prompt_encoder_"+model_id+".onnx") + sparse_embeddings, dense_embeddings, dense_pe = self.prompt_encoder_onnx.run(None, {"coords":sam_point_coords.numpy(), "labels":sam_point_labels.numpy(), "masks":mask_input_dummy.numpy(), "masks_enable":masks_enable.numpy()}) + sparse_embeddings = torch.Tensor(sparse_embeddings) + dense_embeddings = torch.Tensor(dense_embeddings) + dense_pe = torch.Tensor(dense_pe) + + if self.mask_decoder_onnx == None: + self.mask_decoder_onnx = onnxruntime.InferenceSession("model/mask_decoder_"+model_id+".onnx") + if self.debug: + print("backbone_features", backbone_features.shape) + print("begin mask decoder onnx") + print("begin mask decoder onnx") + print("backbone_features", np.sum(backbone_features.numpy())) + print("image_pe", np.sum(dense_pe.numpy())) + print("sparse_embeddings", np.sum(sparse_embeddings.numpy())) + print("dense_embeddings", np.sum(dense_embeddings.numpy())) + print("high_res_features", np.sum(high_res_features[0].numpy())) + print("high_res_features", np.sum(high_res_features[1].numpy())) + masks, iou_pred, sam_tokens_out, object_score_logits = self.mask_decoder_onnx.run(None, { + "image_embeddings":backbone_features.numpy(), + "image_pe": dense_pe.numpy(), + "sparse_prompt_embeddings": sparse_embeddings.numpy(), + "dense_prompt_embeddings": dense_embeddings.numpy(), + #repeat_image=False, # the image is already batched + "high_res_features1":high_res_features[0].numpy(), + "high_res_features2":high_res_features[1].numpy()}) + masks = torch.Tensor(masks) + iou_pred = torch.Tensor(iou_pred) + sam_tokens_out = torch.Tensor(sam_tokens_out) + object_score_logits = torch.Tensor(object_score_logits) + low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder.forward_postprocess(masks, iou_pred, sam_tokens_out, object_score_logits, multimask_output) + #print(low_res_multimasks.shape) + #print(ious.shape) + #print(sam_output_tokens.shape) + #print(object_score_logits.shape) + + if import_from_tflite: + if self.debug: + print("begin prompt encoder tflite") + + import tensorflow as tf + if self.prompt_encoder_tflite == None: + self.prompt_encoder_tflite = tf.lite.Interpreter(model_path="model/prompt_encoder_"+model_id+".tflite") + input_details = self.prompt_encoder_tflite.get_input_details() + self.prompt_encoder_tflite.resize_tensor_input( + input_details[2]["index"], + [1, sam_point_coords.shape[1], 2] + ) + self.prompt_encoder_tflite.allocate_tensors() + + input_details = self.prompt_encoder_tflite.get_input_details() + output_details = self.prompt_encoder_tflite.get_output_details() + + self.prompt_encoder_tflite.set_tensor(input_details[2]["index"], sam_point_coords) + self.prompt_encoder_tflite.set_tensor(input_details[3]["index"], sam_point_labels) + self.prompt_encoder_tflite.set_tensor(input_details[0]["index"], mask_input_dummy) + self.prompt_encoder_tflite.set_tensor(input_details[1]["index"], masks_enable) + self.prompt_encoder_tflite.invoke() + + sparse_embeddings = self.prompt_encoder_tflite.get_tensor(output_details[1]["index"]) + dense_embeddings = self.prompt_encoder_tflite.get_tensor(output_details[0]["index"]) + dense_pe = self.prompt_encoder_tflite.get_tensor(output_details[2]["index"]) + + if self.mask_decoder_tflite == None: + self.mask_decoder_tflite = tf.lite.Interpreter(model_path="model/mask_decoder_"+model_id+".tflite") + + input_details = self.mask_decoder_tflite.get_input_details() + self.mask_decoder_tflite.resize_tensor_input( + input_details[1]["index"], + [1, sparse_embeddings.shape[1], 256] + ) + self.mask_decoder_tflite.allocate_tensors() + + input_details = self.mask_decoder_tflite.get_input_details() + output_details = self.mask_decoder_tflite.get_output_details() + + batched_mode = False + + self.mask_decoder_tflite.set_tensor(input_details[3]["index"], backbone_features.numpy()) + self.mask_decoder_tflite.set_tensor(input_details[6]["index"], dense_pe) + self.mask_decoder_tflite.set_tensor(input_details[1]["index"], sparse_embeddings) + self.mask_decoder_tflite.set_tensor(input_details[2]["index"], dense_embeddings) + self.mask_decoder_tflite.set_tensor(input_details[5]["index"], batched_mode) + self.mask_decoder_tflite.set_tensor(input_details[0]["index"], high_res_features[0].numpy()) + self.mask_decoder_tflite.set_tensor(input_details[4]["index"], high_res_features[1].numpy()) + self.mask_decoder_tflite.invoke() + + masks = self.mask_decoder_tflite.get_tensor(output_details[2]["index"]) + iou_pred = self.mask_decoder_tflite.get_tensor(output_details[0]["index"]) + sam_tokens_out = self.mask_decoder_tflite.get_tensor(output_details[3]["index"]) + object_score_logits = self.mask_decoder_tflite.get_tensor(output_details[1]["index"]) + + masks = torch.Tensor(masks) + iou_pred = torch.Tensor(iou_pred) + sam_tokens_out = torch.Tensor(sam_tokens_out) + object_score_logits = torch.Tensor(object_score_logits) + low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder.forward_postprocess(masks, iou_pred, sam_tokens_out, object_score_logits, multimask_output) + #print(low_res_multimasks.shape) + #print(ious.shape) + #print(sam_output_tokens.shape) + #print(object_score_logits.shape) + + if not import_from_onnx and not import_from_tflite: + if self.debug: + print("begin mask decoder torch") + print("backbone_features", backbone_features.shape) + if sam_mask_prompt is None: + import numpy as np + mask_input_dummy = torch.Tensor(np.zeros((1, self.image_size // 4, self.image_size // 4))) + masks_enable = torch.tensor([0], dtype=torch.int) + else: + mask_input_dummy = sam_mask_prompt + masks_enable = torch.tensor([1], dtype=torch.int) + sparse_embeddings, dense_embeddings, dense_pe = self.sam_prompt_encoder.forward( + coords=sam_point_coords, + labels=sam_point_labels, + masks=mask_input_dummy, + masks_enable=masks_enable + ) + + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder.forward_normal( + image_embeddings=backbone_features, + image_pe=dense_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features1=high_res_features[0], + high_res_features2=high_res_features[1], + ) + if self.debug: + print(low_res_multimasks.shape) + print(ious.shape) + print(sam_output_tokens.shape) + print(object_score_logits.shape) + if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 @@ -384,7 +593,55 @@ def _forward_sam_heads( low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks # Extract object pointer from the SAM output token (with occlusion handling) - obj_ptr = self.obj_ptr_proj(sam_output_token) + if export_to_onnx and not self.mlp_onnx_exported: + print("x", sam_output_token.shape) + self.mlp_onnx_exported = True + torch.onnx.export( + self.obj_ptr_proj, (sam_output_token), 'model/mlp_'+model_id+'.onnx', + input_names=["x"], + output_names=["x_out"], + dynamic_axes={ + 'x': {0: 'n'}, + 'obj_ptr': {0: 'n'} + }, + verbose=False, opset_version=17 + ) + + if import_from_onnx: + import onnxruntime + if self.mlp_onnx == None: + self.mlp_onnx = onnxruntime.InferenceSession("model/mlp_"+model_id+".onnx") + import numpy as np + obj_ptr = self.mlp_onnx.run(None, {"x":sam_output_token.numpy()})[0] + obj_ptr = torch.Tensor(obj_ptr) + + if export_to_tflite and not self.mlp_tflite_exported: + self.mlp_tflite_exported = True + import ai_edge_torch + import tensorflow as tf + sample_inputs = (sam_output_token,) + tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS]}} + edge_model = ai_edge_torch.convert(self.obj_ptr_proj, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags) + edge_model.export("model/mlp_"+model_id+".tflite") + + if import_from_tflite: + import tensorflow as tf + if self.mlp_tflite == None: + self.mlp_tflite = tf.lite.Interpreter(model_path="model/mlp_"+model_id+".tflite") + self.mlp_tflite.allocate_tensors() + + input_details = self.mlp_tflite.get_input_details() + output_details = self.mlp_tflite.get_output_details() + + self.mlp_tflite.set_tensor(input_details[0]["index"], sam_output_token.numpy()) + self.mlp_tflite.invoke() + + obj_ptr = self.mlp_tflite.get_tensor(output_details[0]["index"]) + obj_ptr = torch.Tensor(obj_ptr) + + if not import_from_onnx and not import_from_tflite: + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: @@ -408,7 +665,7 @@ def _forward_sam_heads( object_score_logits, ) - def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs, export_to_onnx, import_from_onnx, export_to_tflite, import_from_tflite, model_id): """ Directly turn binary `mask_inputs` into a output mask logits without using SAM. (same input and output shapes as in _forward_sam_heads above). @@ -437,6 +694,11 @@ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs) backbone_features=backbone_features, mask_inputs=self.mask_downsample(mask_inputs_float), high_res_features=high_res_features, + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying @@ -472,7 +734,7 @@ def forward_image(self, img_batch: torch.Tensor): backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( backbone_out["backbone_fpn"][1] ) - return backbone_out + return backbone_out["vision_features"], backbone_out["vision_pos_enc"][0], backbone_out["vision_pos_enc"][1], backbone_out["vision_pos_enc"][2], backbone_out["backbone_fpn"][0], backbone_out["backbone_fpn"][1], backbone_out["backbone_fpn"][2] def _prepare_backbone_features(self, backbone_out): """Prepare and flatten visual features.""" @@ -500,6 +762,11 @@ def _prepare_memory_conditioned_features( output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) + export_to_onnx=False, + import_from_onnx=False, + export_to_tflite=False, + import_from_tflite=False, + model_id=None ): """Fuse the current frame's visual feature map with previous memory.""" B = current_vision_feats[-1].size(1) # batch size on this frame @@ -650,13 +917,149 @@ def _prepare_memory_conditioned_features( memory = torch.cat(to_cat_memory, dim=0) memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) - pix_feat_with_mem = self.memory_attention( + # 標準の実装ではforwardの中でweightが確保されるが、エクスポート時に固定するために先に確保する + self.memory_attention.allocate_rope_attention_weight( curr=current_vision_feats, curr_pos=current_vision_pos_embeds, - memory=memory, - memory_pos=memory_pos_embed, - num_obj_ptr_tokens=num_obj_ptr_tokens, + image_size=self.image_size, ) + + # 4096の倍数のRoPEAttentionが適用される部分と手協されない部分を事前に分割する + # 動的なsliceがdynamoでエラーになるため + memory_1 = memory[:-num_obj_ptr_tokens,:,:] + memory_2 = memory[-num_obj_ptr_tokens:,:,:] + memory_pos_embed_1 = memory_pos_embed[:-num_obj_ptr_tokens,:,:] + memory_pos_embed_2 = memory_pos_embed[-num_obj_ptr_tokens:,:,:] + + if self.debug: + print("memory attention shape") + print("curr", current_vision_feats[0].shape) + print("memory", memory.shape) + print("curr_pos", current_vision_pos_embeds[0].shape) + print("memory_pos", memory_pos_embed.shape) + print("num_obj_ptr_tokens", num_obj_ptr_tokens) + + if export_to_onnx and not self.memory_attention_onnx_exported: + self.memory_attention_onnx_exported = True + #print("current_vision_feats", current_vision_feats[0].shape, current_vision_feats[0].dtype) + #print("memory", memory.shape, memory.dtype) + #print("current_vision_pos_embeds", current_vision_pos_embeds[0].shape, current_vision_pos_embeds[0].dtype) + #print("memory_pos_embed", memory_pos_embed.shape, memory_pos_embed.dtype) + #print("num_obj_ptr_tokens", num_obj_ptr_tokens) + torch.onnx.export( + self.memory_attention, (current_vision_feats[0], memory_1, memory_2, current_vision_pos_embeds[0], memory_pos_embed_1, memory_pos_embed_2), 'model/memory_attention_'+model_id+'.opt.onnx', + input_names=["curr", "memory_1", "memory_2", "curr_pos", "memory_pos_1", "memory_pos_2"], + output_names=["pix_feat"], + dynamic_axes={ + 'memory_1': {0: 'n_1'}, + 'memory_2': {0: 'n_2'}, + 'memory_pos_1': {0: 'n_1'}, + 'memory_pos_2': {0: 'n_2'} + }, + verbose=False, opset_version=17 + ) + #export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + #onnx_program =torch.onnx.dynamo_export( + # self.memory_attention, current_vision_feats[0], memory_1, memory_2, current_vision_pos_embeds[0], memory_pos_embed_1, memory_pos_embed_2, export_options=export_options + #) + #onnx_program.save('model/memory_attention_'+model_id+'.onnx') + + if import_from_onnx: + if self.debug: + print("begin memory attention onnx") + import onnxruntime + if self.memory_attention_onnx == None: + self.memory_attention_onnx = onnxruntime.InferenceSession("model/memory_attention_"+model_id+".opt.onnx") + import numpy as np + #num_obj_ptr_tokens_numpy = np.array((num_obj_ptr_tokens)).astype(np.int64) + #print("curr", np.sum(current_vision_feats[0].numpy())) + #print("memory", np.sum(memory.numpy())) + #print("curr_pos", np.sum(current_vision_pos_embeds[0].numpy())) + #print("memory_pos", np.sum(memory_pos_embed.numpy())) + #print("num_obj_ptr_tokens", np.sum(num_obj_ptr_tokens_numpy)) + + pix_feat_with_mem = self.memory_attention_onnx.run(None, {"curr":current_vision_feats[0].numpy(), "memory_1":memory_1.numpy(), "memory_2":memory_2.numpy(), "curr_pos":current_vision_pos_embeds[0].numpy(), "memory_pos_1":memory_pos_embed_1.numpy(), "memory_pos_2":memory_pos_embed_2.numpy()}) + pix_feat_with_mem = torch.Tensor(pix_feat_with_mem[0]) + + if export_to_tflite and not self.memory_attention_tflite_exported: + self.memory_attention_tflite_exported = True + import ai_edge_torch + import tensorflow as tf + sample_inputs = (current_vision_feats[0], memory_1, memory_2, current_vision_pos_embeds[0], memory_pos_embed_1, memory_pos_embed_2) + tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS]}} + if self.num_maskmem == 1 and self.max_obj_ptrs_in_encoder == 1: + edge_model = ai_edge_torch.convert(self.memory_attention, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags) + else: + n_1 = torch.export.Dim("n_1", min=1, max=256) + n_4096 = n_1 * 4096 + n_2 = torch.export.Dim("n_2", min=1, max=256) + n_4 = n_2 * 4 + dynamic_shapes={ + 'curr': None, + 'memory_1': {0: n_4096}, + 'memory_2': {0: n_4}, + 'curr_pos': None, + 'memory_pos_1': {0: n_4096}, + 'memory_pos_2': {0: n_4} + } + edge_model = ai_edge_torch.convert(self.memory_attention, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags, dynamic_shapes=dynamic_shapes) + edge_model.export("model/memory_attention_"+model_id+".tflite") + + if import_from_tflite: + if self.debug: + print("begin memory attention tflite") + import tensorflow as tf + if self.memory_attention_tflite == None: + self.memory_attention_tflite = tf.lite.Interpreter(model_path="model/memory_attention_"+model_id+".tflite") + self.memory_attention_tflite.allocate_tensors() + input_details = self.memory_attention_tflite.get_input_details() + self.memory_attention_tflite.resize_tensor_input( + input_details[5]["index"], + [memory_1.shape[0], 1, 64] + ) + self.memory_attention_tflite.resize_tensor_input( + input_details[1]["index"], + [memory_2.shape[0], 1, 64] + ) + self.memory_attention_tflite.resize_tensor_input( + input_details[4]["index"], + [memory_pos_embed_1.shape[0], 1, 64] + ) + self.memory_attention_tflite.resize_tensor_input( + input_details[0]["index"], + [memory_pos_embed_2.shape[0], 1, 64] + ) + self.memory_attention_tflite.allocate_tensors() + + input_details = self.memory_attention_tflite.get_input_details() + output_details = self.memory_attention_tflite.get_output_details() + + self.memory_attention_tflite.set_tensor(input_details[3]["index"], current_vision_feats[0].numpy()) + self.memory_attention_tflite.set_tensor(input_details[5]["index"], memory_1.numpy()) + self.memory_attention_tflite.set_tensor(input_details[1]["index"], memory_2.numpy()) + self.memory_attention_tflite.set_tensor(input_details[2]["index"], current_vision_pos_embeds[0].numpy()) + self.memory_attention_tflite.set_tensor(input_details[4]["index"], memory_pos_embed_1.numpy()) + self.memory_attention_tflite.set_tensor(input_details[0]["index"], memory_pos_embed_2.numpy()) + self.memory_attention_tflite.invoke() + + pix_feat_with_mem = self.memory_attention_tflite.get_tensor(output_details[0]["index"]) + pix_feat_with_mem = torch.Tensor(pix_feat_with_mem) + + if not import_from_onnx and not import_from_tflite: + #print("begin memory attention torch") + #print("current_vision_feats", current_vision_feats[0].shape) + #print("current_vision_pos_embeds", current_vision_pos_embeds[0].shape) + #print("memory", memory.shape) + #print("memory_pos_embed", memory_pos_embed.shape) + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + memory_1=memory_1, + memory_2=memory_2, + curr_pos=current_vision_pos_embeds, + memory_pos_1=memory_pos_embed_1, + memory_pos_2=memory_pos_embed_2, + ) + # reshape the output (HW)BC => BCHW pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem @@ -667,6 +1070,11 @@ def _encode_new_memory( feat_sizes, pred_masks_high_res, is_mask_from_pts, + export_to_onnx = False, + import_from_onnx = False, + export_to_tflite = False, + import_from_tflite = False, + model_id = None ): """Encode the current image and its prediction into a memory feature.""" B = current_vision_feats[-1].size(1) # batch size on this frame @@ -693,11 +1101,64 @@ def _encode_new_memory( mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc - maskmem_out = self.memory_encoder( - pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied - ) - maskmem_features = maskmem_out["vision_features"] - maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + if export_to_onnx and not self.memory_encoder_onnx_exported: + self.memory_encoder_onnx_exported = True + torch.onnx.export( + self.memory_encoder, (pix_feat, mask_for_mem), 'model/memory_encoder_'+model_id+'.onnx', + input_names=["pix_feat", "masks"], + output_names=["vision_features", "vision_pos_enc"], + verbose=False, opset_version=17 + ) + + if import_from_onnx: + if self.debug: + print("begin memory encoder onnx") + import onnxruntime + if self.memory_encoder_onnx == None: + self.memory_encoder_onnx = onnxruntime.InferenceSession("model/memory_encoder_"+model_id+".onnx") + vision_features, vision_pos_enc = self.memory_encoder_onnx.run(None, {"pix_feat":pix_feat.numpy(), "masks":mask_for_mem.numpy()}) + vision_features = torch.Tensor(vision_features) + vision_pos_enc = torch.Tensor(vision_pos_enc) + + if export_to_tflite and not self.memory_encoder_tflite_exported: + self.memory_encoder_tflite_exported = True + import ai_edge_torch + import tensorflow as tf + sample_inputs = (pix_feat, mask_for_mem) + tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS]}} + edge_model = ai_edge_torch.convert(self.memory_encoder, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags) + edge_model.export("model/memory_encoder_"+model_id+".tflite") + + if import_from_tflite: + if self.debug: + print("begin memory encoder tflite") + import tensorflow as tf + if self.memory_encoder_tflite == None: + self.memory_encoder_tflite = tf.lite.Interpreter(model_path="model/memory_encoder_"+model_id+".tflite") + self.memory_encoder_tflite.allocate_tensors() + + input_details = self.memory_encoder_tflite.get_input_details() + output_details = self.memory_encoder_tflite.get_output_details() + + self.memory_encoder_tflite.set_tensor(input_details[0]["index"], pix_feat.numpy()) + self.memory_encoder_tflite.set_tensor(input_details[1]["index"], mask_for_mem.numpy()) + self.memory_encoder_tflite.invoke() + + vision_features = self.memory_encoder_tflite.get_tensor(output_details[1]["index"]) + vision_pos_enc = self.memory_encoder_tflite.get_tensor(output_details[0]["index"]) + vision_features = torch.Tensor(vision_features) + vision_pos_enc = torch.Tensor(vision_pos_enc) + + if not import_from_onnx and not import_from_tflite: + if self.debug: + print("begin memory encoder torch") + vision_features, vision_pos_enc = self.memory_encoder( + pix_feat, mask_for_mem#, skip_mask_sigmoid=True # sigmoid already applied (fixed to constant) + ) + + maskmem_features = vision_features + maskmem_pos_enc = [vision_pos_enc] return maskmem_features, maskmem_pos_enc @@ -721,6 +1182,12 @@ def track_step( run_mem_encoder=True, # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). prev_sam_mask_logits=None, + # ONNX Export + export_to_onnx=False, + import_from_onnx=False, + export_to_tflite=False, + import_from_tflite=False, + model_id=None ): current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW @@ -737,7 +1204,8 @@ def track_step( pix_feat = current_vision_feats[-1].permute(1, 2, 0) pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) sam_outputs = self._use_mask_as_output( - pix_feat, high_res_features, mask_inputs + pix_feat, high_res_features, mask_inputs, + export_to_onnx=export_to_onnx, import_from_onnx=import_from_onnx, export_to_tflite=export_to_tflite, import_from_tflite=import_from_tflite, model_id=model_id ) else: # fused the visual feature with previous memory features in the memory bank @@ -750,6 +1218,11 @@ def track_step( output_dict=output_dict, num_frames=num_frames, track_in_reverse=track_in_reverse, + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) # apply SAM-style segmentation head # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, @@ -765,6 +1238,11 @@ def track_step( mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) ( _, @@ -789,6 +1267,11 @@ def track_step( feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks_for_mem_enc, is_mask_from_pts=(point_inputs is not None), + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) current_out["maskmem_features"] = maskmem_features current_out["maskmem_pos_enc"] = maskmem_pos_enc diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 41ce53af5..071dca649 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -12,6 +12,8 @@ import torch from PIL.Image import Image +import onnxruntime + from sam2.modeling.sam2_base import SAM2Base from sam2.utils.transforms import SAM2Transforms @@ -24,6 +26,7 @@ def __init__( mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, + image_size=1024, **kwargs, ) -> None: """ @@ -42,7 +45,7 @@ def __init__( super().__init__() self.model = sam_model self._transforms = SAM2Transforms( - resolution=self.model.image_size, + resolution=sam_model.image_size, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area, @@ -60,11 +63,24 @@ def __init__( # Spatial dim for backbone feature maps self._bb_feat_sizes = [ - (256, 256), - (128, 128), - (64, 64), + (sam_model.image_size // 4, sam_model.image_size // 4), + (sam_model.image_size // 8, sam_model.image_size // 8), + (sam_model.image_size // 16, sam_model.image_size // 16), ] + # debug + self.debug = False + + # onnx + self.image_encoder_onnx = None + self.prompt_encoder_onnx = None + self.mask_decoder_onnx = None + + # tflite + self.image_encoder_tflite = None + self.prompt_encoder_tflite = None + self.mask_decoder_tflite = None + @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": """ @@ -86,6 +102,12 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": def set_image( self, image: Union[np.ndarray, Image], + export_to_onnx = False, + export_to_tflite = False, + import_from_onnx = False, + import_from_tflite = False, + tflite_int8=False, + model_id=None ) -> None: """ Calculates the image embeddings for the provided image, allowing @@ -114,7 +136,91 @@ def set_image( len(input_image.shape) == 4 and input_image.shape[1] == 3 ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" logging.info("Computing image embeddings for the provided image...") - backbone_out = self.model.forward_image(input_image) + + if export_to_onnx: + #print("input_image", input_image.shape) + self.model.forward = self.model.forward_image + torch.onnx.export( + self.model, (input_image), 'model/image_encoder_'+model_id+'.onnx', + input_names=["input_image"], + output_names=["vision_features", "vision_pos_enc_0", "vision_pos_enc_1", "vision_pos_enc_2", "backbone_fpn_0", "backbone_fpn_1", "backbone_fpn_2"], + verbose=False, opset_version=17 + ) + + if import_from_onnx: + if self.image_encoder_onnx == None: + self.image_encoder_onnx = onnxruntime.InferenceSession("model/image_encoder_"+model_id+".onnx") + vision_features, vision_pos_enc_0, vision_pos_enc_1, vision_pos_enc_2, backbone_fpn_0, backbone_fpn_1, backbone_fpn_2 = self.image_encoder_onnx.run(None, {"input_image":input_image.numpy()}) + if self.debug: + print("vision_features", vision_features.shape) + print("vision_pos_enc_0", vision_pos_enc_0.shape) + print("vision_pos_enc_1", vision_pos_enc_1.shape) + print("vision_pos_enc_2", vision_pos_enc_2.shape) + print("backbone_fpn_0", backbone_fpn_0.shape) + print("backbone_fpn_1", backbone_fpn_1.shape) + print("backbone_fpn_2", backbone_fpn_2.shape) + + if export_to_tflite: + import ai_edge_torch + import tensorflow as tf + sample_inputs = (input_image,) + self.model.forward = self.model.forward_image + + if not tflite_int8: + tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]}} + edge_model = ai_edge_torch.convert(self.model, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags) + edge_model.export("model/image_encoder_"+model_id+".tflite") + + if tflite_int8: + from ai_edge_torch.quantize import pt2e_quantizer + from ai_edge_torch.quantize import quant_config + from torch.ao.quantization import quantize_pt2e + + quantizer = pt2e_quantizer.PT2EQuantizer().set_global( + pt2e_quantizer.get_symmetric_quantization_config(is_dynamic=True) + ) + model = torch._export.capture_pre_autograd_graph(self.model, sample_inputs) + model = quantize_pt2e.prepare_pt2e(model, quantizer) + model(input_image) # calibration (you need to edit reset_histogram function) + model = quantize_pt2e.convert_pt2e(model, fold_quantize=False) + + tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]}} + with_quantizer = ai_edge_torch.convert( + model, + sample_inputs, + quant_config=quant_config.QuantConfig(pt2e_quantizer=quantizer), + _ai_edge_converter_flags=tfl_converter_flags + ) + with_quantizer.export("model/image_encoder_"+model_id+"_int8.tflite") + edge_model = model + + if import_from_tflite: + import tensorflow as tf + if self.image_encoder_tflite == None: + self.image_encoder_tflite = tf.lite.Interpreter(model_path="model/image_encoder_"+model_id+".tflite") + self.image_encoder_tflite.allocate_tensors() + + input_details = self.image_encoder_tflite.get_input_details() + output_details = self.image_encoder_tflite.get_output_details() + + self.image_encoder_tflite.set_tensor(input_details[0]["index"], input_image.numpy()) + self.image_encoder_tflite.invoke() + + vision_features = self.image_encoder_tflite.get_tensor(output_details[4]["index"]) + vision_pos_enc_0 = self.image_encoder_tflite.get_tensor(output_details[1]["index"]) + vision_pos_enc_1 = self.image_encoder_tflite.get_tensor(output_details[5]["index"]) + vision_pos_enc_2 = self.image_encoder_tflite.get_tensor(output_details[3]["index"]) + backbone_fpn_0 = self.image_encoder_tflite.get_tensor(output_details[0]["index"]) + backbone_fpn_1 = self.image_encoder_tflite.get_tensor(output_details[2]["index"]) + backbone_fpn_2 = self.image_encoder_tflite.get_tensor(output_details[6]["index"]) + + if not import_from_onnx and not import_from_tflite: + vision_features, vision_pos_enc_0, vision_pos_enc_1, vision_pos_enc_2, backbone_fpn_0, backbone_fpn_1, backbone_fpn_2 = self.model.forward_image(input_image) + + backbone_out = {"vision_features":torch.Tensor(vision_features), + "vision_pos_enc":[torch.Tensor(vision_pos_enc_0), torch.Tensor(vision_pos_enc_1), torch.Tensor(vision_pos_enc_2)], + "backbone_fpn":[torch.Tensor(backbone_fpn_0), torch.Tensor(backbone_fpn_1), torch.Tensor(backbone_fpn_2)]} + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos if self.model.directly_add_no_mem_embed: @@ -124,6 +230,7 @@ def set_image( feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} self._is_image_set = True logging.info("Image embeddings computed.") @@ -243,6 +350,12 @@ def predict( multimask_output: bool = True, return_logits: bool = False, normalize_coords=True, + export_to_onnx=False, + export_to_tflite=False, + import_from_onnx = False, + import_from_tflite = False, + tflite_int8=False, + model_id=None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. @@ -283,7 +396,6 @@ def predict( ) # Transform input prompts - mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( point_coords, point_labels, box, mask_input, normalize_coords ) @@ -295,6 +407,12 @@ def predict( mask_input, multimask_output, return_logits=return_logits, + export_to_onnx=export_to_onnx, + export_to_tflite=export_to_tflite, + import_from_onnx=import_from_onnx, + import_from_tflite=import_from_tflite, + tflite_int8=tflite_int8, + model_id=model_id ) masks_np = masks.squeeze(0).float().detach().cpu().numpy() @@ -343,6 +461,12 @@ def _predict( multimask_output: bool = True, return_logits: bool = False, img_idx: int = -1, + export_to_onnx = False, + export_to_tflite = False, + tflite_int8 = False, + import_from_onnx = False, + import_from_tflite = False, + model_id = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. @@ -403,11 +527,106 @@ def _predict( else: concat_points = (box_coords, box_labels) - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=concat_points, - boxes=None, - masks=mask_input, - ) + # New data for onnx + if concat_points is None: + raise ("concat points must be exists") # Noneの場合はtensorサイズが0のテンソルを返さないといけないためwhereで組めない + if mask_input is None: + mask_input_dummy = torch.Tensor(np.zeros((1, self.model.image_size // 4, self.model.image_size // 4))) + masks_enable = torch.tensor([0], dtype=torch.int) # boolだとonnxへのエクスポートのwhereでエラーになる + else: + mask_input_dummy = mask_input + masks_enable = torch.tensor([1], dtype=torch.int) + + if export_to_onnx: + #print("concat_points", concat_points.shape) + #print("mask_input", mask_input.shape) + torch.onnx.export( + self.model.sam_prompt_encoder, (concat_points[0], concat_points[1], mask_input_dummy, masks_enable), 'model/prompt_encoder_'+model_id+'.onnx', + input_names=["coords", "labels", "masks", "masks_enable"], + output_names=["sparse_embeddings", "dense_embeddings", "dense_pe"], + dynamic_axes={ + 'coords': {0: 'b', 1: 'n'}, + 'labels': {0: 'b', 1: 'n'}, + 'masks': {0: 'b', 1: 'h', 2: 'w'}, + }, + verbose=False, opset_version=17 + ) + + if import_from_onnx: + if self.prompt_encoder_onnx == None: + self.prompt_encoder_onnx = onnxruntime.InferenceSession("model/prompt_encoder_"+model_id+".onnx") + sparse_embeddings, dense_embeddings, dense_pe = self.prompt_encoder_onnx.run(None, {"coords":concat_points[0].numpy(), "labels":concat_points[1].numpy(), "masks": mask_input_dummy.numpy(), "masks_enable":masks_enable.numpy()}) + sparse_embeddings = torch.Tensor(sparse_embeddings) + dense_embeddings = torch.Tensor(dense_embeddings) + dense_pe = torch.Tensor(dense_pe) + + if export_to_tflite: + import ai_edge_torch + sample_inputs = (concat_points[0], concat_points[1], mask_input_dummy, masks_enable) + + if not tflite_int8: + edge_model = ai_edge_torch.convert(self.model.sam_prompt_encoder, sample_inputs) + edge_model.export("model/prompt_encoder_"+model_id+".tflite") + + if False:#tflite_int8: # labelがint64で量子化できない + from ai_edge_torch.quantize import pt2e_quantizer + from ai_edge_torch.quantize import quant_config + from torch.ao.quantization import quantize_pt2e + + quantizer = pt2e_quantizer.PT2EQuantizer().set_global( + pt2e_quantizer.get_symmetric_quantization_config() + ) + model = torch._export.capture_pre_autograd_graph(self.model.sam_prompt_encoder, sample_inputs) + model = quantize_pt2e.prepare_pt2e(model, quantizer) + model(concat_points[0], concat_points[1]) # calibration + model = quantize_pt2e.convert_pt2e(model, fold_quantize=False) + + with_quantizer = ai_edge_torch.convert( + model, + sample_inputs, + quant_config=quant_config.QuantConfig(pt2e_quantizer=quantizer), + ) + with_quantizer.export("model/prompt_encoder_"+model_id+"_int8.tflite") + + edge_model = model + + if import_from_tflite: + import tensorflow as tf + if self.prompt_encoder_tflite == None: + self.prompt_encoder_tflite = tf.lite.Interpreter(model_path="model/prompt_encoder_"+model_id+".tflite") + self.prompt_encoder_tflite.allocate_tensors() + input_details = self.prompt_encoder_tflite.get_input_details() + self.prompt_encoder_tflite.resize_tensor_input( + input_details[2]["index"], + [1, concat_points[0].shape[1], 2] + ) + self.prompt_encoder_tflite.allocate_tensors() + + input_details = self.prompt_encoder_tflite.get_input_details() + output_details = self.prompt_encoder_tflite.get_output_details() + + self.prompt_encoder_tflite.set_tensor(input_details[2]["index"], concat_points[0]) + self.prompt_encoder_tflite.set_tensor(input_details[3]["index"], concat_points[1]) + self.prompt_encoder_tflite.set_tensor(input_details[0]["index"], mask_input_dummy) + self.prompt_encoder_tflite.set_tensor(input_details[1]["index"], masks_enable) + self.prompt_encoder_tflite.invoke() + + sparse_embeddings = self.prompt_encoder_tflite.get_tensor(output_details[1]["index"]) + dense_embeddings = self.prompt_encoder_tflite.get_tensor(output_details[0]["index"]) + dense_pe = self.prompt_encoder_tflite.get_tensor(output_details[2]["index"]) + + sparse_embeddings = torch.Tensor(sparse_embeddings) + dense_embeddings = torch.Tensor(dense_embeddings) + dense_pe = torch.Tensor(dense_pe) + + if not import_from_onnx and not import_from_tflite: + sparse_embeddings, dense_embeddings, dense_pe = self.model.sam_prompt_encoder.forward( + coords=concat_points[0], + labels=concat_points[1], + #boxes=None, + masks=mask_input_dummy, + masks_enable=masks_enable + ) # Predict masks batched_mode = ( @@ -417,15 +636,124 @@ def _predict( feat_level[img_idx].unsqueeze(0) for feat_level in self._features["high_res_feats"] ] - low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - repeat_image=batched_mode, - high_res_features=high_res_features, - ) + + #print("sparse_embeddings", sparse_embeddings.shape) + #print("dense_embeddings", dense_embeddings.shape) + + if export_to_onnx: + self.model.sam_mask_decoder.forward = self.model.sam_mask_decoder.forward_masks # multimask_outputが定数になってしまうので分離 + torch.onnx.export( + self.model.sam_mask_decoder, (self._features["image_embed"][img_idx].unsqueeze(0), dense_pe, sparse_embeddings, dense_embeddings, batched_mode, high_res_features[0], high_res_features[1]), + 'model/mask_decoder_'+model_id+'.onnx', + input_names=["image_embeddings", "image_pe", "sparse_prompt_embeddings", "dense_prompt_embeddings", "repeat_image", "high_res_features1", "high_res_features2"], + output_names=["masks", "iou_pred", "sam_tokens_out", "object_score_logits"], + dynamic_axes={ + 'sparse_prompt_embeddings': {1: 'n'}, + }, + verbose=False, opset_version=17 + ) + + if import_from_onnx: + if self.mask_decoder_onnx == None: + self.mask_decoder_onnx = onnxruntime.InferenceSession("model/mask_decoder_"+model_id+".onnx") + masks, iou_pred, sam_tokens_out, object_score_logits = self.mask_decoder_onnx.run(None, { + "image_embeddings":self._features["image_embed"][img_idx].unsqueeze(0).numpy(), + "image_pe": dense_pe.numpy(), + "sparse_prompt_embeddings": sparse_embeddings.numpy(), + "dense_prompt_embeddings": dense_embeddings.numpy(), + "high_res_features1":high_res_features[0].numpy(), + "high_res_features2":high_res_features[1].numpy()}) + masks = torch.Tensor(masks) + iou_pred = torch.Tensor(iou_pred) + sam_tokens_out = torch.Tensor(sam_tokens_out) + object_score_logits = torch.Tensor(object_score_logits) + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder.forward_postprocess(masks, iou_pred, sam_tokens_out, object_score_logits, multimask_output) + + if export_to_tflite: + self.model.sam_mask_decoder.forward = self.model.sam_mask_decoder.forward_masks + sample_inputs = (self._features["image_embed"][img_idx].unsqueeze(0), dense_pe, sparse_embeddings, dense_embeddings, batched_mode, high_res_features[0], high_res_features[1]) + + if not tflite_int8: + import ai_edge_torch + edge_model = ai_edge_torch.convert(self.model.sam_mask_decoder, sample_inputs) + edge_model.export("model/mask_decoder_"+model_id+".tflite") + + if tflite_int8: + from ai_edge_torch.quantize import pt2e_quantizer + from ai_edge_torch.quantize import quant_config + from torch.ao.quantization import quantize_pt2e + + quantizer = pt2e_quantizer.PT2EQuantizer().set_global( + pt2e_quantizer.get_symmetric_quantization_config() + ) + model = torch._export.capture_pre_autograd_graph(self.model.sam_mask_decoder, sample_inputs) + model = quantize_pt2e.prepare_pt2e(model, quantizer) + model(self._features["image_embed"][img_idx].unsqueeze(0), dense_pe, sparse_embeddings, dense_embeddings, batched_mode, high_res_features[0], high_res_features[1]) # calibration + model = quantize_pt2e.convert_pt2e(model, fold_quantize=False) + + with_quantizer = ai_edge_torch.convert( + model, + sample_inputs, + quant_config=quant_config.QuantConfig(pt2e_quantizer=quantizer), + ) + with_quantizer.export("model/mask_decoder_"+model_id+"_int8.tflite") + + edge_model = model + + if import_from_tflite: + batched_mode_np = np.zeros((1), dtype=bool) + if batched_mode: + batched_mode_np[0] = True + + import tensorflow as tf + if self.mask_decoder_tflite == None: + self.mask_decoder_tflite = tf.lite.Interpreter(model_path="model/mask_decoder_"+model_id+".tflite") + self.mask_decoder_tflite.allocate_tensors() + input_details = self.mask_decoder_tflite.get_input_details() + self.mask_decoder_tflite.resize_tensor_input( + input_details[1]["index"], + [1, sparse_embeddings.shape[1], 256] + ) + self.mask_decoder_tflite.allocate_tensors() + + input_details = self.mask_decoder_tflite.get_input_details() + output_details = self.mask_decoder_tflite.get_output_details() + + batched_mode = False + + self.mask_decoder_tflite.set_tensor(input_details[3]["index"], self._features["image_embed"][img_idx].unsqueeze(0).numpy()) + self.mask_decoder_tflite.set_tensor(input_details[6]["index"], dense_pe) + self.mask_decoder_tflite.set_tensor(input_details[1]["index"], sparse_embeddings) + self.mask_decoder_tflite.set_tensor(input_details[2]["index"], dense_embeddings) + self.mask_decoder_tflite.set_tensor(input_details[5]["index"], batched_mode) + self.mask_decoder_tflite.set_tensor(input_details[0]["index"], high_res_features[0].numpy()) + self.mask_decoder_tflite.set_tensor(input_details[4]["index"], high_res_features[1].numpy()) + self.mask_decoder_tflite.invoke() + + masks = self.mask_decoder_tflite.get_tensor(output_details[2]["index"]) + iou_pred = self.mask_decoder_tflite.get_tensor(output_details[0]["index"]) + sam_tokens_out = self.mask_decoder_tflite.get_tensor(output_details[3]["index"]) + object_score_logits = self.mask_decoder_tflite.get_tensor(output_details[1]["index"]) + + masks = torch.Tensor(masks) + iou_pred = torch.Tensor(iou_pred) + sam_tokens_out = torch.Tensor(sam_tokens_out) + object_score_logits = torch.Tensor(object_score_logits) + + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder.forward_postprocess(masks, iou_pred, sam_tokens_out, object_score_logits, multimask_output) + + if not import_from_onnx and not import_from_tflite: + self.model.sam_mask_decoder.forward = self.model.sam_mask_decoder.forward_normal + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=dense_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features1=high_res_features[0], + high_res_features2=high_res_features[1], + ) # Upscale the masks to the original image resolution masks = self._transforms.postprocess_masks( diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 8b2fd6c4d..a4c40a86a 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -35,6 +35,7 @@ def __init__( self.non_overlap_masks = non_overlap_masks self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.image_encoder_onnx = None @torch.inference_mode() def init_state( @@ -43,6 +44,9 @@ def init_state( offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, + import_from_onnx=False, + import_from_tflite=False, + model_id=None ): """Initialize an inference state.""" compute_device = self.device # device of the model @@ -103,7 +107,7 @@ def init_state( inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"] = {} # Warm up the visual backbone and cache the image feature on frame 0 - self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + self._get_image_feature(inference_state, frame_idx=0, batch_size=1, import_from_onnx=import_from_onnx, import_from_tflite=import_from_tflite, model_id=model_id) return inference_state @classmethod @@ -176,6 +180,11 @@ def add_new_points_or_box( clear_old_points=True, normalize_coords=True, box=None, + import_from_onnx=False, + export_to_onnx=False, + import_from_tflite=False, + export_to_tflite=False, + model_id=None ): """Add new points to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) @@ -291,6 +300,11 @@ def add_new_points_or_box( # them into memory. run_mem_encoder=False, prev_sam_mask_logits=prev_sam_mask_logits, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out @@ -303,6 +317,9 @@ def add_new_points_or_box( is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, + import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, export_to_tflite=export_to_tflite, + model_id=model_id ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] @@ -320,6 +337,11 @@ def add_new_mask( frame_idx, obj_id, mask, + import_from_onnx=False, + export_to_onnx=False, + import_from_tflite=False, + export_to_tflite=False, + model_id=None ): """Add new mask to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) @@ -379,6 +401,11 @@ def add_new_mask( # allows us to enforce non-overlapping constraints on all objects before encoding # them into memory. run_mem_encoder=False, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id ) # Add the output to the output dict (to be used as future memory) obj_temp_output_dict[storage_key][frame_idx] = current_out @@ -391,6 +418,9 @@ def add_new_mask( is_cond=is_cond, run_mem_encoder=False, consolidate_at_video_res=True, + import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, export_to_tflite=export_to_tflite, + model_id=model_id ) _, video_res_masks = self._get_orig_video_res_output( inference_state, consolidated_out["pred_masks_video_res"] @@ -426,6 +456,11 @@ def _consolidate_temp_output_across_obj( is_cond, run_mem_encoder, consolidate_at_video_res=False, + import_from_onnx=False, + export_to_onnx=False, + import_from_tflite=False, + export_to_tflite=False, + model_id=None ): """ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on @@ -492,7 +527,7 @@ def _consolidate_temp_output_across_obj( if run_mem_encoder: if empty_mask_ptr is None: empty_mask_ptr = self._get_empty_mask_ptr( - inference_state, frame_idx + inference_state, frame_idx, import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, model_id=model_id ) # fill object pointer with a dummy pointer (based on an empty mask) consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr @@ -531,13 +566,18 @@ def _consolidate_temp_output_across_obj( batch_size=batch_size, high_res_masks=high_res_masks, is_mask_from_pts=True, # these frames are what the user interacted with + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) consolidated_out["maskmem_features"] = maskmem_features consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc return consolidated_out - def _get_empty_mask_ptr(self, inference_state, frame_idx): + def _get_empty_mask_ptr(self, inference_state, frame_idx, import_from_onnx, export_to_onnx, import_from_tflite, export_to_tflite, model_id): """Get a dummy object pointer based on an empty mask on the current frame.""" # A dummy (empty) mask with a single object batch_size = 1 @@ -554,7 +594,7 @@ def _get_empty_mask_ptr(self, inference_state, frame_idx): current_vision_feats, current_vision_pos_embeds, feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) + ) = self._get_image_feature(inference_state, frame_idx, batch_size, import_from_onnx=import_from_onnx, import_from_tflite=import_from_tflite, model_id=model_id) # Feed the empty mask and image feature above to get a dummy object pointer current_out = self.track_step( @@ -570,11 +610,16 @@ def _get_empty_mask_ptr(self, inference_state, frame_idx): track_in_reverse=False, run_mem_encoder=False, prev_sam_mask_logits=None, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id ) return current_out["obj_ptr"] @torch.inference_mode() - def propagate_in_video_preflight(self, inference_state): + def propagate_in_video_preflight(self, inference_state, import_from_onnx=False, export_to_onnx=False, import_from_tflite=False, export_to_tflite=False, model_id=None): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Tracking has started and we don't allow adding new objects until session is reset. inference_state["tracking_has_started"] = True @@ -601,7 +646,7 @@ def propagate_in_video_preflight(self, inference_state): # consolidate the temporary output across all objects on this frame for frame_idx in temp_frame_inds: consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True, import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, import_from_tflite=import_from_tflite, export_to_tflite=export_to_tflite, model_id=model_id ) # merge them into "output_dict" and also create per-object slices output_dict[storage_key][frame_idx] = consolidated_out @@ -650,9 +695,14 @@ def propagate_in_video( start_frame_idx=None, max_frame_num_to_track=None, reverse=False, + import_from_onnx=False, + export_to_onnx=False, + import_from_tflite=False, + export_to_tflite=False, + model_id=None ): """Propagate the input points across frames to track in the entire video.""" - self.propagate_in_video_preflight(inference_state) + self.propagate_in_video_preflight(inference_state, import_from_onnx=import_from_onnx, export_to_onnx=export_to_onnx, import_from_tflite=import_from_tflite, export_to_tflite=export_to_tflite, model_id=model_id) output_dict = inference_state["output_dict"] consolidated_frame_inds = inference_state["consolidated_frame_inds"] @@ -712,6 +762,11 @@ def propagate_in_video( mask_inputs=None, reverse=reverse, run_mem_encoder=True, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id ) output_dict[storage_key][frame_idx] = current_out # Create slices of per-object outputs for subsequent interaction with each @@ -788,7 +843,7 @@ def _reset_tracking_results(self, inference_state): inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"].clear() - def _get_image_feature(self, inference_state, frame_idx, batch_size): + def _get_image_feature(self, inference_state, frame_idx, batch_size, import_from_onnx = False, import_from_tflite = False, model_id = None): """Compute the image features on a given frame.""" # Look up in the cache first image, backbone_out = inference_state["cached_features"].get( @@ -798,7 +853,45 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size): # Cache miss -- we will run inference on a single image device = inference_state["device"] image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) - backbone_out = self.forward_image(image) + if import_from_onnx: + if self.debug: + print("begin image encoder onnx") + import onnxruntime + if self.image_encoder_onnx == None: + self.image_encoder_onnx = onnxruntime.InferenceSession("model/image_encoder_"+model_id+".onnx") + vision_features, vision_pos_enc_0, vision_pos_enc_1, vision_pos_enc_2, backbone_fpn_0, backbone_fpn_1, backbone_fpn_2 = self.image_encoder_onnx.run(None, {"input_image":image.numpy()}) + + if import_from_tflite: + if self.debug: + print("begin image encoder tflite") + import tensorflow as tf + if self.image_encoder_tflite == None: + self.image_encoder_tflite = tf.lite.Interpreter(model_path="model/image_encoder_"+model_id+".tflite") + self.image_encoder_tflite.allocate_tensors() + + input_details = self.image_encoder_tflite.get_input_details() + output_details = self.image_encoder_tflite.get_output_details() + + self.image_encoder_tflite.set_tensor(input_details[0]["index"], image.numpy()) + self.image_encoder_tflite.invoke() + + vision_features = self.image_encoder_tflite.get_tensor(output_details[4]["index"]) + vision_pos_enc_0 = self.image_encoder_tflite.get_tensor(output_details[1]["index"]) + vision_pos_enc_1 = self.image_encoder_tflite.get_tensor(output_details[5]["index"]) + vision_pos_enc_2 = self.image_encoder_tflite.get_tensor(output_details[3]["index"]) + backbone_fpn_0 = self.image_encoder_tflite.get_tensor(output_details[0]["index"]) + backbone_fpn_1 = self.image_encoder_tflite.get_tensor(output_details[2]["index"]) + backbone_fpn_2 = self.image_encoder_tflite.get_tensor(output_details[6]["index"]) + + if not import_from_onnx and not import_from_tflite: + if self.debug: + print("begin image encoder torch") + vision_features, vision_pos_enc_0, vision_pos_enc_1, vision_pos_enc_2, backbone_fpn_0, backbone_fpn_1, backbone_fpn_2 = self.forward_image(image) + + backbone_out = {"vision_features":torch.Tensor(vision_features), + "vision_pos_enc":[torch.Tensor(vision_pos_enc_0), torch.Tensor(vision_pos_enc_1), torch.Tensor(vision_pos_enc_2)], + "backbone_fpn":[torch.Tensor(backbone_fpn_0), torch.Tensor(backbone_fpn_1), torch.Tensor(backbone_fpn_2)]} + # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). inference_state["cached_features"] = {frame_idx: (image, backbone_out)} @@ -833,6 +926,11 @@ def _run_single_frame_inference( reverse, run_mem_encoder, prev_sam_mask_logits=None, + import_from_onnx=False, + export_to_onnx=False, + import_from_tflite=False, + export_to_tflite=False, + model_id=None ): """Run tracking on a single frame based on current inputs and previous memory.""" # Retrieve correct image features @@ -842,7 +940,7 @@ def _run_single_frame_inference( current_vision_feats, current_vision_pos_embeds, feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) + ) = self._get_image_feature(inference_state, frame_idx, batch_size, import_from_onnx=import_from_onnx, import_from_tflite=import_from_tflite, model_id=model_id) # point and mask should not appear as input simultaneously on the same frame assert point_inputs is None or mask_inputs is None @@ -859,6 +957,11 @@ def _run_single_frame_inference( track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, + import_from_onnx=import_from_onnx, + export_to_onnx=export_to_onnx, + import_from_tflite=import_from_tflite, + export_to_tflite=export_to_tflite, + model_id=model_id ) # optionally offload the output to CPU memory to save GPU space @@ -888,7 +991,7 @@ def _run_single_frame_inference( return compact_current_out, pred_masks_gpu def _run_memory_encoder( - self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts, export_to_onnx, import_from_onnx, export_to_tflite, import_from_tflite, model_id ): """ Run the memory encoder on `high_res_masks`. This is usually after applying @@ -897,13 +1000,18 @@ def _run_memory_encoder( """ # Retrieve correct image features _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( - inference_state, frame_idx, batch_size + inference_state, frame_idx, batch_size, import_from_onnx=import_from_onnx, import_from_tflite=import_from_tflite, model_id=model_id ) maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks, is_mask_from_pts=is_mask_from_pts, + export_to_onnx=export_to_onnx, + import_from_onnx=import_from_onnx, + export_to_tflite=export_to_tflite, + import_from_tflite=import_from_tflite, + model_id=model_id ) # optionally offload the output to CPU memory to save GPU space