"""SAM 2 engine wrapper with lazy loading and explicit runtime status.""" import logging import os from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np from config import settings logger = logging.getLogger(__name__) DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny" @dataclass(frozen=True) class SAM2Variant: """One selectable SAM 2.1 runtime variant.""" id: str label: str short_label: str config: str legacy_config: str checkpoint_filename: str legacy_checkpoint_filename: str SAM2_VARIANTS: dict[str, SAM2Variant] = { "sam2.1_hiera_tiny": SAM2Variant( id="sam2.1_hiera_tiny", label="SAM 2.1 Tiny", short_label="tiny", config="configs/sam2.1/sam2.1_hiera_t.yaml", legacy_config="configs/sam2/sam2_hiera_t.yaml", checkpoint_filename="sam2.1_hiera_tiny.pt", legacy_checkpoint_filename="sam2_hiera_tiny.pt", ), "sam2.1_hiera_small": SAM2Variant( id="sam2.1_hiera_small", label="SAM 2.1 Small", short_label="small", config="configs/sam2.1/sam2.1_hiera_s.yaml", legacy_config="configs/sam2/sam2_hiera_s.yaml", checkpoint_filename="sam2.1_hiera_small.pt", legacy_checkpoint_filename="sam2_hiera_small.pt", ), "sam2.1_hiera_base_plus": SAM2Variant( id="sam2.1_hiera_base_plus", label="SAM 2.1 Base+", short_label="base+", config="configs/sam2.1/sam2.1_hiera_b+.yaml", legacy_config="configs/sam2/sam2_hiera_b+.yaml", checkpoint_filename="sam2.1_hiera_base_plus.pt", legacy_checkpoint_filename="sam2_hiera_base_plus.pt", ), "sam2.1_hiera_large": SAM2Variant( id="sam2.1_hiera_large", label="SAM 2.1 Large", short_label="large", config="configs/sam2.1/sam2.1_hiera_l.yaml", legacy_config="configs/sam2/sam2_hiera_l.yaml", checkpoint_filename="sam2.1_hiera_large.pt", legacy_checkpoint_filename="sam2_hiera_large.pt", ), } SAM2_MODEL_ALIASES = { "sam2": DEFAULT_SAM2_MODEL_ID, "sam2.1": DEFAULT_SAM2_MODEL_ID, "sam2_tiny": DEFAULT_SAM2_MODEL_ID, } # --------------------------------------------------------------------------- # Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable. # --------------------------------------------------------------------------- try: import torch TORCH_AVAILABLE = True except Exception as exc: # noqa: BLE001 TORCH_AVAILABLE = False torch = None # type: ignore[assignment] logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc) try: from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor SAM2_AVAILABLE = True logger.info("SAM2 library imported successfully.") except Exception as exc: # noqa: BLE001 SAM2_AVAILABLE = False logger.warning("SAM2 import failed (%s). Using stub engine.", exc) class SAM2Engine: """Lazy-loaded SAM 2 inference engine.""" def __init__(self) -> None: self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {} self._video_predictors: dict[str, object | None] = {} self._model_loaded: dict[str, bool] = {} self._video_model_loaded: dict[str, bool] = {} self._loaded_device: dict[str, str] = {} self._last_error: dict[str, str | None] = {} self._video_last_error: dict[str, str | None] = {} # ----------------------------------------------------------------------- # Internal helpers # ----------------------------------------------------------------------- def variant_ids(self) -> list[str]: return list(SAM2_VARIANTS.keys()) def normalize_model_id(self, model_id: str | None) -> str: selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower() selected = SAM2_MODEL_ALIASES.get(selected, selected) if selected not in SAM2_VARIANTS: raise ValueError(f"Unsupported SAM2 model: {model_id}") return selected def is_sam2_model(self, model_id: str | None) -> bool: try: self.normalize_model_id(model_id) return True except ValueError: return False def _models_dir(self) -> Path: configured_path = Path(settings.sam_model_path) return configured_path.parent if configured_path.parent else Path("models") def _variant(self, model_id: str | None) -> SAM2Variant: return SAM2_VARIANTS[self.normalize_model_id(model_id)] def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]: variant_id = self.normalize_model_id(model_id) variant = SAM2_VARIANTS[variant_id] models_dir = self._models_dir() candidates: list[tuple[str, str]] = [] configured_path = Path(settings.sam_model_path) if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file(): candidates.append((settings.sam_model_config, str(configured_path))) candidates.extend([ (variant.config, str(models_dir / variant.checkpoint_filename)), (variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)), ]) for config, checkpoint_path in candidates: if os.path.isfile(checkpoint_path): return config, checkpoint_path return candidates[0] def _load_model(self, model_id: str | None = None) -> None: """Load the SAM 2 model and predictor on first use.""" variant_id = self.normalize_model_id(model_id) if self._model_loaded.get(variant_id): return if not TORCH_AVAILABLE: self._last_error[variant_id] = "PyTorch is not installed." logger.warning("PyTorch not available; skipping SAM2 model load.") self._model_loaded[variant_id] = True return if not SAM2_AVAILABLE: self._last_error[variant_id] = "sam2 package is not installed." logger.warning("SAM2 not available; skipping model load.") self._model_loaded[variant_id] = True return config, checkpoint_path = self._checkpoint_config(variant_id) if not os.path.isfile(checkpoint_path): self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}" logger.error("SAM checkpoint not found at %s", checkpoint_path) self._model_loaded[variant_id] = True return try: device = self._best_device() model = build_sam2( config, checkpoint_path, device=device, ) self._predictors[variant_id] = SAM2ImagePredictor(model) self._model_loaded[variant_id] = True self._loaded_device[variant_id] = device self._last_error[variant_id] = None logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device) except Exception as exc: # noqa: BLE001 self._last_error[variant_id] = str(exc) logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc) self._model_loaded[variant_id] = True # Prevent repeated load attempts def _load_video_model(self, model_id: str | None = None) -> None: """Load the SAM 2 video predictor on first propagation use.""" variant_id = self.normalize_model_id(model_id) if self._video_model_loaded.get(variant_id): return if not TORCH_AVAILABLE: self._video_last_error[variant_id] = "PyTorch is not installed." self._video_model_loaded[variant_id] = True return if not SAM2_AVAILABLE: self._video_last_error[variant_id] = "sam2 package is not installed." self._video_model_loaded[variant_id] = True return config, checkpoint_path = self._checkpoint_config(variant_id) if not os.path.isfile(checkpoint_path): self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}" self._video_model_loaded[variant_id] = True return try: device = self._best_device() self._video_predictors[variant_id] = build_sam2_video_predictor( config, checkpoint_path, device=device, ) self._video_model_loaded[variant_id] = True self._loaded_device[variant_id] = device self._video_last_error[variant_id] = None logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device) except Exception as exc: # noqa: BLE001 self._video_last_error[variant_id] = str(exc) self._video_model_loaded[variant_id] = True logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc) def _best_device(self) -> str: if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available(): return "cuda" return "cpu" def _ensure_ready(self, model_id: str | None = None) -> bool: """Ensure the model is loaded; return whether it is usable.""" variant_id = self.normalize_model_id(model_id) self._load_model(variant_id) return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None def _ensure_video_ready(self, model_id: str | None = None) -> bool: """Ensure the video predictor is loaded; return whether it is usable.""" variant_id = self.normalize_model_id(model_id) self._load_video_model(variant_id) return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None def status(self, model_id: str | None = None) -> dict: """Return lightweight, real runtime status without forcing model load.""" variant_id = self.normalize_model_id(model_id) variant = SAM2_VARIANTS[variant_id] _, checkpoint_path = self._checkpoint_config(variant_id) checkpoint_exists = os.path.isfile(checkpoint_path) using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename predictor = self._predictors.get(variant_id) device = self._loaded_device.get(variant_id) or self._best_device() available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists) if predictor is not None: message = f"{variant.label} model loaded and ready." elif available: message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference." if using_legacy_checkpoint: message += " Using legacy SAM 2 checkpoint fallback." else: missing = [] if not TORCH_AVAILABLE: missing.append("PyTorch") if not SAM2_AVAILABLE: missing.append("sam2 package") if not checkpoint_exists: missing.append("checkpoint") message = f"{variant.label} unavailable: missing {', '.join(missing)}." last_error = self._last_error.get(variant_id) if last_error and not predictor: message = last_error return { "id": variant.id, "label": variant.label, "available": available, "loaded": predictor is not None, "device": device, "supports": ["point", "box", "interactive", "auto", "propagate"], "message": message, "package_available": SAM2_AVAILABLE, "checkpoint_exists": checkpoint_exists, "checkpoint_path": checkpoint_path, "python_ok": True, "torch_ok": TORCH_AVAILABLE, "cuda_required": False, } # ----------------------------------------------------------------------- # Public API # ----------------------------------------------------------------------- def predict_points( self, model_id: str | None, image: np.ndarray, points: list[list[float]], labels: list[int], ) -> tuple[list[list[list[float]]], list[float]]: """Run point-prompt segmentation. Args: image: HWC numpy array (uint8). points: List of [x, y] normalized coordinates (0-1). labels: 1 for foreground, 0 for background. Returns: Tuple of (polygons, scores). """ variant_id = self.normalize_model_id(model_id) if not self._ensure_ready(variant_id): logger.warning("SAM2 not ready; returning dummy masks.") return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] try: predictor = self._predictors[variant_id] h, w = image.shape[:2] pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32) lbls = np.array(labels, dtype=np.int32) with torch.inference_mode(): # type: ignore[name-defined] predictor.set_image(image) masks, scores, _ = predictor.predict( point_coords=pts, point_labels=lbls, multimask_output=False, ) polygons = [] for m in masks: poly = self._mask_to_polygon(m) if poly: polygons.append(poly) return polygons, scores.tolist() except Exception as exc: # noqa: BLE001 logger.error("SAM2 point prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] def predict_box( self, model_id: str | None, image: np.ndarray, box: list[float], ) -> tuple[list[list[list[float]]], list[float]]: """Run box-prompt segmentation. Args: image: HWC numpy array (uint8). box: [x1, y1, x2, y2] normalized coordinates. Returns: Tuple of (polygons, scores). """ variant_id = self.normalize_model_id(model_id) if not self._ensure_ready(variant_id): logger.warning("SAM2 not ready; returning dummy masks.") return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] try: predictor = self._predictors[variant_id] h, w = image.shape[:2] bbox = np.array( [box[0] * w, box[1] * h, box[2] * w, box[3] * h], dtype=np.float32, ) with torch.inference_mode(): # type: ignore[name-defined] predictor.set_image(image) masks, scores, _ = predictor.predict( box=bbox[None, :], multimask_output=False, ) polygons = [] for m in masks: poly = self._mask_to_polygon(m) if poly: polygons.append(poly) return polygons, scores.tolist() except Exception as exc: # noqa: BLE001 logger.error("SAM2 box prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] def predict_interactive( self, model_id: str | None, image: np.ndarray, box: list[float] | None, points: list[list[float]], labels: list[int], ) -> tuple[list[list[list[float]]], list[float]]: """Run combined box and point prompt segmentation for refinement.""" variant_id = self.normalize_model_id(model_id) if not self._ensure_ready(variant_id): logger.warning("SAM2 not ready; returning dummy masks.") return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] try: predictor = self._predictors[variant_id] h, w = image.shape[:2] bbox = None if box: bbox = np.array( [box[0] * w, box[1] * h, box[2] * w, box[3] * h], dtype=np.float32, ) pts = None lbls = None if points: pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32) lbls = np.array(labels, dtype=np.int32) with torch.inference_mode(): # type: ignore[name-defined] predictor.set_image(image) masks, scores, _ = predictor.predict( point_coords=pts, point_labels=lbls, box=bbox, multimask_output=False, ) polygons = [] for m in masks: poly = self._mask_to_polygon(m) if poly: polygons.append(poly) return polygons, scores.tolist() except Exception as exc: # noqa: BLE001 logger.error("SAM2 interactive prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]: """Run automatic mask generation (grid of points). Args: image: HWC numpy array (uint8). Returns: Tuple of (polygons, scores). """ variant_id = self.normalize_model_id(model_id) if not self._ensure_ready(variant_id): logger.warning("SAM2 not ready; returning dummy masks.") return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] try: predictor = self._predictors[variant_id] with torch.inference_mode(): # type: ignore[name-defined] predictor.set_image(image) # Generate a uniform 16x16 grid of point prompts h, w = image.shape[:2] grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T pts = grid * np.array([w, h]) lbls = np.ones(pts.shape[0], dtype=np.int32) masks, scores, _ = predictor.predict( point_coords=pts, point_labels=lbls, multimask_output=False, ) polygons = [] for m in masks[:1]: poly = self._mask_to_polygon(m) if poly: polygons.append(poly) return polygons, scores[:1].tolist() except Exception as exc: # noqa: BLE001 logger.error("SAM2 auto prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] def propagate_video( self, model_id: str | None, frame_paths: list[str], source_frame_index: int, seed: dict, direction: str = "forward", max_frames: int | None = None, ) -> list[dict]: """Propagate one seed mask across a prepared frame directory with SAM 2 video.""" variant_id = self.normalize_model_id(model_id) if not self._ensure_video_ready(variant_id): raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"]) video_predictor = self._video_predictors[variant_id] if not frame_paths: return [] if source_frame_index < 0 or source_frame_index >= len(frame_paths): raise ValueError("source_frame_index is outside the frame sequence.") import cv2 source_image = cv2.imread(frame_paths[source_frame_index]) if source_image is None: raise RuntimeError("Failed to decode source frame for SAM 2 propagation.") height, width = source_image.shape[:2] seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height) if not seed_mask.any(): bbox = seed.get("bbox") if isinstance(bbox, list) and len(bbox) == 4: seed_mask = self._bbox_to_mask(bbox, width, height) if not seed_mask.any(): raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.") inference_state = video_predictor.init_state( video_path=os.path.dirname(frame_paths[0]), offload_video_to_cpu=True, offload_state_to_cpu=True, ) video_predictor.add_new_mask( inference_state, frame_idx=source_frame_index, obj_id=1, mask=seed_mask, ) results: dict[int, dict] = {} def collect(reverse: bool) -> None: for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video( inference_state, start_frame_idx=source_frame_index, max_frame_num_to_track=max_frames, reverse=reverse, ): masks = out_mask_logits if hasattr(masks, "detach"): masks = masks.detach().cpu().numpy() masks = np.asarray(masks) if masks.ndim == 4: masks = masks[:, 0] polygons = [] scores = [] for mask in masks: polygon = self._mask_to_polygon(mask > 0) if polygon: polygons.append(polygon) scores.append(1.0) results[int(out_frame_idx)] = { "frame_index": int(out_frame_idx), "polygons": polygons, "scores": scores, "object_ids": [int(obj_id) for obj_id in list(out_obj_ids)], } normalized_direction = direction.lower() if normalized_direction in {"forward", "both"}: collect(reverse=False) if normalized_direction in {"backward", "both"}: collect(reverse=True) try: video_predictor.reset_state(inference_state) except Exception: # noqa: BLE001 pass return [results[index] for index in sorted(results)] # ----------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------- @staticmethod def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: """Convert a binary mask to a normalized polygon.""" import cv2 if mask.dtype != np.uint8: mask = (mask > 0).astype(np.uint8) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) h, w = mask.shape[:2] largest = [] for cnt in contours: if len(cnt) > len(largest): largest = cnt if len(largest) < 3: return [] return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest] @staticmethod def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]: """Return a dummy rectangle polygon for fallback mode.""" return [ [ [0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75], ] ] @staticmethod def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray: import cv2 mask = np.zeros((height, width), dtype=np.uint8) for polygon in polygons: if len(polygon) < 3: continue pts = np.array( [ [ int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))), int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))), ] for x, y in polygon ], dtype=np.int32, ) cv2.fillPoly(mask, [pts], 1) return mask.astype(bool) @staticmethod def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray: x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox] left = int(round(x * max(width - 1, 1))) top = int(round(y * max(height - 1, 1))) right = int(round(min(x + w, 1.0) * max(width - 1, 1))) bottom = int(round(min(y + h, 1.0) * max(height - 1, 1))) mask = np.zeros((height, width), dtype=bool) mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True return mask # Singleton instance sam_engine = SAM2Engine()