feat: 建立 SAM2 标注闭环基线

- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。

- 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。

- 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。

- 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。

- 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。

- 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。

- 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。

- 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。

- 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。

- 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。

- 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。

- 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。

- 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。

- 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。

- 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。

- 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
2026-05-01 15:26:25 +08:00
parent f020ff3b4f
commit 689a9ba283
48 changed files with 3280 additions and 176 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post(
"/predict",
response_model=PredictResponse,
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = []
scores: List[float] = []
negative_points: list[list[float]] = []
try:
if prompt_type == "point":
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
@@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
return {"polygons": polygons, "scores": scores}
@@ -161,6 +324,100 @@ def save_annotation(
return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get(
"/annotations",
response_model=List[AnnotationOut],