Files
Pre_Seg_Server/backend/services/sam2_engine.py
admin 29a1a87e52 feat: 完善 SAM2.1 模型选择与标注工作流
- 后端 SAM2 引擎新增 sam2.1_hiera_tiny、sam2.1_hiera_small、sam2.1_hiera_base_plus、sam2.1_hiera_large 四个变体定义,并按变体维护 checkpoint/config、image predictor、video predictor、加载状态、错误信息和真实状态回报。

- 后端 SAM registry 仅暴露当前产品启用的 SAM2.1 变体,保留 sam2 作为 tiny 兼容别名,拒绝 sam3 产品入口,并把 point、box、interactive、auto、propagate 都分发到所选 SAM2.1 变体。

- 后端默认配置和下载脚本切换到 SAM2.1 checkpoint 命名,支持 legacy SAM2 checkpoint fallback,并在状态消息中标出 fallback 使用情况。

- 前端全局 AI 模型状态新增 SAM2.1 tiny/small/base+/large 类型和默认 tiny,API 请求默认携带 sam2.1_hiera_tiny,AI 页面提供模型变体选择和所选模型状态展示。

- AI 智能分割页移除当前产品不使用的 SAM3/文本提示入口,保留正向点、反向点、框选和参数开关;AI 页只展示本页生成的候选 mask,并支持遮罩清晰度调节、候选 mask 上继续加正/反点、清空本页候选、推送到工作区编辑。

- 工作区和 Canvas 补强 SAM2 交互式细化链路:框选后正/反点继续细化同一个候选 mask,反向点请求启用背景过滤,空结果会移除被否定候选;AI 推送到工作区后保留选中态和未保存 draft mask。

- 工作区标注保存闭环补强:未保存 mask 可归档保存,dirty saved mask 可更新,保存后用后端 saved annotation 替换已提交 draft,清空/删除已保存 mask 时同步后端删除。

- Dashboard 任务进度区改为展示 queued、running、success、failed、cancelled 最近任务,处理中统计只计算 queued/running,并保留近期完成记录。

- 时间轴在顶部时间进度条和底部缩略图导航轴之间新增已编辑帧标记带,基于当前项目帧内 masks 标出已有编辑/标注的帧,并支持点击标记跳转。

- 前端测试覆盖 SAM2.1 变体选择、模型状态徽标、AI 页候选隔离、遮罩透明度、候选上追加正/反点、推送工作区保留选择、Canvas 交互式细化、VideoWorkspace 传播/保存、Dashboard 进度和时间轴已编辑帧标记。

- 后端测试覆盖 SAM2.1 变体状态、sam2 alias 兼容、sam3 禁用、semantic 禁用、传播标注保存、Dashboard 最近任务状态和 SAM3 历史测试跳过说明。

- README、AGENTS 和 doc 文档同步当前真实进度,更新 SAM2.1 变体、SAM3 禁用、接口契约、设计冻结、需求冻结、前端元素审计、实施计划、FastAPI docs 说明和测试矩阵。
2026-05-01 23:39:53 +08:00

638 lines
24 KiB
Python

