From a1de16180a059124a0c7789e4b17936810eee17f Mon Sep 17 00:00:00 2001 From: Elite Encoder Date: Mon, 18 Nov 2024 16:57:03 -0500 Subject: [PATCH 1/3] squash sam-2 realtime --- runner/.devcontainer/devcontainer.json | 2 +- runner/Dockerfile.live-app__PIPELINE__ | 30 +++++ runner/app/live/Sam2Wrapper/__init__.py | 3 + runner/app/live/Sam2Wrapper/wrapper.py | 115 +++++++++++++++++ runner/app/live/pipelines/loader.py | 3 + .../app/live/pipelines/segment_anything_2.py | 120 ++++++++++++++++++ .../Dockerfile.live-base-segment_anything_2 | 35 +++++ 7 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 runner/Dockerfile.live-app__PIPELINE__ create mode 100644 runner/app/live/Sam2Wrapper/__init__.py create mode 100644 runner/app/live/Sam2Wrapper/wrapper.py create mode 100644 runner/app/live/pipelines/segment_anything_2.py create mode 100644 runner/docker/Dockerfile.live-base-segment_anything_2 diff --git a/runner/.devcontainer/devcontainer.json b/runner/.devcontainer/devcontainer.json index 902d66ca..8ad0be8a 100644 --- a/runner/.devcontainer/devcontainer.json +++ b/runner/.devcontainer/devcontainer.json @@ -4,8 +4,8 @@ "name": "ai-runner", // Image to use for the dev container. More info: https://containers.dev/guide/dockerfile. "build": { + // "dockerfile": "../Dockerfile.live-base-segment_anything_2" "dockerfile": "../Dockerfile", - // "dockerfile": "../docker/Dockerfile.text_to_speech", "context": ".." }, "runArgs": [ diff --git a/runner/Dockerfile.live-app__PIPELINE__ b/runner/Dockerfile.live-app__PIPELINE__ new file mode 100644 index 00000000..b5fdf981 --- /dev/null +++ b/runner/Dockerfile.live-app__PIPELINE__ @@ -0,0 +1,30 @@ +ARG PIPELINE=streamdiffusion +ARG BASE_IMAGE=livepeer/ai-runner:live-base-${PIPELINE} +FROM ${BASE_IMAGE} + +# Install latest stable Go version and system dependencies +RUN apt-get update && apt-get install -y \ + wget \ + libcairo2-dev \ + libgirepository1.0-dev \ + pkg-config \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Install any additional Python packages +COPY requirements.live-ai.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +# Set environment variables +ENV MAX_WORKERS=1 +ENV HUGGINGFACE_HUB_CACHE=/models +ENV DIFFUSERS_CACHE=/models +ENV MODEL_DIR=/models + +# Copy application files +COPY app/ /app/app +COPY images/ /app/images +COPY bench.py /app/bench.py + +WORKDIR /app + +CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"] diff --git a/runner/app/live/Sam2Wrapper/__init__.py b/runner/app/live/Sam2Wrapper/__init__.py new file mode 100644 index 00000000..07af6266 --- /dev/null +++ b/runner/app/live/Sam2Wrapper/__init__.py @@ -0,0 +1,3 @@ +from .wrapper import Sam2Wrapper + +__all__ = ["Sam2Wrapper"] diff --git a/runner/app/live/Sam2Wrapper/wrapper.py b/runner/app/live/Sam2Wrapper/wrapper.py new file mode 100644 index 00000000..9df4b75a --- /dev/null +++ b/runner/app/live/Sam2Wrapper/wrapper.py @@ -0,0 +1,115 @@ +# Copied from StreamDiffusion/utils/wrapper.py +import logging +from typing import List, Optional, Tuple + +import torch +from sam2.build_sam import build_sam2_camera_predictor +from omegaconf import OmegaConf +from hydra.utils import instantiate +from hydra import initialize_config_dir, compose +from hydra.core.global_hydra import GlobalHydra + +from PIL import Image +from typing import Optional, Dict + +MODEL_MAPPING = { + "facebook/sam2-hiera-tiny": { + "config": "sam2_hiera_t.yaml", + "checkpoint": "sam2_hiera_tiny.pt" + }, + "facebook/sam2-hiera-small": { + "config": "sam2_hiera_s.yaml", + "checkpoint": "sam2_hiera_small.pt" + }, + "facebook/sam2-hiera-base": { + "config": "sam2_hiera_b.yaml", + "checkpoint": "sam2_hiera_base.pt" + }, + "facebook/sam2-hiera-large": { + "config": "sam2_hiera_l.yaml", + "checkpoint": "sam2_hiera_large.pt" + } +} + +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() +if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# Initialize Hydra to load the configuration +if GlobalHydra.instance().is_initialized(): + GlobalHydra.instance().clear() + +config_path = "/workspaces/ai-worker/runner/models/sam2_configs" +sam2_checkpoint = "/models/checkpoints/sam2_hiera_tiny.pt" +model_cfg = "sam2_hiera_t.yaml" + +class Sam2Wrapper: + def __init__( + self, + model_id_or_path: str, + device: Optional[str] = None, + **kwargs + ): + self.model_id = model_id_or_path + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # Code ripped out of sam2.build_sam.build_sam2_camera_predictor to appease Hydra + with initialize_config_dir(config_dir=config_path, version_base=None): + cfg = compose(config_name=model_cfg) + + hydra_overrides = [ + "++model._target_=sam2.sam2_camera_predictor.SAM2CameraPredictor", + ] + hydra_overrides_extra = [ + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + "++model.binarize_mask_from_pts_for_mem_enc=true", + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + cfg = compose(config_name=model_cfg, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + + #Load the model + model = instantiate(cfg.model, _recursive_=True) + load_checkpoint(model, sam2_checkpoint, self.device) + model.to(self.device) + model.eval() + + # Set the model in memory + self.predictor = model + + def __call__( + self, + image: Image.Image, + **kwargs + ) -> Tuple[List[Image.Image], List[Optional[bool]]]: + pass + + def __str__(self) -> str: + return f"Sam2Wrapper model_id={self.model_id}" + + +def load_checkpoint(model, ckpt_path, device): + if ckpt_path is not None: + + sd = torch.load(ckpt_path, map_location=device)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(f"Missing keys: {missing_keys}") + raise RuntimeError("Missing keys while loading checkpoint.") + if unexpected_keys: + logging.error(f"Unexpected keys: {unexpected_keys}") + raise RuntimeError("Unexpected keys while loading checkpoint.") + logging.info("Loaded checkpoint successfully.") + +def get_torch_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") \ No newline at end of file diff --git a/runner/app/live/pipelines/loader.py b/runner/app/live/pipelines/loader.py index a656279d..bab137be 100644 --- a/runner/app/live/pipelines/loader.py +++ b/runner/app/live/pipelines/loader.py @@ -13,4 +13,7 @@ def load_pipeline(name: str, **params) -> Pipeline: elif name == "noop": from .noop import Noop return Noop(**params) + elif name == "segment_anything_2": + from .segment_anything_2 import Sam2Live + return Sam2Live(**params) raise ValueError(f"Unknown pipeline: {name}") diff --git a/runner/app/live/pipelines/segment_anything_2.py b/runner/app/live/pipelines/segment_anything_2.py new file mode 100644 index 00000000..d6025ee8 --- /dev/null +++ b/runner/app/live/pipelines/segment_anything_2.py @@ -0,0 +1,120 @@ +import io +import logging +import threading +import time +from typing import List, Optional +import cv2 +import numpy as np +from PIL import Image +from pydantic import BaseModel +from Sam2Wrapper import Sam2Wrapper +from .interface import Pipeline + +logger = logging.getLogger(__name__) + +class Sam2LiveParams(BaseModel): + class Config: + extra = "forbid" + + model_id: str = "facebook/sam2-hiera-tiny" + point_coords: List[List[int]] = [[1, 1]] + point_labels: List[int] = [1] + show_point: bool = False + + def __init__(self, **data): + super().__init__(**data) + +class Sam2Live(Pipeline): + def __init__(self, **params): + super().__init__(**params) + self.pipe: Optional[Sam2Wrapper] = None + self.first_frame = True + self.update_params(**params) + + def update_params(self, **params): + new_params = Sam2LiveParams(**params) + self.params = new_params + + logging.info(f"Setting parameters for sam2") + self.pipe = Sam2Wrapper( + model_id_or_path=self.params.model_id, + point_coords=self.params.point_coords, + point_labels=self.params.point_labels, + show_point=self.params.point_labels, + # Add additional as needed + ) + + self.params = new_params + self.first_frame = True + + def _process_mask(self, mask: np.ndarray, frame_shape: tuple) -> np.ndarray: + """Process and resize mask if needed.""" + if mask.shape[0] == 0: + return np.zeros((frame_shape[0], frame_shape[1]), dtype="uint8") + + mask = (mask[0, 0] > 0).cpu().numpy().astype("uint8") * 255 + if mask.shape[:2] != frame_shape[:2]: + mask = cv2.resize(mask, (frame_shape[1], frame_shape[0])) + return mask + + def process_frame(self, frame: Image.Image, **params) -> Image.Image: + frame_lock = threading.Lock() + start_time = time.time() + try: + if params: + self.update_params(**params) + + # Convert image formats + t0 = time.time() + frame_array = np.array(frame) + frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGBA2BGR) + logger.debug(f"Image conversion took {(time.time() - t0)*1000:.2f}ms") + + if self.first_frame: + t0 = time.time() + self.pipe.predictor.load_first_frame(frame) + + for idx, point in enumerate(self.params.point_coords): + _, _, mask_logits = self.pipe.predictor.add_new_prompt( + frame_idx=0, + obj_id=idx + 1, + points=[point], + labels=[self.params.point_labels[idx]] + ) + logger.debug(f"First frame processing took {(time.time() - t0)*1000:.2f}ms") + self.first_frame = False + else: + t0 = time.time() + _, mask_logits = self.pipe.predictor.track(frame) + logger.debug(f"Frame tracking took {(time.time() - t0)*1000:.2f}ms") + + # Process mask and create overlay + t0 = time.time() + mask = self._process_mask(mask_logits, frame_bgr.shape) + logger.debug(f"Mask processing took {(time.time() - t0)*1000:.2f}ms") + + # Create an overlay by combining the original frame and the mask + colored_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + colored_mask[mask > 0] = [255, 0, 255] # BGR format: # Add a purple tint to the mask + overlay = cv2.addWeighted(frame_bgr, 1, colored_mask, 1, 0) + + # Draw points on the overlay + if hasattr(self.params, 'show_point') and self.params.show_point: + for point in self.params.point_coords: + cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) # Red dot + + # Convert back to PIL Image + t0 = time.time() + with frame_lock: + _, buffer = cv2.imencode('.jpg', overlay) + result = Image.open(io.BytesIO(buffer.tobytes())) + logger.debug(f"Final conversion took {(time.time() - t0)*1000:.2f}ms") + + total_time = time.time() - start_time + logger.debug(f"Total frame processing time: {total_time*1000:.2f}ms") + return result + + except Exception as e: + logger.error(f"Error processing frame: {str(e)}") + return Image.new("RGB", frame.size, (255, 255, 255)) \ No newline at end of file diff --git a/runner/docker/Dockerfile.live-base-segment_anything_2 b/runner/docker/Dockerfile.live-base-segment_anything_2 new file mode 100644 index 00000000..346dd01c --- /dev/null +++ b/runner/docker/Dockerfile.live-base-segment_anything_2 @@ -0,0 +1,35 @@ +ARG BASE_IMAGE=livepeer/ai-runner:live-base +FROM ${BASE_IMAGE} + +# Install required Python version +ARG PYTHON_VERSION=3.10 +RUN pyenv install $PYTHON_VERSION && \ + pyenv global $PYTHON_VERSION && \ + pyenv rehash + +# Upgrade pip and install required packages +ARG PIP_VERSION=23.3.2 +ENV PIP_PREFER_BINARY=1 +RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 + +# Install g++ compiler +RUN apt-get update && apt-get install -y \ + g++-11 \ + && apt-get clean && rm -rf /var/lib/apt/lists/* +ENV CXX=/usr/bin/g++-11 + +# Install Sam2 dependencies +RUN pip install --no-cache-dir torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 xformers==0.0.27.post2 zstd==1.5.5.1 + +RUN pip install huggingface-hub==0.23.2 ninja + +# Set TORCH_CUDA_ARCH_LIST environment variable, fixes build error in segment-anything-2-real-time +ENV TORCH_CUDA_ARCH_LIST="6.0 7.0 7.5 8.0 8.6+PTX" + +RUN pip install --no-cache-dir --no-build-isolation \ + git+https://github.com/eliteprox/segment-anything-2-real-time@main + # git+https://github.com/pschroedl/segment-anything-2-real-time@notebook_test + +# Set environment variables for NVIDIA drivers +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,video From 2bd341e4c35ea4d1bb16ddfd4a924f4d1b86b843 Mon Sep 17 00:00:00 2001 From: Elite Encoder Date: Tue, 19 Nov 2024 22:18:56 +0000 Subject: [PATCH 2/3] add support to track multiple objects --- runner/app/live/Sam2Wrapper/wrapper.py | 10 +- .../app/live/pipelines/segment_anything_2.py | 133 ++++++++++++------ 2 files changed, 94 insertions(+), 49 deletions(-) diff --git a/runner/app/live/Sam2Wrapper/wrapper.py b/runner/app/live/Sam2Wrapper/wrapper.py index 9df4b75a..df1511df 100644 --- a/runner/app/live/Sam2Wrapper/wrapper.py +++ b/runner/app/live/Sam2Wrapper/wrapper.py @@ -1,17 +1,13 @@ -# Copied from StreamDiffusion/utils/wrapper.py import logging -from typing import List, Optional, Tuple - +from typing import List, Optional, Tuple, Dict +from PIL import Image import torch -from sam2.build_sam import build_sam2_camera_predictor +from sam2.build_sam import build_sam2_camera_predictor #import required by hydra even if it appears unused from omegaconf import OmegaConf from hydra.utils import instantiate from hydra import initialize_config_dir, compose from hydra.core.global_hydra import GlobalHydra -from PIL import Image -from typing import Optional, Dict - MODEL_MAPPING = { "facebook/sam2-hiera-tiny": { "config": "sam2_hiera_t.yaml", diff --git a/runner/app/live/pipelines/segment_anything_2.py b/runner/app/live/pipelines/segment_anything_2.py index d6025ee8..68584912 100644 --- a/runner/app/live/pipelines/segment_anything_2.py +++ b/runner/app/live/pipelines/segment_anything_2.py @@ -17,9 +17,11 @@ class Config: extra = "forbid" model_id: str = "facebook/sam2-hiera-tiny" - point_coords: List[List[int]] = [[1, 1]] - point_labels: List[int] = [1] + point_coords: Optional[List[List[int]]] = [[1,1]] + point_labels: Optional[List[int]] = [1] + obj_ids: Optional[List[int]] = [1] show_point: bool = False + show_overlay: bool = True def __init__(self, **data): super().__init__(**data) @@ -32,87 +34,134 @@ def __init__(self, **params): self.update_params(**params) def update_params(self, **params): + # Only update point coordinates and labels if both are provided + if ('point_coords' in params) != ('point_labels' in params): + raise ValueError("Both point_coords and point_labels must be updated together") + + # Preserve existing values if neither is provided + if hasattr(self, 'params'): + if 'point_coords' not in params and hasattr(self.params, 'point_coords'): + params["point_coords"] = self.params.point_coords + + if 'point_labels' not in params and hasattr(self.params, 'point_labels'): + params["point_labels"] = self.params.point_labels + new_params = Sam2LiveParams(**params) self.params = new_params + + #TODO: Only reload the model if the point, label coordinates or model has changed + self.first_frame = True logging.info(f"Setting parameters for sam2") self.pipe = Sam2Wrapper( model_id_or_path=self.params.model_id, point_coords=self.params.point_coords, point_labels=self.params.point_labels, - show_point=self.params.point_labels, + obj_ids=self.params.obj_ids, + show_point=self.params.show_point, + show_overlay=self.params.show_overlay + # Add additional as needed ) - self.params = new_params - self.first_frame = True - def _process_mask(self, mask: np.ndarray, frame_shape: tuple) -> np.ndarray: + def _process_mask(self, mask: np.ndarray, frame_shape: tuple, color: list[list[int]]) -> np.ndarray: """Process and resize mask if needed.""" + logger.info(f"Mask input shape: {mask.shape}") if mask.shape[0] == 0: + logger.warning("Empty mask received") return np.zeros((frame_shape[0], frame_shape[1]), dtype="uint8") + + colors = [ + [255, 0, 255], # Purple + [0, 255, 255], # Yellow + [255, 255, 0], # Cyan + [0, 255, 0], # Green + [255, 0, 0], # Blue + ] + + # Initialize the combined colored mask with alpha channel + combined_colored_mask = np.zeros((frame_shape[0], frame_shape[1], 4), dtype="uint8") + + # Process each mask + for i in range(mask.shape[0]): + current_mask = (mask[i, 0] > 0).cpu().numpy().astype("uint8") * 255 + if current_mask.shape[:2] != frame_shape[:2]: + current_mask = cv2.resize(current_mask, (frame_shape[1], frame_shape[0])) - mask = (mask[0, 0] > 0).cpu().numpy().astype("uint8") * 255 - if mask.shape[:2] != frame_shape[:2]: - mask = cv2.resize(mask, (frame_shape[1], frame_shape[0])) - return mask + # Create BGRA mask with transparency + colored_mask = np.zeros((frame_shape[0], frame_shape[1], 4), dtype="uint8") + color = colors[i % len(colors)] + colored_mask[current_mask > 0] = color + [128] # Add alpha value of 128 + + # Alpha blend with existing masks + alpha = colored_mask[:, :, 3:4] / 255.0 + combined_colored_mask = (1 - alpha) * combined_colored_mask + alpha * colored_mask + + # Convert back to BGR for display + combined_colored_mask = combined_colored_mask[:, :, :3].astype("uint8") + return combined_colored_mask def process_frame(self, frame: Image.Image, **params) -> Image.Image: frame_lock = threading.Lock() - start_time = time.time() try: if params: self.update_params(**params) # Convert image formats - t0 = time.time() frame_array = np.array(frame) frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGBA2BGR) - logger.debug(f"Image conversion took {(time.time() - t0)*1000:.2f}ms") - if self.first_frame: - t0 = time.time() self.pipe.predictor.load_first_frame(frame) - for idx, point in enumerate(self.params.point_coords): - _, _, mask_logits = self.pipe.predictor.add_new_prompt( - frame_idx=0, - obj_id=idx + 1, - points=[point], - labels=[self.params.point_labels[idx]] - ) - logger.debug(f"First frame processing took {(time.time() - t0)*1000:.2f}ms") + # For each obj_id, add a new point and label + for idx, obj_id in enumerate(self.params.obj_ids): + if self.params.point_coords and self.params.point_labels: + point = self.params.point_coords[idx] + label = self.params.point_labels[idx] + _, _, mask_logits = self.pipe.predictor.add_new_prompt( + frame_idx=0, + obj_id=obj_id, + points=[point], + labels=[label] + ) + + # logger.info(f"First frame mask_logits shape: {mask_logits.shape}") self.first_frame = False else: - t0 = time.time() - _, mask_logits = self.pipe.predictor.track(frame) - logger.debug(f"Frame tracking took {(time.time() - t0)*1000:.2f}ms") - - # Process mask and create overlay - t0 = time.time() - mask = self._process_mask(mask_logits, frame_bgr.shape) - logger.debug(f"Mask processing took {(time.time() - t0)*1000:.2f}ms") - - # Create an overlay by combining the original frame and the mask - colored_mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + out_obj_ids, mask_logits = self.pipe.predictor.track(frame) + # logger.info(f"Tracking mask_logits shape: {mask_logits.shape}") - colored_mask[mask > 0] = [255, 0, 255] # BGR format: # Add a purple tint to the mask - overlay = cv2.addWeighted(frame_bgr, 1, colored_mask, 1, 0) + # Initialize overlay with original frame + overlay = frame_bgr.copy() - # Draw points on the overlay + # Only apply mask overlay if show_overlay is True + if self.params.show_overlay: + # Loop through each object ID and apply the corresponding mask + for i, obj_id in enumerate(out_obj_ids): + logger.info(f"Processing mask for object ID: {obj_id}") + single_mask_logits = mask_logits[i:i+1] + colors = [ + [255, 0, 255], # Purple + [0, 255, 255], # Yellow + [255, 255, 0], # Cyan + [0, 255, 0], # Green + [255, 0, 0], # Blue + ] + + colored_mask = self._process_mask(single_mask_logits, frame_bgr.shape, colors[i]) + overlay = cv2.addWeighted(overlay, 0.7, colored_mask, 0.3, 0) + + # Draw points on the overlay if needed if hasattr(self.params, 'show_point') and self.params.show_point: for point in self.params.point_coords: - cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) # Red dot + cv2.circle(overlay, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) # Convert back to PIL Image - t0 = time.time() with frame_lock: _, buffer = cv2.imencode('.jpg', overlay) result = Image.open(io.BytesIO(buffer.tobytes())) - logger.debug(f"Final conversion took {(time.time() - t0)*1000:.2f}ms") - total_time = time.time() - start_time - logger.debug(f"Total frame processing time: {total_time*1000:.2f}ms") return result except Exception as e: From 3792855bef0554539281426c541350d062827fff Mon Sep 17 00:00:00 2001 From: Elite Encoder Date: Wed, 20 Nov 2024 23:43:08 -0500 Subject: [PATCH 3/3] add live container name for sam2-pipeline --- worker/docker.go | 1 + 1 file changed, 1 insertion(+) diff --git a/worker/docker.go b/worker/docker.go index 46a8cfb4..2af7b497 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -54,6 +54,7 @@ var livePipelineToImage = map[string]string{ "streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion", "liveportrait": "livepeer/ai-runner:live-app-liveportrait", "comfyui": "livepeer/ai-runner:live-app-comfyui", + "segment_anything_2" : "livepeer/ai-runner:live-app-segment_anything_2", } type DockerManager struct {