feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -56,8 +56,22 @@ class SAM3Engine:
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _checkpoint_path(self) -> str | None:
|
||||
path = settings.sam3_checkpoint_path.strip()
|
||||
return path if path else None
|
||||
|
||||
def _checkpoint_exists(self) -> bool:
|
||||
path = self._checkpoint_path()
|
||||
return bool(path and os.path.isfile(path))
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
return bool(
|
||||
SAM3_PACKAGE_AVAILABLE
|
||||
and TORCH_AVAILABLE
|
||||
and self._python_ok()
|
||||
and self._gpu_ok()
|
||||
and self._checkpoint_exists()
|
||||
)
|
||||
|
||||
def _worker_path(self) -> Path:
|
||||
return Path(__file__).with_name("sam3_external_worker.py")
|
||||
@@ -98,6 +112,8 @@ class SAM3Engine:
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||
capture_output=True,
|
||||
@@ -146,7 +162,10 @@ class SAM3Engine:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model()
|
||||
self._model = build_sam3_image_model(
|
||||
checkpoint_path=self._checkpoint_path(),
|
||||
load_from_HF=False,
|
||||
)
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
@@ -170,6 +189,8 @@ class SAM3Engine:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if not self._checkpoint_exists():
|
||||
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
@@ -182,7 +203,7 @@ class SAM3Engine:
|
||||
if self._processor is not None:
|
||||
message = "SAM 3 model loaded and ready."
|
||||
elif external_ready:
|
||||
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
elif external_status.get("message") and not self._can_load():
|
||||
message = str(external_status["message"])
|
||||
return {
|
||||
@@ -191,11 +212,11 @@ class SAM3Engine:
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||
"supports": ["semantic"],
|
||||
"supports": ["semantic", "box", "video_track"],
|
||||
"message": message,
|
||||
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||
"cuda_required": True,
|
||||
@@ -203,7 +224,43 @@ class SAM3Engine:
|
||||
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||
}
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def _predict_external(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt_type: str,
|
||||
*,
|
||||
text: str = "",
|
||||
box: list[float] | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
@@ -217,8 +274,11 @@ class SAM3Engine:
|
||||
json.dumps(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"prompt_type": prompt_type,
|
||||
"text": text.strip(),
|
||||
"box": box,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
@@ -227,6 +287,8 @@ class SAM3Engine:
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
@@ -250,6 +312,72 @@ class SAM3Engine:
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("polygons", []), payload.get("scores", [])
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "semantic", text=text)
|
||||
|
||||
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "box", box=box)
|
||||
|
||||
def _propagate_video_external(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
if not frame_paths:
|
||||
return []
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
|
||||
request_path = Path(tmpdir) / "request.json"
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt_type": "video_track",
|
||||
"frame_dir": str(Path(frame_paths[0]).parent),
|
||||
"source_frame_index": source_frame_index,
|
||||
"seed": seed,
|
||||
"direction": direction,
|
||||
"max_frames": max_frames,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("frames", [])
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
@@ -263,29 +391,37 @@ class SAM3Engine:
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
|
||||
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_box_external(image, box)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=self._xyxy_to_cxcywh(box),
|
||||
label=True,
|
||||
)
|
||||
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
|
||||
Reference in New Issue
Block a user