- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。 - 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。 - 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。 - 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。 - 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。 - 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。 - 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。 - 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。 - 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。 - 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。 - 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。 - 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。 - 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。 - 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。 - 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。 - 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
"""Processing task query endpoints."""
|
||
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from typing import List
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
|
||
from celery_app import celery_app
|
||
from database import get_db
|
||
from models import ProcessingTask, Project
|
||
from progress_events import publish_task_progress_event
|
||
from schemas import ProcessingTaskOut
|
||
from statuses import (
|
||
PROJECT_STATUS_PARSING,
|
||
PROJECT_STATUS_PENDING,
|
||
PROJECT_STATUS_READY,
|
||
TASK_ACTIVE_STATUSES,
|
||
TASK_STATUS_CANCELLED,
|
||
TASK_STATUS_FAILED,
|
||
TASK_STATUS_QUEUED,
|
||
)
|
||
from worker_tasks import parse_project_media
|
||
|
||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _now() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
return task
|
||
|
||
|
||
def _project_status_after_stop(project: Project) -> str:
|
||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||
|
||
|
||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||
def list_tasks(
|
||
project_id: int | None = None,
|
||
status: str | None = None,
|
||
limit: int = 50,
|
||
db: Session = Depends(get_db),
|
||
) -> List[ProcessingTask]:
|
||
"""Return recent background processing tasks."""
|
||
query = db.query(ProcessingTask)
|
||
if project_id is not None:
|
||
query = query.filter(ProcessingTask.project_id == project_id)
|
||
if status is not None:
|
||
query = query.filter(ProcessingTask.status == status)
|
||
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
|
||
|
||
|
||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||
"""Return one background task by id."""
|
||
return _get_task_or_404(task_id, db)
|
||
|
||
|
||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||
task = _get_task_or_404(task_id, db)
|
||
if task.status not in TASK_ACTIVE_STATUSES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail=f"Task is not cancellable in status: {task.status}",
|
||
)
|
||
|
||
if task.celery_task_id:
|
||
try:
|
||
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||
|
||
task.status = TASK_STATUS_CANCELLED
|
||
task.progress = 100
|
||
task.message = "任务已取消"
|
||
task.error = "Cancelled by user"
|
||
task.finished_at = _now()
|
||
if task.project:
|
||
task.project.status = _project_status_after_stop(task.project)
|
||
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
return task
|
||
|
||
|
||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||
"""Create a fresh queued task from a failed or cancelled task."""
|
||
previous = _get_task_or_404(task_id, db)
|
||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail=f"Task is not retryable in status: {previous.status}",
|
||
)
|
||
if previous.project_id is None:
|
||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||
|
||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="Project not found")
|
||
if not project.video_path:
|
||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||
|
||
payload = dict(previous.payload or {})
|
||
payload.setdefault("source_type", project.source_type or "video")
|
||
payload["retry_of"] = previous.id
|
||
|
||
task = ProcessingTask(
|
||
task_type=previous.task_type,
|
||
status=TASK_STATUS_QUEUED,
|
||
progress=0,
|
||
message=f"重试任务已入队(源任务 #{previous.id})",
|
||
project_id=project.id,
|
||
payload=payload,
|
||
)
|
||
project.status = PROJECT_STATUS_PARSING
|
||
db.add(task)
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
|
||
async_result = parse_project_media.delay(task.id)
|
||
task.celery_task_id = async_result.id
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
return task
|