feat: 打通全栈标注闭环、异步拆帧与模型状态
后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
220
backend/services/media_task_runner.py
Normal file
220
backend/services/media_task_runner.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Background media parsing runner used by Celery workers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
|
||||
from models import Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.frame_parser import (
|
||||
extract_thumbnail,
|
||||
parse_dicom,
|
||||
parse_video,
|
||||
upload_frames_to_minio,
|
||||
)
|
||||
from statuses import (
|
||||
PROJECT_STATUS_ERROR,
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Parse one project's media and update task progress in the database."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.project_id is None:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="任务缺少 project_id",
|
||||
error="Task has no project_id",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError("Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == task.project_id).first()
|
||||
if not project:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目不存在",
|
||||
error="Project not found",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError(f"Project not found: {task.project_id}")
|
||||
|
||||
if not project.video_path:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目没有可解析媒体",
|
||||
error="Project has no media uploaded",
|
||||
finished=True,
|
||||
)
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
db.commit()
|
||||
raise ValueError("Project has no media uploaded")
|
||||
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
||||
if effective_source == "dicom":
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
|
||||
client = get_minio_client()
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
||||
for obj in objects:
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
project.original_fps = original_fps
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
|
||||
result = {
|
||||
"project_id": project.id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": PROJECT_STATUS_READY,
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="解析完成",
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="解析失败",
|
||||
error=str(exc),
|
||||
finished=True,
|
||||
)
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
|
||||
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -11,10 +11,18 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attempt to import SAM 2; fall back to stubs if unavailable.
|
||||
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
@@ -31,6 +39,8 @@ class SAM2Engine:
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._model_loaded = False
|
||||
self._loaded_device: str | None = None
|
||||
self._last_error: str | None = None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -40,34 +50,87 @@ class SAM2Engine:
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._last_error = "PyTorch is not installed."
|
||||
logger.warning("PyTorch not available; skipping SAM2 model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not SAM2_AVAILABLE:
|
||||
self._last_error = "sam2 package is not installed."
|
||||
logger.warning("SAM2 not available; skipping model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
|
||||
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
model = build_sam2(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
self._predictor = SAM2ImagePredictor(model)
|
||||
self._model_loaded = True
|
||||
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
|
||||
self._loaded_device = device
|
||||
self._last_error = None
|
||||
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error = str(exc)
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
"""Ensure the model is loaded; return whether it is usable."""
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
checkpoint_exists = os.path.isfile(settings.sam_model_path)
|
||||
device = self._loaded_device or self._best_device()
|
||||
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
|
||||
if self._predictor is not None:
|
||||
message = "SAM 2 model loaded and ready."
|
||||
elif available:
|
||||
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference."
|
||||
else:
|
||||
missing = []
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not SAM2_AVAILABLE:
|
||||
missing.append("sam2 package")
|
||||
if not checkpoint_exists:
|
||||
missing.append("checkpoint")
|
||||
message = f"SAM 2 unavailable: missing {', '.join(missing)}."
|
||||
if self._last_error and not self._predictor:
|
||||
message = self._last_error
|
||||
return {
|
||||
"id": "sam2",
|
||||
"label": "SAM 2",
|
||||
"available": available,
|
||||
"loaded": self._predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "auto"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
"checkpoint_path": settings.sam_model_path,
|
||||
"python_ok": True,
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"cuda_required": False,
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
148
backend/services/sam3_engine.py
Normal file
148
backend/services/sam3_engine.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""SAM 3 engine adapter and runtime status.
|
||||
|
||||
The official facebookresearch/sam3 package currently targets Python 3.12+
|
||||
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
|
||||
only performs inference when the local runtime can actually import and execute
|
||||
the package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import SAM2Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
|
||||
|
||||
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
|
||||
|
||||
|
||||
class SAM3Engine:
|
||||
"""Lazy SAM 3 image inference adapter."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: Any | None = None
|
||||
self._processor: Any | None = None
|
||||
self._model_loaded = False
|
||||
self._last_error: str | None = None
|
||||
|
||||
def _python_ok(self) -> bool:
|
||||
return sys.version_info >= (3, 12)
|
||||
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model_loaded:
|
||||
return
|
||||
if not self._can_load():
|
||||
self._last_error = self._status_message()
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model()
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error = str(exc)
|
||||
self._model_loaded = True
|
||||
logger.error("Failed to load SAM 3 model: %s", exc)
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
self._load_model()
|
||||
return self._processor is not None
|
||||
|
||||
def _status_message(self) -> str:
|
||||
missing = []
|
||||
if not SAM3_PACKAGE_AVAILABLE:
|
||||
missing.append("sam3 package")
|
||||
if not self._python_ok():
|
||||
missing.append("Python 3.12+ runtime")
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
|
||||
def status(self) -> dict:
|
||||
available = self._can_load()
|
||||
return {
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else "unavailable",
|
||||
"supports": ["semantic"],
|
||||
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
|
||||
"package_available": SAM3_PACKAGE_AVAILABLE,
|
||||
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": self._python_ok(),
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"cuda_required": True,
|
||||
}
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
80
backend/services/sam_registry.py
Normal file
80
backend/services/sam_registry.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Model registry for SAM runtimes and GPU status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine
|
||||
from services.sam3_engine import sam3_engine
|
||||
|
||||
try:
|
||||
import torch
|
||||
except Exception: # noqa: BLE001
|
||||
torch = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class ModelUnavailableError(RuntimeError):
|
||||
"""Raised when a selected model cannot run in this environment."""
|
||||
|
||||
|
||||
class SAMRegistry:
|
||||
"""Dispatch predictions to the selected SAM backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._engines = {
|
||||
"sam2": sam2_engine,
|
||||
"sam3": sam3_engine,
|
||||
}
|
||||
|
||||
def normalize_model_id(self, model_id: str | None) -> str:
|
||||
selected = (model_id or settings.sam_default_model or "sam2").lower()
|
||||
if selected not in self._engines:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
return selected
|
||||
|
||||
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
"selected_model": self.normalize_model_id(selected_model),
|
||||
"gpu": self.gpu_status(),
|
||||
"models": [engine.status() for engine in self._engines.values()],
|
||||
}
|
||||
|
||||
def gpu_status(self) -> dict[str, Any]:
|
||||
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
return {
|
||||
"available": cuda_available,
|
||||
"device": "cuda" if cuda_available else "cpu",
|
||||
"name": torch.cuda.get_device_name(0) if cuda_available else None,
|
||||
"torch_available": bool(TORCH_AVAILABLE),
|
||||
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
|
||||
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
|
||||
}
|
||||
|
||||
def _engine(self, model_id: str | None) -> Any:
|
||||
return self._engines[self.normalize_model_id(model_id)]
|
||||
|
||||
def _ensure_available(self, model_id: str | None) -> Any:
|
||||
engine = self._engine(model_id)
|
||||
status = engine.status()
|
||||
if not status["available"]:
|
||||
raise ModelUnavailableError(status["message"])
|
||||
return engine
|
||||
|
||||
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
|
||||
return self._ensure_available(model_id).predict_points(image, points, labels)
|
||||
|
||||
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
|
||||
return self._ensure_available(model_id).predict_box(image, box)
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: Any):
|
||||
return self._ensure_available(model_id).predict_auto(image)
|
||||
|
||||
def predict_semantic(self, model_id: str | None, image: Any, text: str):
|
||||
model = self.normalize_model_id(model_id)
|
||||
if model == "sam3":
|
||||
return self._ensure_available(model).predict_semantic(image, text)
|
||||
return self._ensure_available(model).predict_auto(image)
|
||||
|
||||
|
||||
sam_registry = SAMRegistry()
|
||||
Reference in New Issue
Block a user