"""SAM 3 engine adapter and runtime status. The official facebookresearch/sam3 package currently targets Python 3.12+ and CUDA-capable PyTorch. This adapter reports those requirements honestly and only performs inference when the local runtime can actually import and execute the package. """ from __future__ import annotations import importlib.util import logging import sys from typing import Any import numpy as np from PIL import Image from config import settings from services.sam2_engine import SAM2Engine logger = logging.getLogger(__name__) try: import torch TORCH_AVAILABLE = True except Exception as exc: # noqa: BLE001 TORCH_AVAILABLE = False torch = None # type: ignore[assignment] logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc) SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None class SAM3Engine: """Lazy SAM 3 image inference adapter.""" def __init__(self) -> None: self._model: Any | None = None self._processor: Any | None = None self._model_loaded = False self._last_error: str | None = None def _python_ok(self) -> bool: return sys.version_info >= (3, 12) def _gpu_ok(self) -> bool: return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available()) def _can_load(self) -> bool: return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok()) def _load_model(self) -> None: if self._model_loaded: return if not self._can_load(): self._last_error = self._status_message() self._model_loaded = True return try: from sam3.model.sam3_image_processor import Sam3Processor from sam3.model_builder import build_sam3_image_model self._model = build_sam3_image_model() self._processor = Sam3Processor(self._model) self._model_loaded = True self._last_error = None logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version) except Exception as exc: # noqa: BLE001 self._last_error = str(exc) self._model_loaded = True logger.error("Failed to load SAM 3 model: %s", exc) def _ensure_ready(self) -> bool: self._load_model() return self._processor is not None def _status_message(self) -> str: missing = [] if not SAM3_PACKAGE_AVAILABLE: missing.append("sam3 package") if not self._python_ok(): missing.append("Python 3.12+ runtime") if not TORCH_AVAILABLE: missing.append("PyTorch") if not self._gpu_ok(): missing.append("CUDA GPU") if missing: return f"SAM 3 unavailable: missing {', '.join(missing)}." return "SAM 3 dependencies are present; model will load on first inference." def status(self) -> dict: available = self._can_load() return { "id": "sam3", "label": "SAM 3", "available": available, "loaded": self._processor is not None, "device": "cuda" if self._gpu_ok() else "unavailable", "supports": ["semantic"], "message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()), "package_available": SAM3_PACKAGE_AVAILABLE, "checkpoint_exists": SAM3_PACKAGE_AVAILABLE, "checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})", "python_ok": self._python_ok(), "torch_ok": TORCH_AVAILABLE, "cuda_required": True, } def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: if not text.strip(): raise ValueError("SAM 3 semantic prompt requires non-empty text.") if not self._ensure_ready(): raise RuntimeError(self.status()["message"]) pil_image = Image.fromarray(image) with torch.inference_mode(): # type: ignore[union-attr] state = self._processor.set_image(pil_image) output = self._processor.set_text_prompt(state=state, prompt=text.strip()) masks = output.get("masks", []) scores = output.get("scores", []) polygons = [] for mask in masks: if hasattr(mask, "detach"): mask = mask.detach().cpu().numpy() if mask.ndim == 3: mask = mask[0] poly = SAM2Engine._mask_to_polygon(mask) if poly: polygons.append(poly) if hasattr(scores, "detach"): scores = scores.detach().cpu().tolist() elif hasattr(scores, "tolist"): scores = scores.tolist() return polygons, list(scores) def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]: raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.") def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]: raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.") sam3_engine = SAM3Engine()