feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
@@ -38,9 +39,12 @@ class SAM2Engine:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._video_predictor = None
|
||||
self._model_loaded = False
|
||||
self._video_model_loaded = False
|
||||
self._loaded_device: str | None = None
|
||||
self._last_error: str | None = None
|
||||
self._video_last_error: str | None = None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -85,6 +89,40 @@ class SAM2Engine:
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _load_video_model(self) -> None:
|
||||
"""Load the SAM 2 video predictor on first propagation use."""
|
||||
if self._video_model_loaded:
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._video_last_error = "PyTorch is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not SAM2_AVAILABLE:
|
||||
self._video_last_error = "sam2 package is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
self._video_predictor = build_sam2_video_predictor(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device=device,
|
||||
)
|
||||
self._video_model_loaded = True
|
||||
self._loaded_device = device
|
||||
self._video_last_error = None
|
||||
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._video_last_error = str(exc)
|
||||
self._video_model_loaded = True
|
||||
logger.error("Failed to load SAM 2 video predictor: %s", exc)
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -95,6 +133,11 @@ class SAM2Engine:
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
def _ensure_video_ready(self) -> bool:
|
||||
"""Ensure the video predictor is loaded; return whether it is usable."""
|
||||
self._load_video_model()
|
||||
return SAM2_AVAILABLE and self._video_predictor is not None
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
checkpoint_exists = os.path.isfile(settings.sam_model_path)
|
||||
@@ -121,7 +164,7 @@ class SAM2Engine:
|
||||
"available": available,
|
||||
"loaded": self._predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "auto"],
|
||||
"supports": ["point", "box", "interactive", "auto", "propagate"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
@@ -221,6 +264,52 @@ class SAM2Engine:
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run combined box and point prompt segmentation for refinement."""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
bbox = None
|
||||
if box:
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
pts = None
|
||||
lbls = None
|
||||
if points:
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 interactive prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
@@ -260,6 +349,89 @@ class SAM2Engine:
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict,
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
|
||||
if not self._ensure_video_ready():
|
||||
raise RuntimeError(self._video_last_error or self.status()["message"])
|
||||
if not frame_paths:
|
||||
return []
|
||||
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
|
||||
raise ValueError("source_frame_index is outside the frame sequence.")
|
||||
|
||||
import cv2
|
||||
|
||||
source_image = cv2.imread(frame_paths[source_frame_index])
|
||||
if source_image is None:
|
||||
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
|
||||
height, width = source_image.shape[:2]
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
|
||||
if not seed_mask.any():
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
seed_mask = self._bbox_to_mask(bbox, width, height)
|
||||
if not seed_mask.any():
|
||||
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
|
||||
|
||||
inference_state = self._video_predictor.init_state(
|
||||
video_path=os.path.dirname(frame_paths[0]),
|
||||
offload_video_to_cpu=True,
|
||||
offload_state_to_cpu=True,
|
||||
)
|
||||
self._video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
frame_idx=source_frame_index,
|
||||
obj_id=1,
|
||||
mask=seed_mask,
|
||||
)
|
||||
|
||||
results: dict[int, dict] = {}
|
||||
|
||||
def collect(reverse: bool) -> None:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
||||
inference_state,
|
||||
start_frame_idx=source_frame_index,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=reverse,
|
||||
):
|
||||
masks = out_mask_logits
|
||||
if hasattr(masks, "detach"):
|
||||
masks = masks.detach().cpu().numpy()
|
||||
masks = np.asarray(masks)
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
polygons = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
polygon = self._mask_to_polygon(mask > 0)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
scores.append(1.0)
|
||||
results[int(out_frame_idx)] = {
|
||||
"frame_index": int(out_frame_idx),
|
||||
"polygons": polygons,
|
||||
"scores": scores,
|
||||
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
|
||||
}
|
||||
|
||||
normalized_direction = direction.lower()
|
||||
if normalized_direction in {"forward", "both"}:
|
||||
collect(reverse=False)
|
||||
if normalized_direction in {"backward", "both"}:
|
||||
collect(reverse=True)
|
||||
|
||||
try:
|
||||
self._video_predictor.reset_state(inference_state)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return [results[index] for index in sorted(results)]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -292,6 +464,38 @@ class SAM2Engine:
|
||||
]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in polygon
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
|
||||
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
left = int(round(x * max(width - 1, 1)))
|
||||
top = int(round(y * max(height - 1, 1)))
|
||||
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
|
||||
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
|
||||
mask = np.zeros((height, width), dtype=bool)
|
||||
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
|
||||
return mask
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
|
||||
Reference in New Issue
Block a user