- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
"""Processing task query endpoints."""
|
||
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from typing import List
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
|
||
from celery_app import celery_app
|
||
from database import get_db
|
||
from models import ProcessingTask, Project, User
|
||
from progress_events import publish_task_progress_event
|
||
from routers.auth import get_current_user, require_editor
|
||
from schemas import ProcessingTaskOut
|
||
from statuses import (
|
||
PROJECT_STATUS_PARSING,
|
||
PROJECT_STATUS_PENDING,
|
||
PROJECT_STATUS_READY,
|
||
TASK_ACTIVE_STATUSES,
|
||
TASK_STATUS_CANCELLED,
|
||
TASK_STATUS_FAILED,
|
||
TASK_STATUS_QUEUED,
|
||
)
|
||
from worker_tasks import parse_project_media, propagate_project_masks
|
||
|
||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _now() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||
task = (
|
||
db.query(ProcessingTask)
|
||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||
.filter(
|
||
ProcessingTask.id == task_id,
|
||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
|
||
)
|
||
.first()
|
||
)
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
return task
|
||
|
||
|
||
def _project_status_after_stop(project: Project) -> str:
|
||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||
|
||
|
||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||
def list_tasks(
|
||
project_id: int | None = None,
|
||
status: str | None = None,
|
||
limit: int = 50,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> List[ProcessingTask]:
|
||
"""Return recent background processing tasks."""
|
||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
|
||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
|
||
)
|
||
if project_id is not None:
|
||
query = query.filter(ProcessingTask.project_id == project_id)
|
||
if status is not None:
|
||
query = query.filter(ProcessingTask.status == status)
|
||
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
|
||
|
||
|
||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||
def get_task(
|
||
task_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> ProcessingTask:
|
||
"""Return one background task by id."""
|
||
return _get_task_or_404(task_id, db, current_user)
|
||
|
||
|
||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||
def cancel_task(
|
||
task_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> ProcessingTask:
|
||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||
task = _get_task_or_404(task_id, db, current_user)
|
||
if task.status not in TASK_ACTIVE_STATUSES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail=f"Task is not cancellable in status: {task.status}",
|
||
)
|
||
|
||
if task.celery_task_id:
|
||
try:
|
||
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||
|
||
task.status = TASK_STATUS_CANCELLED
|
||
task.progress = 100
|
||
task.message = "任务已取消"
|
||
task.error = "Cancelled by user"
|
||
task.finished_at = _now()
|
||
if task.project:
|
||
task.project.status = _project_status_after_stop(task.project)
|
||
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
return task
|
||
|
||
|
||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||
def retry_task(
|
||
task_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> ProcessingTask:
|
||
"""Create a fresh queued task from a failed or cancelled task."""
|
||
previous = _get_task_or_404(task_id, db, current_user)
|
||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail=f"Task is not retryable in status: {previous.status}",
|
||
)
|
||
if previous.project_id is None:
|
||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||
|
||
project = db.query(Project).filter(
|
||
Project.id == previous.project_id,
|
||
Project.owner_user_id == current_user.id,
|
||
).first()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="Project not found")
|
||
is_propagation_task = previous.task_type == "propagate_masks"
|
||
if not is_propagation_task and not project.video_path:
|
||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||
|
||
payload = dict(previous.payload or {})
|
||
payload.setdefault("source_type", project.source_type or "video")
|
||
payload["retry_of"] = previous.id
|
||
|
||
task = ProcessingTask(
|
||
task_type=previous.task_type,
|
||
status=TASK_STATUS_QUEUED,
|
||
progress=0,
|
||
message=f"重试任务已入队(源任务 #{previous.id})",
|
||
project_id=project.id,
|
||
payload=payload,
|
||
)
|
||
if not is_propagation_task:
|
||
project.status = PROJECT_STATUS_PARSING
|
||
db.add(task)
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
|
||
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
|
||
task.celery_task_id = async_result.id
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
return task
|