Files
Pre_Seg_Server/backend/services/sam_registry.py
admin 29a1a87e52 feat: 完善 SAM2.1 模型选择与标注工作流
- 后端 SAM2 引擎新增 sam2.1_hiera_tiny、sam2.1_hiera_small、sam2.1_hiera_base_plus、sam2.1_hiera_large 四个变体定义,并按变体维护 checkpoint/config、image predictor、video predictor、加载状态、错误信息和真实状态回报。

- 后端 SAM registry 仅暴露当前产品启用的 SAM2.1 变体,保留 sam2 作为 tiny 兼容别名,拒绝 sam3 产品入口,并把 point、box、interactive、auto、propagate 都分发到所选 SAM2.1 变体。

- 后端默认配置和下载脚本切换到 SAM2.1 checkpoint 命名,支持 legacy SAM2 checkpoint fallback,并在状态消息中标出 fallback 使用情况。

- 前端全局 AI 模型状态新增 SAM2.1 tiny/small/base+/large 类型和默认 tiny,API 请求默认携带 sam2.1_hiera_tiny,AI 页面提供模型变体选择和所选模型状态展示。

- AI 智能分割页移除当前产品不使用的 SAM3/文本提示入口,保留正向点、反向点、框选和参数开关;AI 页只展示本页生成的候选 mask,并支持遮罩清晰度调节、候选 mask 上继续加正/反点、清空本页候选、推送到工作区编辑。

- 工作区和 Canvas 补强 SAM2 交互式细化链路:框选后正/反点继续细化同一个候选 mask,反向点请求启用背景过滤,空结果会移除被否定候选;AI 推送到工作区后保留选中态和未保存 draft mask。

- 工作区标注保存闭环补强:未保存 mask 可归档保存,dirty saved mask 可更新,保存后用后端 saved annotation 替换已提交 draft,清空/删除已保存 mask 时同步后端删除。

- Dashboard 任务进度区改为展示 queued、running、success、failed、cancelled 最近任务,处理中统计只计算 queued/running,并保留近期完成记录。

- 时间轴在顶部时间进度条和底部缩略图导航轴之间新增已编辑帧标记带,基于当前项目帧内 masks 标出已有编辑/标注的帧,并支持点击标记跳转。

- 前端测试覆盖 SAM2.1 变体选择、模型状态徽标、AI 页候选隔离、遮罩透明度、候选上追加正/反点、推送工作区保留选择、Canvas 交互式细化、VideoWorkspace 传播/保存、Dashboard 进度和时间轴已编辑帧标记。

- 后端测试覆盖 SAM2.1 变体状态、sam2 alias 兼容、sam3 禁用、semantic 禁用、传播标注保存、Dashboard 最近任务状态和 SAM3 历史测试跳过说明。

- README、AGENTS 和 doc 文档同步当前真实进度,更新 SAM2.1 变体、SAM3 禁用、接口契约、设计冻结、需求冻结、前端元素审计、实施计划、FastAPI docs 说明和测试矩阵。
2026-05-01 23:39:53 +08:00

131 lines
4.8 KiB
Python

"""Model registry for SAM runtimes and GPU status."""
from __future__ import annotations
from typing import Any
from config import settings
from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, TORCH_AVAILABLE, sam_engine as sam2_engine
# SAM 3 integration is intentionally disabled for the current product flow.
# The source files are kept in the repository so the integration can be
# restored later, but the active registry only exposes SAM 2.
# from services.sam3_engine import sam3_engine
try:
import torch
except Exception: # noqa: BLE001
torch = None # type: ignore[assignment]
class ModelUnavailableError(RuntimeError):
"""Raised when a selected model cannot run in this environment."""
class SAMRegistry:
"""Dispatch predictions to the selected SAM backend."""
def __init__(self) -> None:
self._engines = {
"sam2": sam2_engine,
# "sam3": sam3_engine,
}
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
if self._engines["sam2"].is_sam2_model(selected):
return self._engines["sam2"].normalize_model_id(selected)
if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}")
return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
selected = self.normalize_model_id(selected_model)
return {
"selected_model": selected,
"gpu": self.gpu_status(),
"models": [sam2_engine.status(model_id) for model_id in sam2_engine.variant_ids()],
}
def gpu_status(self) -> dict[str, Any]:
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
return {
"available": cuda_available,
"device": "cuda" if cuda_available else "cpu",
"name": torch.cuda.get_device_name(0) if cuda_available else None,
"torch_available": bool(TORCH_AVAILABLE),
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
}
def _engine(self, model_id: str | None) -> Any:
normalized = self.normalize_model_id(model_id)
if self._engines["sam2"].is_sam2_model(normalized):
return self._engines["sam2"]
return self._engines[normalized]
def _ensure_available(self, model_id: str | None) -> Any:
normalized = self.normalize_model_id(model_id)
engine = self._engine(model_id)
status = engine.status(normalized) if engine is sam2_engine else engine.status()
if not status["available"]:
raise ModelUnavailableError(status["message"])
return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_points(model, image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_box(model, image, box)
def predict_interactive(
self,
model_id: str | None,
image: Any,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
):
model = self.normalize_model_id(model_id)
if not sam2_engine.is_sam2_model(model):
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
return self._ensure_available(model).predict_interactive(model, image, box, points, labels)
def predict_auto(self, model_id: str | None, image: Any):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_auto(model, image)
def predict_semantic(
self,
model_id: str | None,
image: Any,
text: str,
confidence_threshold: float | None = None,
):
self.normalize_model_id(model_id)
raise NotImplementedError("Semantic text prompting is disabled; use SAM 2 point or box prompts.")
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).propagate_video(
model,
frame_paths,
source_frame_index,
seed,
direction=direction,
max_frames=max_frames,
)
sam_registry = SAMRegistry()