"""Model registry for SAM runtimes and GPU status.""" from __future__ import annotations from typing import Any from config import settings from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, TORCH_AVAILABLE, sam_engine as sam2_engine # SAM 3 integration is intentionally disabled for the current product flow. # The source files are kept in the repository so the integration can be # restored later, but the active registry only exposes SAM 2. # from services.sam3_engine import sam3_engine try: import torch except Exception: # noqa: BLE001 torch = None # type: ignore[assignment] class ModelUnavailableError(RuntimeError): """Raised when a selected model cannot run in this environment.""" class SAMRegistry: """Dispatch predictions to the selected SAM backend.""" def __init__(self) -> None: self._engines = { "sam2": sam2_engine, # "sam3": sam3_engine, } def normalize_model_id(self, model_id: str | None) -> str: selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower() if self._engines["sam2"].is_sam2_model(selected): return self._engines["sam2"].normalize_model_id(selected) if selected not in self._engines: raise ValueError(f"Unsupported model: {model_id}") return selected def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]: selected = self.normalize_model_id(selected_model) return { "selected_model": selected, "gpu": self.gpu_status(), "models": [sam2_engine.status(model_id) for model_id in sam2_engine.variant_ids()], } def gpu_status(self) -> dict[str, Any]: cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available()) return { "available": cuda_available, "device": "cuda" if cuda_available else "cpu", "name": torch.cuda.get_device_name(0) if cuda_available else None, "torch_available": bool(TORCH_AVAILABLE), "torch_version": getattr(torch, "__version__", None) if torch is not None else None, "cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None, } def _engine(self, model_id: str | None) -> Any: normalized = self.normalize_model_id(model_id) if self._engines["sam2"].is_sam2_model(normalized): return self._engines["sam2"] return self._engines[normalized] def _ensure_available(self, model_id: str | None) -> Any: normalized = self.normalize_model_id(model_id) engine = self._engine(model_id) status = engine.status(normalized) if engine is sam2_engine else engine.status() if not status["available"]: raise ModelUnavailableError(status["message"]) return engine def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]): model = self.normalize_model_id(model_id) return self._ensure_available(model).predict_points(model, image, points, labels) def predict_box(self, model_id: str | None, image: Any, box: list[float]): model = self.normalize_model_id(model_id) return self._ensure_available(model).predict_box(model, image, box) def predict_interactive( self, model_id: str | None, image: Any, box: list[float] | None, points: list[list[float]], labels: list[int], ): model = self.normalize_model_id(model_id) if not sam2_engine.is_sam2_model(model): raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.") return self._ensure_available(model).predict_interactive(model, image, box, points, labels) def predict_auto(self, model_id: str | None, image: Any): model = self.normalize_model_id(model_id) return self._ensure_available(model).predict_auto(model, image) def predict_semantic( self, model_id: str | None, image: Any, text: str, confidence_threshold: float | None = None, ): self.normalize_model_id(model_id) raise NotImplementedError("Semantic text prompting is disabled; use SAM 2 point or box prompts.") def propagate_video( self, model_id: str | None, frame_paths: list[str], source_frame_index: int, seed: dict[str, Any], direction: str, max_frames: int | None, ): model = self.normalize_model_id(model_id) return self._ensure_available(model).propagate_video( model, frame_paths, source_frame_index, seed, direction=direction, max_frames=max_frames, ) sam_registry = SAMRegistry()