Files
Pre_Seg_Server/backend/services/sam2_engine.py
admin f88f9bdbb9 支持中空mask编辑和传播保洞
- 前端按 polygonRingCounts 维护外圈/内洞 ring 分组,中空 mask 在调整多边形时显示内洞顶点和插点手柄。

- 保存与回显标注时将中空结构拆分为 mask_data.polygons 和 mask_data.holes,导入/普通 mask 共享同一编辑体验。

- 自动传播 seed 携带 holes,SAM 2 seed 栅格化时扣除内洞,避免中空 mask 以实心形式传播。

- 传播结果轮廓提取改为保留层级内洞,并在同步传播和 Celery 传播落库时写回 holes 与 hasHoles。

- 传播 seed 签名纳入 holes,并加固保存结果时 holes 与原始 polygon 索引对齐。

- 补充前端保存/回显、Canvas 内洞编辑和后端 SAM 2 hole 处理测试。

- 更新 AGENTS、接口契约、需求冻结、设计冻结和测试计划文档,移除中空结构未实现的旧描述。
2026-05-03 18:28:46 +08:00

691 lines
27 KiB
Python

"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny"
@dataclass(frozen=True)
class SAM2Variant:
"""One selectable SAM 2.1 runtime variant."""
id: str
label: str
short_label: str
config: str
legacy_config: str
checkpoint_filename: str
legacy_checkpoint_filename: str
SAM2_VARIANTS: dict[str, SAM2Variant] = {
"sam2.1_hiera_tiny": SAM2Variant(
id="sam2.1_hiera_tiny",
label="SAM 2.1 Tiny",
short_label="tiny",
config="configs/sam2.1/sam2.1_hiera_t.yaml",
legacy_config="configs/sam2/sam2_hiera_t.yaml",
checkpoint_filename="sam2.1_hiera_tiny.pt",
legacy_checkpoint_filename="sam2_hiera_tiny.pt",
),
"sam2.1_hiera_small": SAM2Variant(
id="sam2.1_hiera_small",
label="SAM 2.1 Small",
short_label="small",
config="configs/sam2.1/sam2.1_hiera_s.yaml",
legacy_config="configs/sam2/sam2_hiera_s.yaml",
checkpoint_filename="sam2.1_hiera_small.pt",
legacy_checkpoint_filename="sam2_hiera_small.pt",
),
"sam2.1_hiera_base_plus": SAM2Variant(
id="sam2.1_hiera_base_plus",
label="SAM 2.1 Base+",
short_label="base+",
config="configs/sam2.1/sam2.1_hiera_b+.yaml",
legacy_config="configs/sam2/sam2_hiera_b+.yaml",
checkpoint_filename="sam2.1_hiera_base_plus.pt",
legacy_checkpoint_filename="sam2_hiera_base_plus.pt",
),
"sam2.1_hiera_large": SAM2Variant(
id="sam2.1_hiera_large",
label="SAM 2.1 Large",
short_label="large",
config="configs/sam2.1/sam2.1_hiera_l.yaml",
legacy_config="configs/sam2/sam2_hiera_l.yaml",
checkpoint_filename="sam2.1_hiera_large.pt",
legacy_checkpoint_filename="sam2_hiera_large.pt",
),
}
SAM2_MODEL_ALIASES = {
"sam2": DEFAULT_SAM2_MODEL_ID,
"sam2.1": DEFAULT_SAM2_MODEL_ID,
"sam2_tiny": DEFAULT_SAM2_MODEL_ID,
}
# ---------------------------------------------------------------------------
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {}
self._video_predictors: dict[str, object | None] = {}
self._model_loaded: dict[str, bool] = {}
self._video_model_loaded: dict[str, bool] = {}
self._loaded_device: dict[str, str] = {}
self._last_error: dict[str, str | None] = {}
self._video_last_error: dict[str, str | None] = {}
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def variant_ids(self) -> list[str]:
return list(SAM2_VARIANTS.keys())
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
selected = SAM2_MODEL_ALIASES.get(selected, selected)
if selected not in SAM2_VARIANTS:
raise ValueError(f"Unsupported SAM2 model: {model_id}")
return selected
def is_sam2_model(self, model_id: str | None) -> bool:
try:
self.normalize_model_id(model_id)
return True
except ValueError:
return False
def _models_dir(self) -> Path:
configured_path = Path(settings.sam_model_path)
return configured_path.parent if configured_path.parent else Path("models")
def _variant(self, model_id: str | None) -> SAM2Variant:
return SAM2_VARIANTS[self.normalize_model_id(model_id)]
def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]:
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
models_dir = self._models_dir()
candidates: list[tuple[str, str]] = []
configured_path = Path(settings.sam_model_path)
if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file():
candidates.append((settings.sam_model_config, str(configured_path)))
candidates.extend([
(variant.config, str(models_dir / variant.checkpoint_filename)),
(variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)),
])
for config, checkpoint_path in candidates:
if os.path.isfile(checkpoint_path):
return config, checkpoint_path
return candidates[0]
def _load_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 model and predictor on first use."""
variant_id = self.normalize_model_id(model_id)
if self._model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._last_error[variant_id] = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._last_error[variant_id] = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
logger.error("SAM checkpoint not found at %s", checkpoint_path)
self._model_loaded[variant_id] = True
return
try:
device = self._best_device()
model = build_sam2(
config,
checkpoint_path,
device=device,
)
self._predictors[variant_id] = SAM2ImagePredictor(model)
self._model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._last_error[variant_id] = None
logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._last_error[variant_id] = str(exc)
logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc)
self._model_loaded[variant_id] = True # Prevent repeated load attempts
def _load_video_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 video predictor on first propagation use."""
variant_id = self.normalize_model_id(model_id)
if self._video_model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._video_last_error[variant_id] = "PyTorch is not installed."
self._video_model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._video_last_error[variant_id] = "sam2 package is not installed."
self._video_model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
self._video_model_loaded[variant_id] = True
return
try:
device = self._best_device()
self._video_predictors[variant_id] = build_sam2_video_predictor(
config,
checkpoint_path,
device=device,
)
self._video_model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._video_last_error[variant_id] = None
logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._video_last_error[variant_id] = str(exc)
self._video_model_loaded[variant_id] = True
logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc)
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self, model_id: str | None = None) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_model(variant_id)
return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None
def _ensure_video_ready(self, model_id: str | None = None) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_video_model(variant_id)
return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None
def status(self, model_id: str | None = None) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
_, checkpoint_path = self._checkpoint_config(variant_id)
checkpoint_exists = os.path.isfile(checkpoint_path)
using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename
predictor = self._predictors.get(variant_id)
device = self._loaded_device.get(variant_id) or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if predictor is not None:
message = f"{variant.label} model loaded and ready."
elif available:
message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference."
if using_legacy_checkpoint:
message += " Using legacy SAM 2 checkpoint fallback."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"{variant.label} unavailable: missing {', '.join(missing)}."
last_error = self._last_error.get(variant_id)
if last_error and not predictor:
message = last_error
return {
"id": variant.id,
"label": variant.label,
"available": available,
"loaded": predictor is not None,
"device": device,
"supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": checkpoint_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
model_id: str | None,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
model_id: str | None,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_interactive(
self,
model_id: str | None,
image: np.ndarray,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = None
if box:
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
pts = None
lbls = None
if points:
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
box=bbox,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks[:1]:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:1].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict,
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_video_ready(variant_id):
raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"])
video_predictor = self._video_predictors[variant_id]
if not frame_paths:
return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
raise ValueError("source_frame_index is outside the frame sequence.")
import cv2
source_image = cv2.imread(frame_paths[source_frame_index])
if source_image is None:
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
height, width = source_image.shape[:2]
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height, seed.get("holes") or [])
if not seed_mask.any():
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
seed_mask = self._bbox_to_mask(bbox, width, height)
if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True,
offload_state_to_cpu=True,
)
video_predictor.add_new_mask(
inference_state,
frame_idx=source_frame_index,
obj_id=1,
mask=seed_mask,
)
results: dict[int, dict] = {}
def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
inference_state,
start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames,
reverse=reverse,
):
masks = out_mask_logits
if hasattr(masks, "detach"):
masks = masks.detach().cpu().numpy()
masks = np.asarray(masks)
if masks.ndim == 4:
masks = masks[:, 0]
polygons = []
holes = []
scores = []
for mask in masks:
mask_polygons, mask_holes = self._mask_to_polygon_data(mask > 0)
for polygon_index, polygon in enumerate(mask_polygons):
polygons.append(polygon)
holes.append(mask_holes[polygon_index] if polygon_index < len(mask_holes) else [])
scores.append(1.0)
results[int(out_frame_idx)] = {
"frame_index": int(out_frame_idx),
"polygons": polygons,
"holes": holes,
"scores": scores,
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
}
normalized_direction = direction.lower()
if normalized_direction in {"forward", "both"}:
collect(reverse=False)
if normalized_direction in {"backward", "both"}:
collect(reverse=True)
try:
video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001
pass
return [results[index] for index in sorted(results)]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
polygons, _holes = SAM2Engine._mask_to_polygon_data(mask)
return polygons[0] if polygons else []
@staticmethod
def _mask_to_polygon_data(mask: np.ndarray) -> tuple[list[list[list[float]]], list[list[list[list[float]]]]]:
"""Convert a binary mask to normalized outer polygons and aligned hole rings."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
if hierarchy is None:
return [], []
def contour_to_polygon(contour: np.ndarray) -> list[list[float]]:
if len(contour) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in contour]
hierarchy_rows = hierarchy[0]
outer_indices = [
index for index, row in enumerate(hierarchy_rows)
if int(row[3]) < 0 and len(contours[index]) >= 3
]
outer_indices.sort(key=lambda index: cv2.contourArea(contours[index]), reverse=True)
polygons: list[list[list[float]]] = []
holes: list[list[list[list[float]]]] = []
for outer_index in outer_indices:
outer = contour_to_polygon(contours[outer_index])
if not outer:
continue
child_index = int(hierarchy_rows[outer_index][2])
hole_group: list[list[list[float]]] = []
while child_index >= 0:
hole = contour_to_polygon(contours[child_index])
if hole:
hole_group.append(hole)
child_index = int(hierarchy_rows[child_index][0])
polygons.append(outer)
holes.append(hole_group)
return polygons, holes
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
@staticmethod
def _polygons_to_mask(
polygons: list[list[list[float]]],
width: int,
height: int,
holes_by_polygon: list[list[list[list[float]]]] | None = None,
) -> np.ndarray:
import cv2
mask = np.zeros((height, width), dtype=np.uint8)
for polygon_index, polygon in enumerate(polygons):
if len(polygon) < 3:
continue
pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in polygon
],
dtype=np.int32,
)
cv2.fillPoly(mask, [pts], 1)
holes = holes_by_polygon[polygon_index] if holes_by_polygon and polygon_index < len(holes_by_polygon) else []
for hole in holes:
if len(hole) < 3:
continue
hole_pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in hole
],
dtype=np.int32,
)
cv2.fillPoly(mask, [hole_pts], 0)
return mask.astype(bool)
@staticmethod
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
left = int(round(x * max(width - 1, 1)))
top = int(round(y * max(height - 1, 1)))
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
mask = np.zeros((height, width), dtype=bool)
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
return mask
# Singleton instance
sam_engine = SAM2Engine()