Files
Pre_Seg_Server/backend/services/sam3_engine.py
admin 8a9247075e feat: 完善 AI 分割与工作区标注闭环
功能增加:

- 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。

- 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。

- 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。

- 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。

- 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。

bugfix:

- 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。

- 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。

- 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。

- 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。

- 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。

测试和文档:

- 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。

- 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。

- 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。
2026-05-01 21:50:17 +08:00

448 lines
17 KiB
Python

"""SAM 3 engine adapter and runtime status.
The official facebookresearch/sam3 package currently targets Python 3.12+
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
only performs inference when the local runtime can actually import and execute
the package.
"""
from __future__ import annotations
import importlib.util
import json
import logging
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from config import settings
from services.sam2_engine import SAM2Engine
logger = logging.getLogger(__name__)
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
class SAM3Engine:
"""Lazy SAM 3 image inference adapter."""
def __init__(self) -> None:
self._model: Any | None = None
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
self._external_status_cache: dict[str, Any] | None = None
self._external_status_checked_at = 0.0
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _checkpoint_path(self) -> str | None:
path = settings.sam3_checkpoint_path.strip()
return path if path else None
def _checkpoint_exists(self) -> bool:
path = self._checkpoint_path()
return bool(path and os.path.isfile(path))
def _can_load(self) -> bool:
return bool(
SAM3_PACKAGE_AVAILABLE
and TORCH_AVAILABLE
and self._python_ok()
and self._gpu_ok()
and self._checkpoint_exists()
)
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
def _external_python_exists(self) -> bool:
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
def _external_status(self, force: bool = False) -> dict[str, Any]:
now = time.monotonic()
if (
not force
and self._external_status_cache is not None
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
):
return self._external_status_cache
if not settings.sam3_external_enabled:
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": "SAM 3 external runtime is disabled.",
}
elif not self._external_python_exists():
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
}
else:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
text=True,
timeout=min(settings.sam3_timeout_seconds, 30),
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {detail}",
}
else:
status = json.loads(completed.stdout)
except Exception as exc: # noqa: BLE001
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {exc}",
}
self._external_status_cache = status
self._external_status_checked_at = now
return status
def _load_model(self) -> None:
if self._model_loaded:
return
if not self._can_load():
self._last_error = self._status_message()
self._model_loaded = True
return
try:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model(
checkpoint_path=self._checkpoint_path(),
load_from_HF=False,
)
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
self._model_loaded = True
logger.error("Failed to load SAM 3 model: %s", exc)
def _ensure_ready(self) -> bool:
self._load_model()
return self._processor is not None
def _status_message(self) -> str:
missing = []
if not SAM3_PACKAGE_AVAILABLE:
missing.append("sam3 package")
if not self._python_ok():
missing.append("Python 3.12+ runtime")
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if not self._checkpoint_exists():
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
external_status = self._external_status()
available = bool(self._can_load() or external_status.get("available"))
external_ready = bool(external_status.get("available"))
message = self._last_error or self._status_message()
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic", "box", "video_track"],
"message": message,
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True,
"external_available": external_ready,
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
}
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def _predict_external(
self,
image: np.ndarray,
prompt_type: str,
*,
text: str = "",
box: list[float] | None = None,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
tmp_path = Path(tmpdir)
image_path = tmp_path / "image.png"
request_path = tmp_path / "request.json"
Image.fromarray(image).save(image_path)
request_path.write_text(
json.dumps(
{
"image_path": str(image_path),
"prompt_type": prompt_type,
"text": text.strip(),
"box": box,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": (
confidence_threshold
if confidence_threshold is not None
else settings.sam3_confidence_threshold
),
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def _predict_semantic_external(
self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(
image,
"semantic",
text=text,
confidence_threshold=confidence_threshold,
)
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "box", box=box)
def _propagate_video_external(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
) -> list[dict[str, Any]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
if not frame_paths:
return []
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
request_path = Path(tmpdir) / "request.json"
request_path.write_text(
json.dumps(
{
"prompt_type": "video_track",
"frame_dir": str(Path(frame_paths[0]).parent),
"source_frame_index": source_frame_index,
"seed": seed,
"direction": direction,
"max_frames": max_frames,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("frames", [])
def predict_semantic(
self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._can_load() and self._external_status().get("available"):
return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
return self._prediction_to_polygons(output)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
if not self._can_load() and self._external_status().get("available"):
return self._predict_box_external(image, box)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.add_geometric_prompt(
state=state,
box=self._xyxy_to_cxcywh(box),
label=True,
)
return self._prediction_to_polygons(output)
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict[str, Any]]:
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
sam3_engine = SAM3Engine()