feat: 建立 SAM2 标注闭环基线

- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。

- 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。

- 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。

- 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。

- 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。

- 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。

- 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。

- 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。

- 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。

- 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。

- 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。

- 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。

- 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。

- 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。

- 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。

- 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
2026-05-01 15:26:25 +08:00
parent f020ff3b4f
commit 689a9ba283
48 changed files with 3280 additions and 176 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post(
"/predict",
response_model=PredictResponse,
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = []
scores: List[float] = []
negative_points: list[list[float]] = []
try:
if prompt_type == "point":
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
@@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
return {"polygons": polygons, "scores": scores}
@@ -161,6 +324,100 @@ def save_annotation(
return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get(
"/annotations",
response_model=List[AnnotationOut],

View File

@@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"}
def _system_load_percent() -> int:
@@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"raw_status": task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at),
}
@@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:

View File

@@ -37,6 +37,54 @@ def _mask_from_polygon(
return mask
def _annotation_z_index(annotation: Annotation) -> int:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
try:
return int(class_meta["zIndex"])
except (TypeError, ValueError):
pass
if annotation.template and annotation.template.z_index is not None:
return int(annotation.template.z_index)
return 0
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
if class_meta.get("id"):
return f"class:{class_meta['id']}"
if class_meta.get("name"):
return f"name:{class_meta['name']}"
if annotation.template_id:
return f"template:{annotation.template_id}"
return f"annotation:{annotation.id}"
def _annotation_label(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("name"):
return str(class_meta["name"])
if mask_data.get("label"):
return str(mask_data["label"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return f"Annotation {annotation.id}"
def _annotation_color(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("color"):
return str(class_meta["color"])
if mask_data.get("color"):
return str(mask_data["color"])
if annotation.template and annotation.template.color:
return str(annotation.template.color)
return "#ffffff"
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
@@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
import cv2
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
def class_value(annotation: Annotation) -> int:
key = _annotation_class_key(annotation)
if key not in class_values:
value = len(class_values) + 1
class_values[key] = value
semantic_classes.append({
"value": value,
"key": key,
"label": _annotation_label(annotation),
"color": _annotation_color(annotation),
"zIndex": _annotation_z_index(annotation),
"template_id": annotation.template_id,
})
return class_values[key]
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
for ann in annotations:
if not ann.mask_data:
continue
@@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
if ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=np.uint8)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
zf.writestr(
"semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"

View File

@@ -1,15 +1,45 @@
"""Processing task query endpoints."""
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db
from models import ProcessingTask
from models import ProcessingTask, Project
from progress_events import publish_task_progress_event
from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
PROJECT_STATUS_PENDING,
PROJECT_STATUS_READY,
TASK_ACTIVE_STATUSES,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
@@ -31,7 +61,78 @@ def list_tasks(
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return _get_task_or_404(task_id, db)
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not cancellable in status: {task.status}",
)
if task.celery_task_id:
try:
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = "任务已取消"
task.error = "Cancelled by user"
task.finished_at = _now()
if task.project:
task.project.status = _project_status_after_stop(task.project)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not retryable in status: {previous.status}",
)
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
payload.setdefault("source_type", project.source_type or "video")
payload["retry_of"] = previous.id
task = ProcessingTask(
task_type=previous.task_type,
status=TASK_STATUS_QUEUED,
progress=0,
message=f"重试任务已入队(源任务 #{previous.id}",
project_id=project.id,
payload=payload,
)
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task