feat: 完善分割工作区传播与交互闭环

功能增加:新增后端传播任务执行器,支持异步自动传播、传播进度、结果统计、取消/重试状态同步。

功能增加:传播请求支持指定 SAM2.1 tiny/small/base+/large 权重,并记录 seed mask、source annotation 和传播范围。

功能增加:传播逻辑增加 seed 签名,未变化的 mask 二次传播会跳过,已变化的 mask 会先清理旧自动传播结果再重新生成,避免重复重叠。

功能增加:工作区增加传播范围二次选择、传播进度提示、人工/AI 标注帧红色标识、自动传播帧蓝色标识和当前帧双层边框。

功能增加:新增临时提示组件,让工具操作提示自动消失且不阻塞后续操作。

功能增加:补充项目删除、模板删除、任务失败详情、任务取消/重试等前后端联动状态。

功能增加:新增安装部署文档,补充当前需求冻结、设计冻结、接口契约、测试计划和 AGENTS/README 项目说明。

Bugfix:修复自动传播接口 404、传播后看不到任务进度、传播结果重复堆叠和已编辑帧提示不清晰的问题。

Bugfix:修复 AI 分割框选/点选交互、单候选 mask、删除选点、工作区保存与候选 mask 推送相关问题。

Bugfix:修复 Canvas 多边形顶点拖动告警、工具栏提示缺失、项目库 FPS 展示和若干 UI 文案/可用性问题。

测试:补充 AI 分割、Canvas、Dashboard、FrameTimeline、ProjectLibrary、TemplateRegistry、ToolsPalette、VideoWorkspace、API 和后端任务/AI/dashboard 测试。

