Skip to content

Commit

Permalink
SAM2 video-to-video real-time pipeline (#280)
Browse files Browse the repository at this point in the history
Add SAM2 video-to-video real-time pipeline
  • Loading branch information
eliteprox authored Dec 3, 2024
1 parent ff220bb commit d11b114
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 5 deletions.
3 changes: 3 additions & 0 deletions runner/app/live/Sam2Wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .wrapper import Sam2Wrapper

__all__ = ["Sam2Wrapper"]
94 changes: 94 additions & 0 deletions runner/app/live/Sam2Wrapper/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
from typing import List, Optional, Tuple
from PIL import Image
import torch
from sam2.build_sam import build_sam2_camera_predictor
from omegaconf import OmegaConf
from hydra.utils import instantiate
from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
import os

MODEL_MAPPING = {
"facebook/sam2-hiera-tiny": {
"config": "sam2_hiera_t.yaml",
"checkpoint": "sam2_hiera_tiny.pt"
},
"facebook/sam2-hiera-small": {
"config": "sam2_hiera_s.yaml",
"checkpoint": "sam2_hiera_small.pt"
},
"facebook/sam2-hiera-large": {
"config": "sam2_hiera_l.yaml",
"checkpoint": "sam2_hiera_large.pt"
}
}

class Sam2Wrapper:
def __init__(
self,
model_id_or_path: str,
device: str,
**kwargs
):
self.device = device
self.model_id = model_id_or_path
if model_id_or_path in MODEL_MAPPING:
model_info = MODEL_MAPPING[model_id_or_path]
config_path = os.path.join("/models/sam2--checkpoints/", model_id_or_path.replace("/", "--"))
model_cfg = model_info['config']
sam2_checkpoint = f"{config_path}/{model_info['checkpoint']}"
else:
raise ValueError(f"Model ID {model_id_or_path} not supported")

logging.info(f"Initializing segment-anything-2 with model_id {self.model_id}")

torch.autocast(device_type=device, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Initialize Hydra to load the configuration
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()

# Code from sam2.build_sam.build_sam2_camera_predictor to appease Hydra
with initialize_config_dir(config_dir=config_path, version_base=None):
cfg = compose(config_name=model_cfg)

hydra_overrides = [
"++model._target_=sam2.sam2_camera_predictor.SAM2CameraPredictor",
]
hydra_overrides_extra = [
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
"++model.binarize_mask_from_pts_for_mem_enc=true",
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

cfg = compose(config_name=model_cfg, overrides=hydra_overrides)
OmegaConf.resolve(cfg)

#Load the model
model = instantiate(cfg.model, _recursive_=True)
load_checkpoint(model, sam2_checkpoint, self.device)
model.to(self.device)
model.eval()

# Set the model in memory
self.predictor = model

def load_checkpoint(model, ckpt_path, device):
if ckpt_path is not None:

sd = torch.load(ckpt_path, map_location=device)["model"]
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(f"Missing keys: {missing_keys}")
raise RuntimeError("Missing keys while loading checkpoint.")
if unexpected_keys:
logging.error(f"Unexpected keys: {unexpected_keys}")
raise RuntimeError("Unexpected keys while loading checkpoint.")
logging.info("Loaded checkpoint successfully.")
3 changes: 3 additions & 0 deletions runner/app/live/pipelines/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ def load_pipeline(name: str, **params) -> Pipeline:
elif name == "noop":
from .noop import Noop
return Noop(**params)
elif name == "segment_anything_2":
from .segment_anything_2 import Sam2Live
return Sam2Live(**params)
raise ValueError(f"Unknown pipeline: {name}")
96 changes: 96 additions & 0 deletions runner/app/live/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import io
import logging
import threading
import time
from typing import List, Optional
import cv2, torch
import numpy as np
from PIL import Image
from pydantic import BaseModel
from Sam2Wrapper import Sam2Wrapper
from .interface import Pipeline

logger = logging.getLogger(__name__)

class Sam2LiveParams(BaseModel):
class Config:
extra = "forbid"

model_id: str = "facebook/sam2-hiera-tiny"
point_coords: List[List[int]] = [[1, 1]]
point_labels: List[int] = [1]
show_points: bool = True

def __init__(self, **data):
super().__init__(**data)

class Sam2Live(Pipeline):
def __init__(self, **params):
super().__init__(**params)
self.pipe: Optional[Sam2Wrapper] = None
self.first_frame = True
self.update_params(**params)

def update_params(self, **params):
new_params = Sam2LiveParams(**params)

logging.info(f"Setting parameters for sam2: {new_params}")
self.pipe = Sam2Wrapper(
model_id_or_path=new_params.model_id,
point_coords=new_params.point_coords,
point_labels=new_params.point_labels,
show_points=new_params.show_points,
device=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
)

self.params = new_params
self.first_frame = True

def _process_mask(self, mask: np.ndarray, frame_shape: tuple) -> np.ndarray:
"""Process and resize mask if needed."""
if mask.shape[0] == 0:
return np.zeros((frame_shape[0], frame_shape[1]), dtype="uint8")

mask = (mask[0, 0] > 0).cpu().numpy().astype("uint8") * 255
if mask.shape[:2] != frame_shape[:2]:
mask = cv2.resize(mask, (frame_shape[1], frame_shape[0]))
return mask

def process_frame(self, frame: Image.Image, **params) -> Image.Image:
if params:
self.update_params(**params)

# Convert image formats
frame_array = np.array(frame)
frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGBA2BGR)
if self.first_frame:
self.pipe.predictor.load_first_frame(frame)

for idx, point in enumerate(self.params.point_coords):
_, _, mask_logits = self.pipe.predictor.add_new_prompt(
frame_idx=0,
obj_id=idx + 1,
points=[point],
labels=[self.params.point_labels[idx]]
)
self.first_frame = False
else:
_, mask_logits = self.pipe.predictor.track(frame)

# Process mask and create overlay
mask = self._process_mask(mask_logits, frame_bgr.shape)

# Create an overlay by combining the original frame and the mask
colored_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
colored_mask[mask > 0] = [255, 0, 255] # Add purple tint to the mask
overlay = cv2.addWeighted(frame_bgr, 1, colored_mask, 1, 0)
# Draw points on the overlay
if self.params.show_points:
for point in self.params.point_coords:
cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) # Red dot

# Convert back to PIL Image
_, buffer = cv2.imencode('.jpg', overlay)
result = Image.open(io.BytesIO(buffer.tobytes()))

return result
9 changes: 8 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ function download_live_models() {
huggingface-cli download stabilityai/sd-turbo --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download warmshao/FasterLivePortrait --local-dir models/FasterLivePortrait--checkpoints
huggingface-cli download yuvraj108c/Depth-Anything-Onnx --include depth_anything_vitl14.onnx --local-dir models/ComfyUI--models/Depth-Anything-Onnx
download_sam2_checkpoints
}

function download_sam2_checkpoints() {
huggingface-cli download facebook/sam2-hiera-tiny --local-dir models/sam2--checkpoints/facebook--sam2-hiera-tiny
huggingface-cli download facebook/sam2-hiera-small --local-dir models/sam2--checkpoints/facebook--sam2-hiera-small
huggingface-cli download facebook/sam2-hiera-large --local-dir models/sam2--checkpoints/facebook--sam2-hiera-large
}

function build_tensorrt_models() {
Expand Down Expand Up @@ -183,7 +190,7 @@ done
echo "Starting livepeer AI subnet model downloader..."
echo "Creating 'models' directory in the current working directory..."
mkdir -p models
mkdir -p models/StreamDiffusion--engines models/FasterLivePortrait--checkpoints models/ComfyUI--models
mkdir -p models/StreamDiffusion--engines models/FasterLivePortrait--checkpoints models/ComfyUI--models models/sam2--checkpoints

# Ensure 'huggingface-cli' is installed.
echo "Checking if 'huggingface-cli' is installed..."
Expand Down
31 changes: 31 additions & 0 deletions runner/docker/Dockerfile.live-base-segment_anything_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
ARG BASE_IMAGE=livepeer/ai-runner:live-base
FROM ${BASE_IMAGE}

# Install required Python version
ARG PYTHON_VERSION=3.10
RUN pyenv install $PYTHON_VERSION && \
pyenv global $PYTHON_VERSION && \
pyenv rehash

# Upgrade pip and install required packages
ARG PIP_VERSION=23.3.2
ENV PIP_PREFER_BINARY=1
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0

# Install g++ compiler
RUN apt-get update && apt-get install -y \
g++-11 \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
ENV CXX=/usr/bin/g++-11

# Install Sam2 dependencies
RUN pip install --no-cache-dir torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 xformers==0.0.27.post2 zstd==1.5.5.1

RUN pip install --no-cache-dir huggingface-hub==0.23.2 ninja

# Set TORCH_CUDA_ARCH_LIST environment variable, fixes build error in segment-anything-2-real-time
ENV TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6+PTX"

RUN pip install --no-cache-dir --no-build-isolation \
git+https://github.com/pschroedl/segment-anything-2-real-time@main

9 changes: 5 additions & 4 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ var pipelineToImage = map[string]string{
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"liveportrait": "livepeer/ai-runner:live-app-liveportrait",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"noop": "livepeer/ai-runner:live-app-noop",
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"liveportrait": "livepeer/ai-runner:live-app-liveportrait",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
Expand Down

0 comments on commit d11b114

Please sign in to comment.