"""Processing task query endpoints.""" import logging from datetime import datetime, timezone from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from celery_app import celery_app from database import get_db from models import ProcessingTask, Project from progress_events import publish_task_progress_event from schemas import ProcessingTaskOut from statuses import ( PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, PROJECT_STATUS_READY, TASK_ACTIVE_STATUSES, TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_QUEUED, ) from worker_tasks import parse_project_media router = APIRouter(prefix="/api/tasks", tags=["Tasks"]) logger = logging.getLogger(__name__) def _now() -> datetime: return datetime.now(timezone.utc) def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask: task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() if not task: raise HTTPException(status_code=404, detail="Task not found") return task def _project_status_after_stop(project: Project) -> str: return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING @router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks") def list_tasks( project_id: int | None = None, status: str | None = None, limit: int = 50, db: Session = Depends(get_db), ) -> List[ProcessingTask]: """Return recent background processing tasks.""" query = db.query(ProcessingTask) if project_id is not None: query = query.filter(ProcessingTask.project_id == project_id) if status is not None: query = query.filter(ProcessingTask.status == status) return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all() @router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task") def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: """Return one background task by id.""" return _get_task_or_404(task_id, db) @router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task") def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: """Cancel a queued/running background task and revoke the Celery job when possible.""" task = _get_task_or_404(task_id, db) if task.status not in TASK_ACTIVE_STATUSES: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"Task is not cancellable in status: {task.status}", ) if task.celery_task_id: try: celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM") except Exception as exc: # noqa: BLE001 logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc) task.status = TASK_STATUS_CANCELLED task.progress = 100 task.message = "任务已取消" task.error = "Cancelled by user" task.finished_at = _now() if task.project: task.project.status = _project_status_after_stop(task.project) db.commit() db.refresh(task) publish_task_progress_event(task) return task @router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task") def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: """Create a fresh queued task from a failed or cancelled task.""" previous = _get_task_or_404(task_id, db) if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"Task is not retryable in status: {previous.status}", ) if previous.project_id is None: raise HTTPException(status_code=400, detail="Task has no project_id") project = db.query(Project).filter(Project.id == previous.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") if not project.video_path: raise HTTPException(status_code=400, detail="Project has no media uploaded") payload = dict(previous.payload or {}) payload.setdefault("source_type", project.source_type or "video") payload["retry_of"] = previous.id task = ProcessingTask( task_type=previous.task_type, status=TASK_STATUS_QUEUED, progress=0, message=f"重试任务已入队(源任务 #{previous.id})", project_id=project.id, payload=payload, ) project.status = PROJECT_STATUS_PARSING db.add(task) db.commit() db.refresh(task) publish_task_progress_event(task) async_result = parse_project_media.delay(task.id) task.celery_task_id = async_result.id db.commit() db.refresh(task) publish_task_progress_event(task) return task