"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny"
@dataclass(frozen=True)
class SAM2Variant:
"""One selectable SAM 2.1 runtime variant."""
id: str
label: str
short_label: str
config: str
legacy_config: str
checkpoint_filename: str
legacy_checkpoint_filename: str
SAM2_VARIANTS: dict[str, SAM2Variant] = {
"sam2.1_hiera_tiny": SAM2Variant(
id="sam2.1_hiera_tiny",
label="SAM 2.1 Tiny",
short_label="tiny",
config="configs/sam2.1/sam2.1_hiera_t.yaml",
legacy_config="configs/sam2/sam2_hiera_t.yaml",
checkpoint_filename="sam2.1_hiera_tiny.pt",
legacy_checkpoint_filename="sam2_hiera_tiny.pt",
),
"sam2.1_hiera_small": SAM2Variant(
id="sam2.1_hiera_small",
label="SAM 2.1 Small",
short_label="small",
config="configs/sam2.1/sam2.1_hiera_s.yaml",
legacy_config="configs/sam2/sam2_hiera_s.yaml",
checkpoint_filename="sam2.1_hiera_small.pt",
legacy_checkpoint_filename="sam2_hiera_small.pt",
),
"sam2.1_hiera_base_plus": SAM2Variant(
id="sam2.1_hiera_base_plus",
label="SAM 2.1 Base+",
short_label="base+",
config="configs/sam2.1/sam2.1_hiera_b+.yaml",
legacy_config="configs/sam2/sam2_hiera_b+.yaml",
checkpoint_filename="sam2.1_hiera_base_plus.pt",
legacy_checkpoint_filename="sam2_hiera_base_plus.pt",
),
"sam2.1_hiera_large": SAM2Variant(
id="sam2.1_hiera_large",
label="SAM 2.1 Large",
short_label="large",
config="configs/sam2.1/sam2.1_hiera_l.yaml",
legacy_config="configs/sam2/sam2_hiera_l.yaml",
checkpoint_filename="sam2.1_hiera_large.pt",
legacy_checkpoint_filename="sam2_hiera_large.pt",
),
}
SAM2_MODEL_ALIASES = {
"sam2": DEFAULT_SAM2_MODEL_ID,
"sam2.1": DEFAULT_SAM2_MODEL_ID,
"sam2_tiny": DEFAULT_SAM2_MODEL_ID,
}
# ---------------------------------------------------------------------------
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {}
self._video_predictors: dict[str, object | None] = {}
self._model_loaded: dict[str, bool] = {}
self._video_model_loaded: dict[str, bool] = {}
self._loaded_device: dict[str, str] = {}
self._last_error: dict[str, str | None] = {}
self._video_last_error: dict[str, str | None] = {}
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def variant_ids(self) -> list[str]:
return list(SAM2_VARIANTS.keys())
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
selected = SAM2_MODEL_ALIASES.get(selected, selected)
if selected not in SAM2_VARIANTS:
raise ValueError(f"Unsupported SAM2 model: {model_id}")
return selected
def is_sam2_model(self, model_id: str | None) -> bool:
try:
self.normalize_model_id(model_id)
return True
except ValueError:
return False
def _models_dir(self) -> Path:
configured_path = Path(settings.sam_model_path)
return configured_path.parent if configured_path.parent else Path("models")
def _variant(self, model_id: str | None) -> SAM2Variant:
return SAM2_VARIANTS[self.normalize_model_id(model_id)]
def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]:
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
models_dir = self._models_dir()
candidates: list[tuple[str, str]] = []
configured_path = Path(settings.sam_model_path)
if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file():
candidates.append((settings.sam_model_config, str(configured_path)))
candidates.extend([
(variant.config, str(models_dir / variant.checkpoint_filename)),
(variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)),
])
for config, checkpoint_path in candidates:
if os.path.isfile(checkpoint_path):
return config, checkpoint_path
return candidates[0]
def _load_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 model and predictor on first use."""
variant_id = self.normalize_model_id(model_id)
if self._model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._last_error[variant_id] = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._last_error[variant_id] = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
logger.error("SAM checkpoint not found at %s", checkpoint_path)
self._model_loaded[variant_id] = True
return
try:
device = self._best_device()
model = build_sam2(
config,
checkpoint_path,
device=device,
)
self._predictors[variant_id] = SAM2ImagePredictor(model)
self._model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._last_error[variant_id] = None
logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._last_error[variant_id] = str(exc)
logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc)
self._model_loaded[variant_id] = True # Prevent repeated load attempts
def _load_video_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 video predictor on first propagation use."""
variant_id = self.normalize_model_id(model_id)
if self._video_model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._video_last_error[variant_id] = "PyTorch is not installed."
self._video_model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._video_last_error[variant_id] = "sam2 package is not installed."
self._video_model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
self._video_model_loaded[variant_id] = True
return
try:
device = self._best_device()
self._video_predictors[variant_id] = build_sam2_video_predictor(
config,
checkpoint_path,
device=device,
)
self._video_model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._video_last_error[variant_id] = None
logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._video_last_error[variant_id] = str(exc)
self._video_model_loaded[variant_id] = True
logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc)
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self, model_id: str | None = None) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_model(variant_id)
return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None
def _ensure_video_ready(self, model_id: str | None = None) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_video_model(variant_id)
return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None
def status(self, model_id: str | None = None) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
_, checkpoint_path = self._checkpoint_config(variant_id)
checkpoint_exists = os.path.isfile(checkpoint_path)
using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename
predictor = self._predictors.get(variant_id)
device = self._loaded_device.get(variant_id) or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if predictor is not None:
message = f"{variant.label} model loaded and ready."
elif available:
message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference."
if using_legacy_checkpoint:
message += " Using legacy SAM 2 checkpoint fallback."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"{variant.label} unavailable: missing {', '.join(missing)}."
last_error = self._last_error.get(variant_id)
if last_error and not predictor:
message = last_error
return {
"id": variant.id,
"label": variant.label,
"available": available,
"loaded": predictor is not None,
"device": device,
"supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": checkpoint_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
model_id: str | None,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
model_id: str | None,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_interactive(
self,
model_id: str | None,
image: np.ndarray,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = None
if box:
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
pts = None
lbls = None
if points:
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
box=bbox,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks[:1]:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:1].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict,
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_video_ready(variant_id):
raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"])
video_predictor = self._video_predictors[variant_id]
if not frame_paths:
return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
raise ValueError("source_frame_index is outside the frame sequence.")
import cv2
source_image = cv2.imread(frame_paths[source_frame_index])
if source_image is None:
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
height, width = source_image.shape[:2]
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
if not seed_mask.any():
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
seed_mask = self._bbox_to_mask(bbox, width, height)
if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True,
offload_state_to_cpu=True,
)
video_predictor.add_new_mask(
inference_state,
frame_idx=source_frame_index,
obj_id=1,
mask=seed_mask,
)
results: dict[int, dict] = {}
def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
inference_state,
start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames,
reverse=reverse,
):
masks = out_mask_logits
if hasattr(masks, "detach"):
masks = masks.detach().cpu().numpy()
masks = np.asarray(masks)
if masks.ndim == 4:
masks = masks[:, 0]
polygons = []
scores = []
for mask in masks:
polygon = self._mask_to_polygon(mask > 0)
if polygon:
polygons.append(polygon)
scores.append(1.0)
results[int(out_frame_idx)] = {
"frame_index": int(out_frame_idx),
"polygons": polygons,
"scores": scores,
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
}
normalized_direction = direction.lower()
if normalized_direction in {"forward", "both"}:
collect(reverse=False)
if normalized_direction in {"backward", "both"}:
collect(reverse=True)
try:
video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001
pass
return [results[index] for index in sorted(results)]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
largest = []
for cnt in contours:
if len(cnt) > len(largest):
largest = cnt
if len(largest) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
@staticmethod
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
import cv2
mask = np.zeros((height, width), dtype=np.uint8)
for polygon in polygons:
if len(polygon) < 3:
continue
pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in polygon
],
dtype=np.int32,
)
cv2.fillPoly(mask, [pts], 1)
return mask.astype(bool)
@staticmethod
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
left = int(round(x * max(width - 1, 1)))
top = int(round(y * max(height - 1, 1)))
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
mask = np.zeros((height, width), dtype=bool)
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
return mask
# Singleton instance
sam_engine = SAM2Engine()