feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
@@ -15,6 +17,8 @@ from schemas import (
|
||||
AiRuntimeStatus,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
PropagateRequest,
|
||||
PropagateResponse,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
@@ -66,6 +70,48 @@ def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||||
]
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
if direction == "both":
|
||||
before = (count - 1) // 2
|
||||
after = count - 1 - before
|
||||
start = max(0, source_position - before)
|
||||
end = min(len(frames), source_position + after + 1)
|
||||
while end - start < count and start > 0:
|
||||
start -= 1
|
||||
while end - start < count and end < len(frames):
|
||||
end += 1
|
||||
return frames[start:end], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
path = directory / f"frame_{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||||
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||||
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||||
@@ -184,6 +230,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
@@ -246,6 +293,51 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "interactive":
|
||||
prompt = payload.prompt_data
|
||||
if not isinstance(prompt, dict):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
|
||||
box = prompt.get("box")
|
||||
points = prompt.get("points") or []
|
||||
labels = prompt.get("labels")
|
||||
if box is not None and (not isinstance(box, list) or len(box) != 4):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
|
||||
if not isinstance(points, list):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
|
||||
if not box and len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
negative_points = [
|
||||
point for point, label in zip(points, labels) if label == 0
|
||||
]
|
||||
inference_image = image
|
||||
inference_box = box
|
||||
inference_points = points
|
||||
crop_bounds = None
|
||||
if options.get("crop_to_prompt"):
|
||||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||||
crop_points = list(points)
|
||||
if box:
|
||||
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
|
||||
crop_bounds = _crop_bounds_from_points(crop_points, margin)
|
||||
inference_image = _crop_image(image, crop_bounds)
|
||||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||||
if box:
|
||||
inference_box = [
|
||||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||||
]
|
||||
polygons, scores = sam_registry.predict_interactive(
|
||||
payload.model,
|
||||
inference_image,
|
||||
inference_box,
|
||||
inference_points,
|
||||
labels,
|
||||
)
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
||||
@@ -276,6 +368,124 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/propagate",
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
SAM 3 uses the external Python 3.12 video tracker with the seed bbox.
|
||||
"""
|
||||
direction = payload.direction.lower()
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
bbox = seed.get("bbox")
|
||||
points = seed.get("points") or []
|
||||
if not polygons and not bbox and not points:
|
||||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
if len(selected_frames) == 0:
|
||||
raise HTTPException(status_code=400, detail="No frames available for propagation")
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
payload.model,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Video propagation failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
|
||||
|
||||
created: list[Annotation] = []
|
||||
if payload.save_annotations:
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.model)
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not payload.include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=payload.project_id,
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygon_bbox(polygon),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
|
||||
return {
|
||||
"model": sam_registry.normalize_model_id(payload.model),
|
||||
"direction": direction,
|
||||
"source_frame_id": source_frame.id,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"annotations": created,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
@@ -169,6 +169,9 @@ async def upload_dicom_batch(
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
parse_fps: Optional[float] = Query(None, gt=0, le=120),
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
@@ -184,14 +187,21 @@ def parse_media(
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
effective_parse_fps = parse_fps or project.parse_fps or 30.0
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={"source_type": effective_source},
|
||||
payload={
|
||||
"source_type": effective_source,
|
||||
"parse_fps": effective_parse_fps,
|
||||
"max_frames": max_frames,
|
||||
"target_width": target_width,
|
||||
},
|
||||
)
|
||||
project.parse_fps = effective_parse_fps
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
Reference in New Issue
Block a user