验证:npm run lint;npm run test:run;python -m pytest backend/tests -q。
This commit is contained in:
2026-05-02 05:17:18 +08:00
parent b6a276cb8d
commit c8c59f7ede
38 changed files with 2852 additions and 212 deletions

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Frame, Template, Annotation
from models import Project, Frame, Template, Annotation, ProcessingTask
from schemas import (
AiRuntimeStatus,
MaskAnalysisRequest,
@@ -21,10 +21,15 @@ from schemas import (
PredictResponse,
PropagateRequest,
PropagateResponse,
PropagateTaskRequest,
ProcessingTaskOut,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from progress_events import publish_task_progress_event
from statuses import TASK_STATUS_QUEUED
from worker_tasks import propagate_project_masks
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
@@ -586,6 +591,66 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
}
@router.post(
"/propagate/task",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Queue a background video propagation task",
)
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
"""Queue multiple seed/direction propagation steps as one background task."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
if not payload.steps:
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
try:
model_id = sam_registry.normalize_model_id(payload.model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
for step in payload.steps:
direction = step.direction.lower()
if direction not in {"forward", "backward"}:
raise HTTPException(status_code=400, detail="direction must be forward or backward")
seed = step.seed.model_dump(exclude_none=True)
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
task_payload = payload.model_dump(exclude_none=True)
task_payload["model"] = model_id
task = ProcessingTask(
task_type="propagate_masks",
status=TASK_STATUS_QUEUED,
progress=0,
message="自动传播任务已入队",
project_id=payload.project_id,
payload=task_payload,
)
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = propagate_project_masks.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id)
return task
@router.post(
"/auto",
response_model=PredictResponse,

View File

@@ -36,6 +36,7 @@ def _iso_or_none(value: datetime | None) -> str | None:
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
result = task.result or {}
return {
"id": f"task-{task.id}",
"task_id": task.id,
@@ -44,7 +45,7 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
"progress": task.progress,
"status": task.message or task.status,
"raw_status": task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at),
}

View File

@@ -21,7 +21,7 @@ from statuses import (
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media
from worker_tasks import parse_project_media, propagate_project_masks
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
@@ -109,7 +109,8 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
is_propagation_task = previous.task_type == "propagate_masks"
if not is_propagation_task and not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
@@ -124,13 +125,14 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
project_id=project.id,
payload=payload,
)
project.status = PROJECT_STATUS_PARSING
if not is_propagation_task:
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)

View File

@@ -218,6 +218,8 @@ class PropagationSeed(BaseModel):
color: Optional[str] = None
class_metadata: Optional[dict[str, Any]] = None
template_id: Optional[int] = None
source_mask_id: Optional[str] = None
source_annotation_id: Optional[int] = None
class PropagateRequest(BaseModel):
@@ -240,6 +242,21 @@ class PropagateResponse(BaseModel):
annotations: list[AnnotationOut]
class PropagateTaskStep(BaseModel):
seed: PropagationSeed
direction: str = "forward"
max_frames: int = 30
class PropagateTaskRequest(BaseModel):
project_id: int
frame_id: int
model: Optional[str] = "sam2.1_hiera_tiny"
steps: list[PropagateTaskStep]
include_source: bool = False
save_annotations: bool = True
class AiModelStatus(BaseModel):
id: str
label: str

View File

@@ -0,0 +1,512 @@
"""Background SAM video propagation runner used by Celery workers."""
import hashlib
import json
import logging
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import download_file
from models import Annotation, Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.sam_registry import ModelUnavailableError, sam_registry
from statuses import (
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class PropagationTaskCancelled(RuntimeError):
"""Raised internally when a persisted propagation task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise PropagationTaskCancelled("Task was cancelled")
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _seed_signature(seed: dict[str, Any]) -> str:
"""Return a stable signature for seed geometry and semantic attrs."""
signature_payload = {
"polygons": seed.get("polygons") or [],
"bbox": seed.get("bbox") or [],
"points": seed.get("points") or [],
"labels": seed.get("labels") or [],
"label": seed.get("label"),
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
}
return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest()
def _seed_key(seed: dict[str, Any]) -> str:
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
return f"annotation:{source_annotation_id}"
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
return f"mask:{source_mask_id}"
class_metadata = seed.get("class_metadata") or {}
class_id = class_metadata.get("id") or class_metadata.get("name")
return _stable_json({
"template_id": seed.get("template_id"),
"class_id": class_id,
"label": seed.get("label"),
"color": seed.get("color"),
})
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
and previous_class_id == class_id
)
def _is_propagation_annotation(
annotation: Annotation,
model_id: str,
source_frame: Frame,
seed_key: str,
seed: dict[str, Any],
) -> bool:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if source != f"{model_id}_propagation":
return False
if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id):
return False
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key is not None:
return previous_seed_key == seed_key
return _legacy_seed_matches(mask_data, seed)
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
previous_direction = mask_data.get("propagation_direction")
return previous_direction in {None, direction}
def _prepare_seed_propagation(
db: Session,
*,
payload: dict[str, Any],
model_id: str,
source_frame: Frame,
seed: dict[str, Any],
direction: str,
) -> dict[str, Any]:
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.all()
)
matching = [
annotation for annotation in previous_annotations
if _is_propagation_annotation(annotation, model_id, source_frame, seed_key, seed)
and _direction_matches(annotation.mask_data or {}, direction)
]
if matching and all((annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature for annotation in matching):
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
deleted_count = 0
if matching:
for annotation in matching:
db.delete(annotation)
deleted_count += 1
db.commit()
return {
"skip": False,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": deleted_count,
}
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
path = directory / f"{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _save_propagated_annotations(
db: Session,
*,
payload: dict[str, Any],
selected_frames: list[Frame],
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> list[Annotation]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.get("model"))
include_source = bool(payload.get("include_source", False))
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
direction = str(payload.get("current_direction") or "")
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
scores = frame_result.get("scores") or []
for polygon_index, polygon in enumerate(result_polygons):
if len(polygon) < 3:
continue
annotation = Annotation(
project_id=int(payload["project_id"]),
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"propagation_seed_key": seed_key,
"propagation_seed_signature": seed_signature,
"propagation_direction": direction,
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return created
def _run_one_step(
db: Session,
*,
payload: dict[str, Any],
frames: list[Frame],
source_frame: Frame,
source_position: int,
step: dict[str, Any],
) -> dict[str, Any]:
direction = str(step.get("direction") or "forward").lower()
if direction not in {"forward", "backward"}:
raise ValueError("direction must be forward or backward")
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
seed = step.get("seed") or {}
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise ValueError("Propagation requires seed polygons, bbox, or points")
model_id = sam_registry.normalize_model_id(payload.get("model"))
seed_state = _prepare_seed_propagation(
db,
payload=payload,
model_id=model_id,
source_frame=source_frame,
seed=seed,
direction=direction,
)
if seed_state["skip"]:
return {
"model": model_id,
"direction": direction,
"processed_frame_count": 0,
"created_annotation_count": 0,
"deleted_annotation_count": 0,
"skipped_seed_count": 1,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
model_id,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
save_payload = {**payload, "current_direction": direction}
created = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
source_frame=source_frame,
propagated=propagated,
seed=seed,
)
return {
"model": model_id,
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]),
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
"""Run one queued SAM propagation task and update persisted progress."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
payload = task.payload or {}
project_id = int(payload.get("project_id") or task.project_id or 0)
source_frame_id = int(payload.get("frame_id") or 0)
try:
model_id = sam_registry.normalize_model_id(payload.get("model"))
except ValueError as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
raise ValueError(f"Project not found: {project_id}")
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
if not source_frame:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
raise ValueError(f"Frame not found: {source_frame_id}")
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
raise ValueError("Source frame is not in project frame sequence")
steps = payload.get("steps") or []
if not steps:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
raise ValueError("Propagation task has no steps")
_ensure_not_cancelled(db, task)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
step_results: list[dict[str, Any]] = []
created_count = 0
processed_count = 0
deleted_count = 0
skipped_count = 0
total_steps = len(steps)
try:
for index, step in enumerate(steps, start=1):
_ensure_not_cancelled(db, task)
seed_label = (step.get("seed") or {}).get("label") or "mask"
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
progress_before = 5 + int(((index - 1) / total_steps) * 90)
_set_task_state(
db,
task,
progress=progress_before,
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index - 1,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = _run_one_step(
db,
payload=payload,
frames=frames,
source_frame=source_frame,
source_position=source_position,
step=step,
)
step_results.append(result)
created_count += int(result["created_annotation_count"])
processed_count += int(result["processed_frame_count"])
deleted_count += int(result.get("deleted_annotation_count") or 0)
skipped_count += int(result.get("skipped_seed_count") or 0)
_set_task_state(
db,
task,
progress=5 + int((index / total_steps) * 90),
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = {
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": total_steps,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="自动传播完成" if created_count > 0 else (
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
),
result=result,
finished=True,
)
return result
except PropagationTaskCancelled:
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task.id)
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise

View File

@@ -1,5 +1,8 @@
import numpy as np
import cv2
from pathlib import Path
from models import Annotation, ProcessingTask
from services.propagation_task_runner import run_propagate_project_task
def _create_project_and_frame(client):
@@ -294,6 +297,245 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
assert len(listing.json()) == 1
def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Queued Propagation"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
class FakeAsyncResult:
id = "celery-propagate-1"
queued = []
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
response = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam2.1_hiera_tiny",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"label": "胆囊",
},
}],
})
assert response.status_code == 202
body = response.json()
assert body["task_type"] == "propagate_masks"
assert body["status"] == "queued"
assert body["celery_task_id"] == "celery-propagate-1"
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
assert queued == [body["id"]]
def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Model"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
class FakeAsyncResult:
id = "celery-propagate-model"
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult())
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
response = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam2",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
},
}],
})
assert response.status_code == 202
assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny"
unsupported = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam3",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
},
}],
})
assert unsupported.status_code == 400
assert "Unsupported model" in unsupported.json()["detail"]
def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Worker"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(2)
]
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊"},
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
published = []
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress)))
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"]
return [
{"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]},
{"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]},
]
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
result = run_propagate_project_task(db_session, task.id)
db_session.refresh(task)
assert task.status == "success"
assert task.progress == 100
assert task.result["model"] == "sam2.1_hiera_tiny"
assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny"
assert result["created_annotation_count"] == 1
assert result["processed_frame_count"] == 2
assert published[0][0] == "running"
assert published[-1] == ("success", 100)
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.json()[0]["frame_id"] == frames[1]["id"]
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(2)
]
def make_task(seed_polygon):
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [seed_polygon],
"label": "胆囊",
"color": "#ff0000",
"source_annotation_id": 7,
"source_mask_id": "annotation-7",
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
return task
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]]
replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]]
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
propagate_calls = []
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
propagate_calls.append(seed["polygons"][0])
output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon
return [
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
]
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
assert first_result["created_annotation_count"] == 1
assert len(propagate_calls) == 1
unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
assert unchanged_result["created_annotation_count"] == 0
assert unchanged_result["skipped_seed_count"] == 1
assert len(propagate_calls) == 1
assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1
changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id)
assert changed_result["created_annotation_count"] == 1
assert changed_result["deleted_annotation_count"] == 1
assert len(propagate_calls) == 2
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
assert len(annotations) == 1
assert annotations[0].mask_data["polygons"] == [replacement_output_polygon]
assert annotations[0].mask_data["source_annotation_id"] == 7
def test_predict_validation_errors(client, monkeypatch):
project, _, _ = _create_project_and_frame(client)

View File

@@ -110,3 +110,31 @@ def test_dashboard_overview_keeps_recent_success_tasks_in_progress_list(client,
"updated_at": body["tasks"][0]["updated_at"],
},
]
def test_dashboard_overview_uses_processed_frame_count_for_propagation_tasks(client, db_session):
from models import ProcessingTask
project = client.post("/api/projects", json={
"name": "Propagation Project",
"status": "ready",
}).json()
task = ProcessingTask(
task_type="propagate_masks",
status="running",
progress=45,
message="向后传播 胆囊 (1/2)",
project_id=project["id"],
payload={"project_id": project["id"]},
result={"processed_frame_count": 8, "created_annotation_count": 3},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
response = client.get("/api/dashboard/overview")
assert response.status_code == 200
body = response.json()
assert body["tasks"][0]["task_id"] == task.id
assert body["tasks"][0]["frame_count"] == 8

View File

@@ -84,6 +84,42 @@ def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
def test_retry_task_dispatches_propagation_worker_without_media_requirement(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Retry Propagation"}).json()
task = ProcessingTask(
task_type="propagate_masks",
status="failed",
progress=100,
message="自动传播失败",
error="model unavailable",
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": 1,
"steps": [],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
class FakeAsyncResult:
id = "celery-propagation-retry"
queued = []
monkeypatch.setattr("routers.tasks.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: None)
response = client.post(f"/api/tasks/{task.id}/retry")
assert response.status_code == 202
body = response.json()
assert body["task_type"] == "propagate_masks"
assert body["celery_task_id"] == "celery-propagation-retry"
assert queued == [body["id"]]
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
def test_task_actions_reject_invalid_states(client, db_session):
project = client.post("/api/projects", json={
"name": "Done",

View File

@@ -5,6 +5,7 @@ import logging
from celery_app import celery_app
from database import SessionLocal
from services.media_task_runner import run_parse_media_task
from services.propagation_task_runner import run_propagate_project_task
logger = logging.getLogger(__name__)
@@ -20,3 +21,16 @@ def parse_project_media(task_id: int) -> dict:
raise exc
finally:
db.close()
@celery_app.task(name="ai.propagate_project")
def propagate_project_masks(task_id: int) -> dict:
"""Run SAM video propagation for one queued task."""
db = SessionLocal()
try:
return run_propagate_project_task(db, task_id)
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task_id)
raise exc
finally:
db.close()