feat: 完善分割工作区传播与交互闭环
功能增加:新增后端传播任务执行器,支持异步自动传播、传播进度、结果统计、取消/重试状态同步。 功能增加:传播请求支持指定 SAM2.1 tiny/small/base+/large 权重,并记录 seed mask、source annotation 和传播范围。 功能增加:传播逻辑增加 seed 签名,未变化的 mask 二次传播会跳过,已变化的 mask 会先清理旧自动传播结果再重新生成,避免重复重叠。 功能增加:工作区增加传播范围二次选择、传播进度提示、人工/AI 标注帧红色标识、自动传播帧蓝色标识和当前帧双层边框。 功能增加:新增临时提示组件,让工具操作提示自动消失且不阻塞后续操作。 功能增加:补充项目删除、模板删除、任务失败详情、任务取消/重试等前后端联动状态。 功能增加:新增安装部署文档,补充当前需求冻结、设计冻结、接口契约、测试计划和 AGENTS/README 项目说明。 Bugfix:修复自动传播接口 404、传播后看不到任务进度、传播结果重复堆叠和已编辑帧提示不清晰的问题。 Bugfix:修复 AI 分割框选/点选交互、单候选 mask、删除选点、工作区保存与候选 mask 推送相关问题。 Bugfix:修复 Canvas 多边形顶点拖动告警、工具栏提示缺失、项目库 FPS 展示和若干 UI 文案/可用性问题。 测试:补充 AI 分割、Canvas、Dashboard、FrameTimeline、ProjectLibrary、TemplateRegistry、ToolsPalette、VideoWorkspace、API 和后端任务/AI/dashboard 测试。 验证:npm run lint;npm run test:run;python -m pytest backend/tests -q。
This commit is contained in:
@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -21,10 +21,15 @@ from schemas import (
|
||||
PredictResponse,
|
||||
PropagateRequest,
|
||||
PropagateResponse,
|
||||
PropagateTaskRequest,
|
||||
ProcessingTaskOut,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
)
|
||||
from progress_events import publish_task_progress_event
|
||||
from statuses import TASK_STATUS_QUEUED
|
||||
from worker_tasks import propagate_project_masks
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -586,6 +591,66 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/propagate/task",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.model)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
for step in payload.steps:
|
||||
direction = step.direction.lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward or backward")
|
||||
seed = step.seed.model_dump(exclude_none=True)
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
task_payload = payload.model_dump(exclude_none=True)
|
||||
task_payload["model"] = model_id
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="自动传播任务已入队",
|
||||
project_id=payload.project_id,
|
||||
payload=task_payload,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = propagate_project_masks.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id)
|
||||
return task
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
|
||||
@@ -36,6 +36,7 @@ def _iso_or_none(value: datetime | None) -> str | None:
|
||||
|
||||
|
||||
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
result = task.result or {}
|
||||
return {
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
@@ -44,7 +45,7 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"raw_status": task.status,
|
||||
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
||||
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
|
||||
"error": task.error,
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ from statuses import (
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_QUEUED,
|
||||
)
|
||||
from worker_tasks import parse_project_media
|
||||
from worker_tasks import parse_project_media, propagate_project_masks
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -109,7 +109,8 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
if not project.video_path:
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
if not is_propagation_task and not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
payload = dict(previous.payload or {})
|
||||
@@ -124,13 +125,14 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
project_id=project.id,
|
||||
payload=payload,
|
||||
)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
if not is_propagation_task:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
@@ -218,6 +218,8 @@ class PropagationSeed(BaseModel):
|
||||
color: Optional[str] = None
|
||||
class_metadata: Optional[dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
source_mask_id: Optional[str] = None
|
||||
source_annotation_id: Optional[int] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
@@ -240,6 +242,21 @@ class PropagateResponse(BaseModel):
|
||||
annotations: list[AnnotationOut]
|
||||
|
||||
|
||||
class PropagateTaskStep(BaseModel):
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
|
||||
|
||||
class PropagateTaskRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2.1_hiera_tiny"
|
||||
steps: list[PropagateTaskStep]
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
512
backend/services/propagation_task_runner.py
Normal file
512
backend/services/propagation_task_runner.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Background SAM video propagation runner used by Celery workers."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import download_file
|
||||
from models import Annotation, Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
from statuses import (
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PropagationTaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted propagation task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise PropagationTaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return min(max(float(value), 0.0), 1.0)
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _stable_json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"""Return a stable signature for seed geometry and semantic attrs."""
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
"points": seed.get("points") or [],
|
||||
"labels": seed.get("labels") or [],
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
"class_metadata": seed.get("class_metadata") or {},
|
||||
"template_id": seed.get("template_id"),
|
||||
}
|
||||
return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _seed_key(seed: dict[str, Any]) -> str:
|
||||
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
return f"annotation:{source_annotation_id}"
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
return f"mask:{source_mask_id}"
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return _stable_json({
|
||||
"template_id": seed.get("template_id"),
|
||||
"class_id": class_id,
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
})
|
||||
|
||||
|
||||
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match for propagation annotations created before seed keys."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
and previous_class_id == class_id
|
||||
)
|
||||
|
||||
|
||||
def _is_propagation_annotation(
|
||||
annotation: Annotation,
|
||||
model_id: str,
|
||||
source_frame: Frame,
|
||||
seed_key: str,
|
||||
seed: dict[str, Any],
|
||||
) -> bool:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if source != f"{model_id}_propagation":
|
||||
return False
|
||||
if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id):
|
||||
return False
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key is not None:
|
||||
return previous_seed_key == seed_key
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
|
||||
|
||||
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
||||
previous_direction = mask_data.get("propagation_direction")
|
||||
return previous_direction in {None, direction}
|
||||
|
||||
|
||||
def _prepare_seed_propagation(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
model_id: str,
|
||||
source_frame: Frame,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
) -> dict[str, Any]:
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.all()
|
||||
)
|
||||
matching = [
|
||||
annotation for annotation in previous_annotations
|
||||
if _is_propagation_annotation(annotation, model_id, source_frame, seed_key, seed)
|
||||
and _direction_matches(annotation.mask_data or {}, direction)
|
||||
]
|
||||
if matching and all((annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature for annotation in matching):
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
|
||||
deleted_count = 0
|
||||
if matching:
|
||||
for annotation in matching:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"skip": False,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
|
||||
path = directory / f"{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _save_propagated_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
selected_frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
propagated: list[dict[str, Any]],
|
||||
seed: dict[str, Any],
|
||||
) -> list[Annotation]:
|
||||
created: list[Annotation] = []
|
||||
if payload.get("save_annotations", True) is False:
|
||||
return created
|
||||
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
direction = str(payload.get("current_direction") or "")
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=int(payload["project_id"]),
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"propagation_seed_key": seed_key,
|
||||
"propagation_seed_signature": seed_signature,
|
||||
"propagation_direction": direction,
|
||||
"source_annotation_id": source_annotation_id,
|
||||
"source_mask_id": source_mask_id,
|
||||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygon_bbox(polygon),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
return created
|
||||
|
||||
|
||||
def _run_one_step(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
source_position: int,
|
||||
step: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
direction = str(step.get("direction") or "forward").lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise ValueError("direction must be forward or backward")
|
||||
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
|
||||
seed = step.get("seed") or {}
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
seed_state = _prepare_seed_propagation(
|
||||
db,
|
||||
payload=payload,
|
||||
model_id=model_id,
|
||||
source_frame=source_frame,
|
||||
seed=seed,
|
||||
direction=direction,
|
||||
)
|
||||
if seed_state["skip"]:
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": 0,
|
||||
"created_annotation_count": 0,
|
||||
"deleted_annotation_count": 0,
|
||||
"skipped_seed_count": 1,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
model_id,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
|
||||
save_payload = {**payload, "current_direction": direction}
|
||||
created = _save_propagated_annotations(
|
||||
db,
|
||||
payload=save_payload,
|
||||
selected_frames=selected_frames,
|
||||
source_frame=source_frame,
|
||||
propagated=propagated,
|
||||
seed=seed,
|
||||
)
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]),
|
||||
"skipped_seed_count": 0,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
|
||||
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Run one queued SAM propagation task and update persisted progress."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
|
||||
|
||||
payload = task.payload or {}
|
||||
project_id = int(payload.get("project_id") or task.project_id or 0)
|
||||
source_frame_id = int(payload.get("frame_id") or 0)
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
except ValueError as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
|
||||
raise ValueError(f"Project not found: {project_id}")
|
||||
|
||||
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
|
||||
if not source_frame:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
|
||||
raise ValueError(f"Frame not found: {source_frame_id}")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
|
||||
raise ValueError("Source frame is not in project frame sequence")
|
||||
|
||||
steps = payload.get("steps") or []
|
||||
if not steps:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
|
||||
raise ValueError("Propagation task has no steps")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
|
||||
|
||||
step_results: list[dict[str, Any]] = []
|
||||
created_count = 0
|
||||
processed_count = 0
|
||||
deleted_count = 0
|
||||
skipped_count = 0
|
||||
total_steps = len(steps)
|
||||
|
||||
try:
|
||||
for index, step in enumerate(steps, start=1):
|
||||
_ensure_not_cancelled(db, task)
|
||||
seed_label = (step.get("seed") or {}).get("label") or "mask"
|
||||
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
|
||||
progress_before = 5 + int(((index - 1) / total_steps) * 90)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=progress_before,
|
||||
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index - 1,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = _run_one_step(
|
||||
db,
|
||||
payload=payload,
|
||||
frames=frames,
|
||||
source_frame=source_frame,
|
||||
source_position=source_position,
|
||||
step=step,
|
||||
)
|
||||
step_results.append(result)
|
||||
created_count += int(result["created_annotation_count"])
|
||||
processed_count += int(result["processed_frame_count"])
|
||||
deleted_count += int(result.get("deleted_annotation_count") or 0)
|
||||
skipped_count += int(result.get("skipped_seed_count") or 0)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=5 + int((index / total_steps) * 90),
|
||||
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": total_steps,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="自动传播完成" if created_count > 0 else (
|
||||
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
|
||||
),
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
return result
|
||||
except PropagationTaskCancelled:
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
|
||||
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task.id)
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
@@ -1,5 +1,8 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from models import Annotation, ProcessingTask
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
@@ -294,6 +297,245 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
assert len(listing.json()) == 1
|
||||
|
||||
|
||||
def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Queued Propagation"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagate-1"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"label": "胆囊",
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert response.status_code == 202
|
||||
body = response.json()
|
||||
assert body["task_type"] == "propagate_masks"
|
||||
assert body["status"] == "queued"
|
||||
assert body["celery_task_id"] == "celery-propagate-1"
|
||||
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
|
||||
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
|
||||
assert queued == [body["id"]]
|
||||
|
||||
|
||||
def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Model"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagate-model"
|
||||
|
||||
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam2",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny"
|
||||
|
||||
unsupported = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam3",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert unsupported.status_code == 400
|
||||
assert "Unsupported model" in unsupported.json()["detail"]
|
||||
|
||||
|
||||
def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Worker"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(2)
|
||||
]
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": False,
|
||||
"save_annotations": True,
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊"},
|
||||
},
|
||||
}],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
published = []
|
||||
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
||||
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress)))
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"]
|
||||
return [
|
||||
{"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]},
|
||||
{"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
result = run_propagate_project_task(db_session, task.id)
|
||||
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
assert task.result["model"] == "sam2.1_hiera_tiny"
|
||||
assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny"
|
||||
assert result["created_annotation_count"] == 1
|
||||
assert result["processed_frame_count"] == 2
|
||||
assert published[0][0] == "running"
|
||||
assert published[-1] == ("success", 100)
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
||||
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(2)
|
||||
]
|
||||
|
||||
def make_task(seed_polygon):
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": False,
|
||||
"save_annotations": True,
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [seed_polygon],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"source_annotation_id": 7,
|
||||
"source_mask_id": "annotation-7",
|
||||
},
|
||||
}],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
return task
|
||||
|
||||
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
||||
first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]]
|
||||
replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
||||
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
||||
propagate_calls = []
|
||||
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
propagate_calls.append(seed["polygons"][0])
|
||||
output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon
|
||||
return [
|
||||
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
||||
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
||||
assert first_result["created_annotation_count"] == 1
|
||||
assert len(propagate_calls) == 1
|
||||
|
||||
unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
||||
assert unchanged_result["created_annotation_count"] == 0
|
||||
assert unchanged_result["skipped_seed_count"] == 1
|
||||
assert len(propagate_calls) == 1
|
||||
assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1
|
||||
|
||||
changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id)
|
||||
assert changed_result["created_annotation_count"] == 1
|
||||
assert changed_result["deleted_annotation_count"] == 1
|
||||
assert len(propagate_calls) == 2
|
||||
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0].mask_data["polygons"] == [replacement_output_polygon]
|
||||
assert annotations[0].mask_data["source_annotation_id"] == 7
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
|
||||
@@ -110,3 +110,31 @@ def test_dashboard_overview_keeps_recent_success_tasks_in_progress_list(client,
|
||||
"updated_at": body["tasks"][0]["updated_at"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_dashboard_overview_uses_processed_frame_count_for_propagation_tasks(client, db_session):
|
||||
from models import ProcessingTask
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Propagation Project",
|
||||
"status": "ready",
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="running",
|
||||
progress=45,
|
||||
message="向后传播 胆囊 (1/2)",
|
||||
project_id=project["id"],
|
||||
payload={"project_id": project["id"]},
|
||||
result={"processed_frame_count": 8, "created_annotation_count": 3},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
response = client.get("/api/dashboard/overview")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["tasks"][0]["task_id"] == task.id
|
||||
assert body["tasks"][0]["frame_count"] == 8
|
||||
|
||||
@@ -84,6 +84,42 @@ def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
|
||||
|
||||
|
||||
def test_retry_task_dispatches_propagation_worker_without_media_requirement(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Retry Propagation"}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="failed",
|
||||
progress=100,
|
||||
message="自动传播失败",
|
||||
error="model unavailable",
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": 1,
|
||||
"steps": [],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagation-retry"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.tasks.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: None)
|
||||
|
||||
response = client.post(f"/api/tasks/{task.id}/retry")
|
||||
|
||||
assert response.status_code == 202
|
||||
body = response.json()
|
||||
assert body["task_type"] == "propagate_masks"
|
||||
assert body["celery_task_id"] == "celery-propagation-retry"
|
||||
assert queued == [body["id"]]
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
|
||||
|
||||
|
||||
def test_task_actions_reject_invalid_states(client, db_session):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Done",
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from celery_app import celery_app
|
||||
from database import SessionLocal
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,3 +21,16 @@ def parse_project_media(task_id: int) -> dict:
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="ai.propagate_project")
|
||||
def propagate_project_masks(task_id: int) -> dict:
|
||||
"""Run SAM video propagation for one queued task."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return run_propagate_project_task(db, task_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task_id)
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user