-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SAM2 video-to-video real-time pipeline (#280)
Add SAM2 video-to-video real-time pipeline
- Loading branch information
Showing
7 changed files
with
240 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .wrapper import Sam2Wrapper | ||
|
||
__all__ = ["Sam2Wrapper"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import logging | ||
from typing import List, Optional, Tuple | ||
from PIL import Image | ||
import torch | ||
from sam2.build_sam import build_sam2_camera_predictor | ||
from omegaconf import OmegaConf | ||
from hydra.utils import instantiate | ||
from hydra import initialize_config_dir, compose | ||
from hydra.core.global_hydra import GlobalHydra | ||
import os | ||
|
||
MODEL_MAPPING = { | ||
"facebook/sam2-hiera-tiny": { | ||
"config": "sam2_hiera_t.yaml", | ||
"checkpoint": "sam2_hiera_tiny.pt" | ||
}, | ||
"facebook/sam2-hiera-small": { | ||
"config": "sam2_hiera_s.yaml", | ||
"checkpoint": "sam2_hiera_small.pt" | ||
}, | ||
"facebook/sam2-hiera-large": { | ||
"config": "sam2_hiera_l.yaml", | ||
"checkpoint": "sam2_hiera_large.pt" | ||
} | ||
} | ||
|
||
class Sam2Wrapper: | ||
def __init__( | ||
self, | ||
model_id_or_path: str, | ||
device: str, | ||
**kwargs | ||
): | ||
self.device = device | ||
self.model_id = model_id_or_path | ||
if model_id_or_path in MODEL_MAPPING: | ||
model_info = MODEL_MAPPING[model_id_or_path] | ||
config_path = os.path.join("/models/sam2--checkpoints/", model_id_or_path.replace("/", "--")) | ||
model_cfg = model_info['config'] | ||
sam2_checkpoint = f"{config_path}/{model_info['checkpoint']}" | ||
else: | ||
raise ValueError(f"Model ID {model_id_or_path} not supported") | ||
|
||
logging.info(f"Initializing segment-anything-2 with model_id {self.model_id}") | ||
|
||
torch.autocast(device_type=device, dtype=torch.bfloat16).__enter__() | ||
if torch.cuda.get_device_properties(0).major >= 8: | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
# Initialize Hydra to load the configuration | ||
if GlobalHydra.instance().is_initialized(): | ||
GlobalHydra.instance().clear() | ||
|
||
# Code from sam2.build_sam.build_sam2_camera_predictor to appease Hydra | ||
with initialize_config_dir(config_dir=config_path, version_base=None): | ||
cfg = compose(config_name=model_cfg) | ||
|
||
hydra_overrides = [ | ||
"++model._target_=sam2.sam2_camera_predictor.SAM2CameraPredictor", | ||
] | ||
hydra_overrides_extra = [ | ||
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", | ||
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", | ||
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", | ||
"++model.binarize_mask_from_pts_for_mem_enc=true", | ||
"++model.fill_hole_area=8", | ||
] | ||
hydra_overrides.extend(hydra_overrides_extra) | ||
|
||
cfg = compose(config_name=model_cfg, overrides=hydra_overrides) | ||
OmegaConf.resolve(cfg) | ||
|
||
#Load the model | ||
model = instantiate(cfg.model, _recursive_=True) | ||
load_checkpoint(model, sam2_checkpoint, self.device) | ||
model.to(self.device) | ||
model.eval() | ||
|
||
# Set the model in memory | ||
self.predictor = model | ||
|
||
def load_checkpoint(model, ckpt_path, device): | ||
if ckpt_path is not None: | ||
|
||
sd = torch.load(ckpt_path, map_location=device)["model"] | ||
missing_keys, unexpected_keys = model.load_state_dict(sd) | ||
if missing_keys: | ||
logging.error(f"Missing keys: {missing_keys}") | ||
raise RuntimeError("Missing keys while loading checkpoint.") | ||
if unexpected_keys: | ||
logging.error(f"Unexpected keys: {unexpected_keys}") | ||
raise RuntimeError("Unexpected keys while loading checkpoint.") | ||
logging.info("Loaded checkpoint successfully.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import io | ||
import logging | ||
import threading | ||
import time | ||
from typing import List, Optional | ||
import cv2, torch | ||
import numpy as np | ||
from PIL import Image | ||
from pydantic import BaseModel | ||
from Sam2Wrapper import Sam2Wrapper | ||
from .interface import Pipeline | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class Sam2LiveParams(BaseModel): | ||
class Config: | ||
extra = "forbid" | ||
|
||
model_id: str = "facebook/sam2-hiera-tiny" | ||
point_coords: List[List[int]] = [[1, 1]] | ||
point_labels: List[int] = [1] | ||
show_points: bool = True | ||
|
||
def __init__(self, **data): | ||
super().__init__(**data) | ||
|
||
class Sam2Live(Pipeline): | ||
def __init__(self, **params): | ||
super().__init__(**params) | ||
self.pipe: Optional[Sam2Wrapper] = None | ||
self.first_frame = True | ||
self.update_params(**params) | ||
|
||
def update_params(self, **params): | ||
new_params = Sam2LiveParams(**params) | ||
|
||
logging.info(f"Setting parameters for sam2: {new_params}") | ||
self.pipe = Sam2Wrapper( | ||
model_id_or_path=new_params.model_id, | ||
point_coords=new_params.point_coords, | ||
point_labels=new_params.point_labels, | ||
show_points=new_params.show_points, | ||
device=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")), | ||
) | ||
|
||
self.params = new_params | ||
self.first_frame = True | ||
|
||
def _process_mask(self, mask: np.ndarray, frame_shape: tuple) -> np.ndarray: | ||
"""Process and resize mask if needed.""" | ||
if mask.shape[0] == 0: | ||
return np.zeros((frame_shape[0], frame_shape[1]), dtype="uint8") | ||
|
||
mask = (mask[0, 0] > 0).cpu().numpy().astype("uint8") * 255 | ||
if mask.shape[:2] != frame_shape[:2]: | ||
mask = cv2.resize(mask, (frame_shape[1], frame_shape[0])) | ||
return mask | ||
|
||
def process_frame(self, frame: Image.Image, **params) -> Image.Image: | ||
if params: | ||
self.update_params(**params) | ||
|
||
# Convert image formats | ||
frame_array = np.array(frame) | ||
frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGBA2BGR) | ||
if self.first_frame: | ||
self.pipe.predictor.load_first_frame(frame) | ||
|
||
for idx, point in enumerate(self.params.point_coords): | ||
_, _, mask_logits = self.pipe.predictor.add_new_prompt( | ||
frame_idx=0, | ||
obj_id=idx + 1, | ||
points=[point], | ||
labels=[self.params.point_labels[idx]] | ||
) | ||
self.first_frame = False | ||
else: | ||
_, mask_logits = self.pipe.predictor.track(frame) | ||
|
||
# Process mask and create overlay | ||
mask = self._process_mask(mask_logits, frame_bgr.shape) | ||
|
||
# Create an overlay by combining the original frame and the mask | ||
colored_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) | ||
colored_mask[mask > 0] = [255, 0, 255] # Add purple tint to the mask | ||
overlay = cv2.addWeighted(frame_bgr, 1, colored_mask, 1, 0) | ||
# Draw points on the overlay | ||
if self.params.show_points: | ||
for point in self.params.point_coords: | ||
cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) # Red dot | ||
|
||
# Convert back to PIL Image | ||
_, buffer = cv2.imencode('.jpg', overlay) | ||
result = Image.open(io.BytesIO(buffer.tobytes())) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
ARG BASE_IMAGE=livepeer/ai-runner:live-base | ||
FROM ${BASE_IMAGE} | ||
|
||
# Install required Python version | ||
ARG PYTHON_VERSION=3.10 | ||
RUN pyenv install $PYTHON_VERSION && \ | ||
pyenv global $PYTHON_VERSION && \ | ||
pyenv rehash | ||
|
||
# Upgrade pip and install required packages | ||
ARG PIP_VERSION=23.3.2 | ||
ENV PIP_PREFER_BINARY=1 | ||
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 | ||
|
||
# Install g++ compiler | ||
RUN apt-get update && apt-get install -y \ | ||
g++-11 \ | ||
&& apt-get clean && rm -rf /var/lib/apt/lists/* | ||
ENV CXX=/usr/bin/g++-11 | ||
|
||
# Install Sam2 dependencies | ||
RUN pip install --no-cache-dir torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 xformers==0.0.27.post2 zstd==1.5.5.1 | ||
|
||
RUN pip install --no-cache-dir huggingface-hub==0.23.2 ninja | ||
|
||
# Set TORCH_CUDA_ARCH_LIST environment variable, fixes build error in segment-anything-2-real-time | ||
ENV TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6+PTX" | ||
|
||
RUN pip install --no-cache-dir --no-build-isolation \ | ||
git+https://github.com/pschroedl/segment-anything-2-real-time@main | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters