Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM2 video-to-video real-time pipeline #280

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runner/.devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"name": "ai-runner",
// Image to use for the dev container. More info: https://containers.dev/guide/dockerfile.
"build": {
// "dockerfile": "../Dockerfile.live-base-segment_anything_2"
"dockerfile": "../Dockerfile",
// "dockerfile": "../docker/Dockerfile.text_to_speech",
"context": ".."
},
"runArgs": [
Expand Down
30 changes: 30 additions & 0 deletions runner/Dockerfile.live-app__PIPELINE__
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
ARG PIPELINE=streamdiffusion
ARG BASE_IMAGE=livepeer/ai-runner:live-base-${PIPELINE}
FROM ${BASE_IMAGE}

# Install latest stable Go version and system dependencies
RUN apt-get update && apt-get install -y \
wget \
libcairo2-dev \
libgirepository1.0-dev \
pkg-config \
&& apt-get clean && rm -rf /var/lib/apt/lists/*

# Install any additional Python packages
COPY requirements.live-ai.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt

# Set environment variables
ENV MAX_WORKERS=1
ENV HUGGINGFACE_HUB_CACHE=/models
ENV DIFFUSERS_CACHE=/models
ENV MODEL_DIR=/models

# Copy application files
COPY app/ /app/app
COPY images/ /app/images
COPY bench.py /app/bench.py

WORKDIR /app

CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"]
3 changes: 3 additions & 0 deletions runner/app/live/Sam2Wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .wrapper import Sam2Wrapper

__all__ = ["Sam2Wrapper"]
111 changes: 111 additions & 0 deletions runner/app/live/Sam2Wrapper/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging
from typing import List, Optional, Tuple, Dict
from PIL import Image
import torch
from sam2.build_sam import build_sam2_camera_predictor #import required by hydra even if it appears unused
from omegaconf import OmegaConf
from hydra.utils import instantiate
from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra

MODEL_MAPPING = {
"facebook/sam2-hiera-tiny": {
"config": "sam2_hiera_t.yaml",
"checkpoint": "sam2_hiera_tiny.pt"
},
"facebook/sam2-hiera-small": {
"config": "sam2_hiera_s.yaml",
"checkpoint": "sam2_hiera_small.pt"
},
"facebook/sam2-hiera-base": {
"config": "sam2_hiera_b.yaml",
"checkpoint": "sam2_hiera_base.pt"
},
"facebook/sam2-hiera-large": {
"config": "sam2_hiera_l.yaml",
"checkpoint": "sam2_hiera_large.pt"
}
}

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Initialize Hydra to load the configuration
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()

config_path = "/workspaces/ai-worker/runner/models/sam2_configs"
sam2_checkpoint = "/models/checkpoints/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

class Sam2Wrapper:
def __init__(
self,
model_id_or_path: str,
device: Optional[str] = None,
**kwargs
):
self.model_id = model_id_or_path
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

# Code ripped out of sam2.build_sam.build_sam2_camera_predictor to appease Hydra
with initialize_config_dir(config_dir=config_path, version_base=None):
cfg = compose(config_name=model_cfg)

hydra_overrides = [
"++model._target_=sam2.sam2_camera_predictor.SAM2CameraPredictor",
]
hydra_overrides_extra = [
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
"++model.binarize_mask_from_pts_for_mem_enc=true",
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

cfg = compose(config_name=model_cfg, overrides=hydra_overrides)
OmegaConf.resolve(cfg)

#Load the model
model = instantiate(cfg.model, _recursive_=True)
load_checkpoint(model, sam2_checkpoint, self.device)
model.to(self.device)
model.eval()

# Set the model in memory
self.predictor = model

def __call__(
self,
image: Image.Image,
**kwargs
) -> Tuple[List[Image.Image], List[Optional[bool]]]:
pass

def __str__(self) -> str:
return f"Sam2Wrapper model_id={self.model_id}"


def load_checkpoint(model, ckpt_path, device):
if ckpt_path is not None:

sd = torch.load(ckpt_path, map_location=device)["model"]
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(f"Missing keys: {missing_keys}")
raise RuntimeError("Missing keys while loading checkpoint.")
if unexpected_keys:
logging.error(f"Unexpected keys: {unexpected_keys}")
raise RuntimeError("Unexpected keys while loading checkpoint.")
logging.info("Loaded checkpoint successfully.")

def get_torch_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
3 changes: 3 additions & 0 deletions runner/app/live/pipelines/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ def load_pipeline(name: str, **params) -> Pipeline:
elif name == "noop":
from .noop import Noop
return Noop(**params)
elif name == "segment_anything_2":
from .segment_anything_2 import Sam2Live
return Sam2Live(**params)
raise ValueError(f"Unknown pipeline: {name}")
169 changes: 169 additions & 0 deletions runner/app/live/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import io
import logging
import threading
import time
from typing import List, Optional
import cv2
import numpy as np
from PIL import Image
from pydantic import BaseModel
from Sam2Wrapper import Sam2Wrapper
from .interface import Pipeline

logger = logging.getLogger(__name__)

class Sam2LiveParams(BaseModel):
class Config:
extra = "forbid"

model_id: str = "facebook/sam2-hiera-tiny"
point_coords: Optional[List[List[int]]] = [[1,1]]
point_labels: Optional[List[int]] = [1]
obj_ids: Optional[List[int]] = [1]
show_point: bool = False
show_overlay: bool = True

def __init__(self, **data):
super().__init__(**data)

class Sam2Live(Pipeline):
def __init__(self, **params):
super().__init__(**params)
self.pipe: Optional[Sam2Wrapper] = None
self.first_frame = True
self.update_params(**params)

def update_params(self, **params):
# Only update point coordinates and labels if both are provided
if ('point_coords' in params) != ('point_labels' in params):
raise ValueError("Both point_coords and point_labels must be updated together")

# Preserve existing values if neither is provided
if hasattr(self, 'params'):
if 'point_coords' not in params and hasattr(self.params, 'point_coords'):
params["point_coords"] = self.params.point_coords

if 'point_labels' not in params and hasattr(self.params, 'point_labels'):
params["point_labels"] = self.params.point_labels

new_params = Sam2LiveParams(**params)
self.params = new_params

#TODO: Only reload the model if the point, label coordinates or model has changed
self.first_frame = True

logging.info(f"Setting parameters for sam2")
self.pipe = Sam2Wrapper(
model_id_or_path=self.params.model_id,
point_coords=self.params.point_coords,
point_labels=self.params.point_labels,
obj_ids=self.params.obj_ids,
show_point=self.params.show_point,
show_overlay=self.params.show_overlay

# Add additional as needed
)


def _process_mask(self, mask: np.ndarray, frame_shape: tuple, color: list[list[int]]) -> np.ndarray:
"""Process and resize mask if needed."""
logger.info(f"Mask input shape: {mask.shape}")
if mask.shape[0] == 0:
logger.warning("Empty mask received")
return np.zeros((frame_shape[0], frame_shape[1]), dtype="uint8")

colors = [
[255, 0, 255], # Purple
[0, 255, 255], # Yellow
[255, 255, 0], # Cyan
[0, 255, 0], # Green
[255, 0, 0], # Blue
]

# Initialize the combined colored mask with alpha channel
combined_colored_mask = np.zeros((frame_shape[0], frame_shape[1], 4), dtype="uint8")

# Process each mask
for i in range(mask.shape[0]):
current_mask = (mask[i, 0] > 0).cpu().numpy().astype("uint8") * 255
if current_mask.shape[:2] != frame_shape[:2]:
current_mask = cv2.resize(current_mask, (frame_shape[1], frame_shape[0]))

# Create BGRA mask with transparency
colored_mask = np.zeros((frame_shape[0], frame_shape[1], 4), dtype="uint8")
color = colors[i % len(colors)]
colored_mask[current_mask > 0] = color + [128] # Add alpha value of 128

# Alpha blend with existing masks
alpha = colored_mask[:, :, 3:4] / 255.0
combined_colored_mask = (1 - alpha) * combined_colored_mask + alpha * colored_mask

# Convert back to BGR for display
combined_colored_mask = combined_colored_mask[:, :, :3].astype("uint8")
return combined_colored_mask

def process_frame(self, frame: Image.Image, **params) -> Image.Image:
frame_lock = threading.Lock()
try:
if params:
self.update_params(**params)

# Convert image formats
frame_array = np.array(frame)
frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGBA2BGR)
if self.first_frame:
self.pipe.predictor.load_first_frame(frame)

# For each obj_id, add a new point and label
for idx, obj_id in enumerate(self.params.obj_ids):
if self.params.point_coords and self.params.point_labels:
point = self.params.point_coords[idx]
label = self.params.point_labels[idx]
_, _, mask_logits = self.pipe.predictor.add_new_prompt(
frame_idx=0,
obj_id=obj_id,
points=[point],
labels=[label]
)

# logger.info(f"First frame mask_logits shape: {mask_logits.shape}")
self.first_frame = False
else:
out_obj_ids, mask_logits = self.pipe.predictor.track(frame)
# logger.info(f"Tracking mask_logits shape: {mask_logits.shape}")

# Initialize overlay with original frame
overlay = frame_bgr.copy()

# Only apply mask overlay if show_overlay is True
if self.params.show_overlay:
# Loop through each object ID and apply the corresponding mask
for i, obj_id in enumerate(out_obj_ids):
logger.info(f"Processing mask for object ID: {obj_id}")
single_mask_logits = mask_logits[i:i+1]
colors = [
[255, 0, 255], # Purple
[0, 255, 255], # Yellow
[255, 255, 0], # Cyan
[0, 255, 0], # Green
[255, 0, 0], # Blue
]

colored_mask = self._process_mask(single_mask_logits, frame_bgr.shape, colors[i])
overlay = cv2.addWeighted(overlay, 0.7, colored_mask, 0.3, 0)

# Draw points on the overlay if needed
if hasattr(self.params, 'show_point') and self.params.show_point:
for point in self.params.point_coords:
cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1)

# Convert back to PIL Image
with frame_lock:
_, buffer = cv2.imencode('.jpg', overlay)
result = Image.open(io.BytesIO(buffer.tobytes()))

return result

except Exception as e:
logger.error(f"Error processing frame: {str(e)}")
return Image.new("RGB", frame.size, (255, 255, 255))
35 changes: 35 additions & 0 deletions runner/docker/Dockerfile.live-base-segment_anything_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
ARG BASE_IMAGE=livepeer/ai-runner:live-base
FROM ${BASE_IMAGE}

# Install required Python version
ARG PYTHON_VERSION=3.10
RUN pyenv install $PYTHON_VERSION && \
pyenv global $PYTHON_VERSION && \
pyenv rehash

# Upgrade pip and install required packages
ARG PIP_VERSION=23.3.2
ENV PIP_PREFER_BINARY=1
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0

# Install g++ compiler
RUN apt-get update && apt-get install -y \
g++-11 \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
ENV CXX=/usr/bin/g++-11

# Install Sam2 dependencies
RUN pip install --no-cache-dir torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 xformers==0.0.27.post2 zstd==1.5.5.1

RUN pip install huggingface-hub==0.23.2 ninja

# Set TORCH_CUDA_ARCH_LIST environment variable, fixes build error in segment-anything-2-real-time
ENV TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6+PTX"

RUN pip install --no-cache-dir --no-build-isolation \
git+https://github.com/eliteprox/segment-anything-2-real-time@main
# git+https://github.com/pschroedl/segment-anything-2-real-time@notebook_test

# Set environment variables for NVIDIA drivers
ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,video
1 change: 1 addition & 0 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"liveportrait": "livepeer/ai-runner:live-app-liveportrait",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2" : "livepeer/ai-runner:live-app-segment_anything_2",
}

type DockerManager struct {
Expand Down