"""Model registry for SAM runtimes and GPU status.""" from __future__ import annotations from typing import Any from config import settings from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine from services.sam3_engine import sam3_engine try: import torch except Exception: # noqa: BLE001 torch = None # type: ignore[assignment] class ModelUnavailableError(RuntimeError): """Raised when a selected model cannot run in this environment.""" class SAMRegistry: """Dispatch predictions to the selected SAM backend.""" def __init__(self) -> None: self._engines = { "sam2": sam2_engine, "sam3": sam3_engine, } def normalize_model_id(self, model_id: str | None) -> str: selected = (model_id or settings.sam_default_model or "sam2").lower() if selected not in self._engines: raise ValueError(f"Unsupported model: {model_id}") return selected def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]: return { "selected_model": self.normalize_model_id(selected_model), "gpu": self.gpu_status(), "models": [engine.status() for engine in self._engines.values()], } def gpu_status(self) -> dict[str, Any]: cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available()) return { "available": cuda_available, "device": "cuda" if cuda_available else "cpu", "name": torch.cuda.get_device_name(0) if cuda_available else None, "torch_available": bool(TORCH_AVAILABLE), "torch_version": getattr(torch, "__version__", None) if torch is not None else None, "cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None, } def _engine(self, model_id: str | None) -> Any: return self._engines[self.normalize_model_id(model_id)] def _ensure_available(self, model_id: str | None) -> Any: engine = self._engine(model_id) status = engine.status() if not status["available"]: raise ModelUnavailableError(status["message"]) return engine def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]): return self._ensure_available(model_id).predict_points(image, points, labels) def predict_box(self, model_id: str | None, image: Any, box: list[float]): return self._ensure_available(model_id).predict_box(image, box) def predict_auto(self, model_id: str | None, image: Any): return self._ensure_available(model_id).predict_auto(image) def predict_semantic(self, model_id: str | None, image: Any, text: str): model = self.normalize_model_id(model_id) if model == "sam3": return self._ensure_available(model).predict_semantic(image, text) return self._ensure_available(model).predict_auto(image) sam_registry = SAMRegistry()