feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

21
backend/celery_app.py Normal file
View File

@@ -0,0 +1,21 @@
"""Celery application for background processing."""
from celery import Celery
from config import settings
celery_app = Celery(
"seg_server",
broker=settings.redis_url,
backend=settings.redis_url,
include=["worker_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="Asia/Shanghai",
enable_utc=True,
task_track_started=True,
)

View File

@@ -18,9 +18,11 @@ class Settings(BaseSettings):
minio_secret_key: str = "minioadmin"
minio_secure: bool = False
# SAM2
# SAM
sam_default_model: str = "sam2"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
sam3_model_version: str = "sam3.1"
# App
app_env: str = "development"

View File

@@ -1,11 +1,13 @@
"""FastAPI application entrypoint."""
import asyncio
import json
import logging
import os
import shutil
import tempfile
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
@@ -13,9 +15,11 @@ from fastapi.middleware.cors import CORSMiddleware
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists, upload_file
from redis_client import ping as redis_ping
from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
from routers import projects, templates, media, ai, export, auth
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
logging.basicConfig(
level=logging.INFO,
@@ -45,7 +49,7 @@ def _seed_default_project_sync() -> None:
project = Project(
name="Data_MyVideo_1",
description="默认演示视频",
status="pending",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
)
@@ -98,7 +102,7 @@ def _seed_default_project_sync() -> None:
)
db.add(frame)
project.status = "ready"
project.status = PROJECT_STATUS_READY
db.commit()
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
finally:
@@ -165,6 +169,7 @@ def _seed_default_templates_sync() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
progress_listener: asyncio.Task | None = None
# Startup
logger.info("Starting up SegServer backend...")
@@ -187,6 +192,11 @@ async def lifespan(app: FastAPI):
else:
logger.warning("Redis connection failed.")
try:
progress_listener = asyncio.create_task(_progress_pubsub_loop())
except Exception as exc: # noqa: BLE001
logger.error("Failed to start Redis progress subscription: %s", exc)
# Seed default templates
try:
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
@@ -203,6 +213,10 @@ async def lifespan(app: FastAPI):
# Shutdown
logger.info("Shutting down SegServer backend...")
if progress_listener is not None:
progress_listener.cancel()
with suppress(asyncio.CancelledError):
await progress_listener
engine.dispose()
@@ -229,6 +243,8 @@ app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
@app.get("/health", tags=["Health"])
@@ -269,6 +285,34 @@ class ConnectionManager:
manager = ConnectionManager()
async def _progress_pubsub_loop() -> None:
"""Forward Redis task-progress events to connected WebSocket clients."""
while True:
pubsub = None
try:
pubsub = get_redis_client().pubsub()
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
while True:
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
if message is None:
await asyncio.sleep(0)
continue
raw_data = message.get("data")
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
if isinstance(payload, dict):
await manager.broadcast(payload)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
logger.error("Redis progress subscription failed: %s", exc)
await asyncio.sleep(5)
finally:
if pubsub is not None:
with suppress(Exception):
await asyncio.to_thread(pubsub.close)
@app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates."""
@@ -284,7 +328,7 @@ async def websocket_progress(websocket: WebSocket):
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()),
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except WebSocketDisconnect:
manager.disconnect(websocket)

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from database import Base
from statuses import PROJECT_STATUS_PENDING
class Project(Base):
@@ -26,7 +27,7 @@ class Project(Base):
description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True)
thumbnail_url = Column(String(512), nullable=True)
status = Column(String(50), default="Ready", nullable=False)
status = Column(String(50), default=PROJECT_STATUS_PENDING, nullable=False)
source_type = Column(String(20), default="video", nullable=False) # video | dicom
original_fps = Column(Float, nullable=True)
parse_fps = Column(Float, default=30.0, nullable=False)
@@ -39,6 +40,9 @@ class Project(Base):
annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan"
)
tasks = relationship(
"ProcessingTask", back_populates="project", cascade="all, delete-orphan"
)
class Frame(Base):
@@ -121,3 +125,30 @@ class Mask(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotation = relationship("Annotation", back_populates="masks")
class ProcessingTask(Base):
"""Background task state persisted for dashboard and polling."""
__tablename__ = "processing_tasks"
id = Column(Integer, primary_key=True, index=True)
task_type = Column(String(80), nullable=False)
status = Column(String(40), default="queued", nullable=False)
progress = Column(Integer, default=0, nullable=False)
message = Column(Text, nullable=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=True
)
celery_task_id = Column(String(255), nullable=True)
payload = Column(JSON, nullable=True)
result = Column(JSON, nullable=True)
error = Column(Text, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
finished_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="tasks")

View File

@@ -0,0 +1,64 @@
"""Progress event payloads and Redis publication helpers."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
from redis_client import get_redis_client
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
logger = logging.getLogger(__name__)
PROGRESS_CHANNEL = "seg:progress"
def _iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _event_type(task_status: str) -> str:
if task_status == TASK_STATUS_SUCCESS:
return "complete"
if task_status == TASK_STATUS_FAILED:
return "error"
return "progress"
def task_progress_payload(task: Any) -> dict[str, Any]:
"""Build the WebSocket payload from a persisted processing task."""
project = getattr(task, "project", None)
project_name = getattr(project, "name", None)
status = getattr(task, "status", "")
updated_at = getattr(task, "updated_at", None)
timestamp = updated_at.isoformat() if updated_at is not None else _iso_now()
message = getattr(task, "message", None)
return {
"type": _event_type(status),
"taskId": f"task-{task.id}",
"task_id": task.id,
"project_id": getattr(task, "project_id", None),
"projectName": project_name,
"filename": project_name,
"progress": getattr(task, "progress", 0),
"status": message or status,
"message": message,
"error": getattr(task, "error", None),
"timestamp": timestamp,
}
def publish_progress_event(payload: dict[str, Any]) -> None:
"""Publish a JSON progress event without failing the worker on Redis errors."""
try:
get_redis_client().publish(PROGRESS_CHANNEL, json.dumps(payload, ensure_ascii=False))
except Exception as exc: # noqa: BLE001
logger.warning("Failed to publish progress event: %s", exc)
def publish_task_progress_event(task: Any) -> None:
"""Publish a progress event for a ProcessingTask ORM object."""
publish_progress_event(task_progress_payload(task))

View File

@@ -0,0 +1,2 @@
pytest
httpx

View File

@@ -1,18 +1,25 @@
"""AI inference endpoints using SAM 2."""
"""AI inference endpoints using selectable SAM runtimes."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
from models import Project, Frame, Template, Annotation
from schemas import (
AiRuntimeStatus,
PredictRequest,
PredictResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
try:
if prompt_type == "point":
point_payload = payload.prompt_data
if isinstance(point_payload, dict):
points = point_payload.get("points")
labels = point_payload.get("labels")
else:
points = point_payload
labels = None
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/auto",
response_model=PredictResponse,
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@@ -106,7 +144,7 @@ def save_annotation(
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -121,3 +159,74 @@ def save_annotation(
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@@ -0,0 +1,137 @@
"""Dashboard overview endpoints."""
import os
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
def _system_load_percent() -> int:
"""Return a real host load estimate without adding a psutil dependency."""
try:
load_1m = os.getloadavg()[0]
cpu_count = os.cpu_count() or 1
return min(100, max(0, round((load_1m / cpu_count) * 100)))
except (AttributeError, OSError):
return 0
def _iso_or_none(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
return {
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": task.project_id or 0,
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"updated_at": _iso_or_none(task.updated_at),
}
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).scalar() or 0
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
template_count = db.query(func.count(Template.id)).scalar() or 0
active_task_count = (
db.query(func.count(ProcessingTask.id))
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
recent_tasks = (
db.query(ProcessingTask)
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:
project_name = task.project.name if task.project else f"项目 {task.project_id}"
activities.append({
"id": f"task-{task.id}",
"kind": "task",
"time": _iso_or_none(task.updated_at),
"message": task.message or f"任务状态: {task.status}",
"project": project_name,
})
for project in projects[:10]:
activities.append({
"id": f"project-{project.id}",
"kind": "project",
"time": _iso_or_none(project.updated_at),
"message": f"项目状态: {project.status}",
"project": project.name,
})
recent_annotations = (
db.query(Annotation)
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
)
for annotation in recent_annotations:
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
activities.append({
"id": f"annotation-{annotation.id}",
"kind": "annotation",
"time": _iso_or_none(annotation.updated_at),
"message": f"标注已更新 #{annotation.id}",
"project": project_name,
})
recent_templates = (
db.query(Template)
.order_by(Template.created_at.desc())
.limit(10)
.all()
)
for template in recent_templates:
activities.append({
"id": f"template-{template.id}",
"kind": "template",
"time": _iso_or_none(template.created_at),
"message": f"模板可用: {template.name}",
"project": "系统",
})
activities.sort(key=lambda item: item["time"] or "", reverse=True)
return {
"summary": {
"project_count": project_count,
"parsing_task_count": active_task_count,
"annotation_count": annotation_count,
"frame_count": frame_count,
"template_count": template_count,
"system_load_percent": _system_load_percent(),
},
"tasks": tasks,
"activity": activities[:10],
}

View File

@@ -1,10 +1,6 @@
"""Media upload and parsing endpoints."""
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import List, Optional
@@ -12,13 +8,12 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url, download_file
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import (
parse_video, parse_dicom, upload_frames_to_minio,
extract_thumbnail, get_video_fps,
)
from minio_client import upload_file, get_presigned_url
from models import ProcessingTask, Project
from progress_events import publish_task_progress_event
from schemas import ProcessingTaskOut
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
from worker_tasks import parse_project_media
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
@@ -79,7 +74,7 @@ async def upload_media(
project = Project(
name=file.filename,
description="Auto-created from upload",
status="pending",
status=PROJECT_STATUS_PENDING,
video_path=object_name,
source_type="video",
)
@@ -135,7 +130,7 @@ async def upload_dicom_batch(
project = Project(
name=first_name,
description=f"DICOM series with {len(files)} files",
status="pending",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
)
db.add(project)
@@ -168,19 +163,18 @@ async def upload_dicom_batch(
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: Optional[str] = None,
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
) -> ProcessingTask:
"""Create a background task for media frame extraction.
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
@@ -190,100 +184,24 @@ def parse_media(
raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if effective_source == "dicom":
# Download all dicom files from MinIO
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
from minio_client import get_minio_client, BUCKET_NAME
client = get_minio_client()
prefix = project.video_path
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
# Video: download and parse
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
# Extract thumbnail from first frame
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project_id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
logger.info("Uploaded thumbnail for project_id=%s", project_id)
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
task = ProcessingTask(
task_type=f"parse_{effective_source}",
status=TASK_STATUS_QUEUED,
progress=0,
message="解析任务已入队",
project_id=project_id,
payload={"source_type": effective_source},
)
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
for f in frames_out:
db.refresh(f)
db.refresh(task)
publish_task_progress_event(task)
# Cleanup temp files
shutil.rmtree(tmp_dir, ignore_errors=True)
project.status = "ready"
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
return {
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
return task

37
backend/routers/tasks.py Normal file
View File

@@ -0,0 +1,37 @@
"""Processing task query endpoints."""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from database import get_db
from models import ProcessingTask
from schemas import ProcessingTaskOut
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
def list_tasks(
project_id: int | None = None,
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
query = db.query(ProcessingTask)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
query = query.filter(ProcessingTask.status == status)
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task

View File

@@ -18,9 +18,9 @@ def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None:
mapping["classes"] = data["classes"]
mapping["classes"] = data.pop("classes")
if "rules" in data and data["rules"] is not None:
mapping["rules"] = data["rules"]
mapping["rules"] = data.pop("rules")
data["mapping_rules"] = mapping
return data

View File

@@ -70,6 +70,7 @@ class FrameOut(FrameBase):
# ---------------------------------------------------------------------------
class TemplateBase(BaseModel):
name: str
description: Optional[str] = None
color: str
z_index: int = 0
mapping_rules: Optional[dict[str, Any]] = None
@@ -83,6 +84,7 @@ class TemplateCreate(TemplateBase):
class TemplateUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
z_index: Optional[int] = None
mapping_rules: Optional[dict[str, Any]] = None
@@ -115,7 +117,7 @@ class AnnotationCreate(AnnotationBase):
class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[float]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
template_id: Optional[int] = None
@@ -148,6 +150,28 @@ class MaskOut(MaskBase):
created_at: datetime
# ---------------------------------------------------------------------------
# Processing task schemas
# ---------------------------------------------------------------------------
class ProcessingTaskOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
task_type: str
status: str
progress: int
message: Optional[str] = None
project_id: Optional[int] = None
celery_task_id: Optional[str] = None
payload: Optional[dict[str, Any]] = None
result: Optional[dict[str, Any]] = None
error: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
updated_at: datetime
# ---------------------------------------------------------------------------
# AI schemas
# ---------------------------------------------------------------------------
@@ -155,6 +179,7 @@ class PredictRequest(BaseModel):
image_id: int
prompt_type: str # point / box / semantic
prompt_data: Any
model: Optional[str] = None
class PredictResponse(BaseModel):
@@ -162,6 +187,37 @@ class PredictResponse(BaseModel):
scores: Optional[list[float]] = None
class AiModelStatus(BaseModel):
id: str
label: str
available: bool
loaded: bool = False
device: str
supports: list[str]
message: str
package_available: bool = False
checkpoint_exists: bool = False
checkpoint_path: Optional[str] = None
python_ok: bool = True
torch_ok: bool = True
cuda_required: bool = False
class GpuStatus(BaseModel):
available: bool
device: str
name: Optional[str] = None
torch_available: bool
torch_version: Optional[str] = None
cuda_version: Optional[str] = None
class AiRuntimeStatus(BaseModel):
selected_model: str
gpu: GpuStatus
models: list[AiModelStatus]
# ---------------------------------------------------------------------------
# Export schemas
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,220 @@
"""Background media parsing runner used by Celery workers."""
import logging
import os
import shutil
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
from models import Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.project_id is None:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="任务缺少 project_id",
error="Task has no project_id",
finished=True,
)
raise ValueError("Task has no project_id")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目不存在",
error="Project not found",
finished=True,
)
raise ValueError(f"Project not found: {task.project_id}")
if not project.video_path:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目没有可解析媒体",
error="Project has no media uploaded",
finished=True,
)
project.status = PROJECT_STATUS_ERROR
db.commit()
raise ValueError("Project has no media uploaded")
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
else:
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
project.status = PROJECT_STATUS_READY
db.commit()
result = {
"project_id": project.id,
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="解析完成",
result=result,
finished=True,
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="解析失败",
error=str(exc),
finished=True,
)
logger.error("Frame extraction failed: %s", exc)
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -1,4 +1,4 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging
import os
@@ -11,10 +11,18 @@ from config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable.
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -31,6 +39,8 @@ class SAM2Engine:
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False
self._loaded_device: str | None = None
self._last_error: str | None = None
# -----------------------------------------------------------------------
# Internal helpers
@@ -40,34 +50,87 @@ class SAM2Engine:
if self._model_loaded:
return
if not TORCH_AVAILABLE:
self._last_error = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded = True
return
if not SAM2_AVAILABLE:
self._last_error = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True
return
try:
device = self._best_device()
model = build_sam2(
settings.sam_model_config,
settings.sam_model_path,
device="cuda",
device=device,
)
self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
self._loaded_device = device
self._last_error = None
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
def status(self) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path)
device = self._loaded_device or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if self._predictor is not None:
message = "SAM 2 model loaded and ready."
elif available:
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"SAM 2 unavailable: missing {', '.join(missing)}."
if self._last_error and not self._predictor:
message = self._last_error
return {
"id": "sam2",
"label": "SAM 2",
"available": available,
"loaded": self._predictor is not None,
"device": device,
"supports": ["point", "box", "auto"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": settings.sam_model_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------

View File

@@ -0,0 +1,148 @@
"""SAM 3 engine adapter and runtime status.
The official facebookresearch/sam3 package currently targets Python 3.12+
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
only performs inference when the local runtime can actually import and execute
the package.
"""
from __future__ import annotations
import importlib.util
import logging
import sys
from typing import Any
import numpy as np
from PIL import Image
from config import settings
from services.sam2_engine import SAM2Engine
logger = logging.getLogger(__name__)
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
class SAM3Engine:
"""Lazy SAM 3 image inference adapter."""
def __init__(self) -> None:
self._model: Any | None = None
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
def _load_model(self) -> None:
if self._model_loaded:
return
if not self._can_load():
self._last_error = self._status_message()
self._model_loaded = True
return
try:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model()
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
self._model_loaded = True
logger.error("Failed to load SAM 3 model: %s", exc)
def _ensure_ready(self) -> bool:
self._load_model()
return self._processor is not None
def _status_message(self) -> str:
missing = []
if not SAM3_PACKAGE_AVAILABLE:
missing.append("sam3 package")
if not self._python_ok():
missing.append("Python 3.12+ runtime")
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
available = self._can_load()
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else "unavailable",
"supports": ["semantic"],
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
"package_available": SAM3_PACKAGE_AVAILABLE,
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": self._python_ok(),
"torch_ok": TORCH_AVAILABLE,
"cuda_required": True,
}
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
sam3_engine = SAM3Engine()

View File

@@ -0,0 +1,80 @@
"""Model registry for SAM runtimes and GPU status."""
from __future__ import annotations
from typing import Any
from config import settings
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine
from services.sam3_engine import sam3_engine
try:
import torch
except Exception: # noqa: BLE001
torch = None # type: ignore[assignment]
class ModelUnavailableError(RuntimeError):
"""Raised when a selected model cannot run in this environment."""
class SAMRegistry:
"""Dispatch predictions to the selected SAM backend."""
def __init__(self) -> None:
self._engines = {
"sam2": sam2_engine,
"sam3": sam3_engine,
}
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or "sam2").lower()
if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}")
return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
return {
"selected_model": self.normalize_model_id(selected_model),
"gpu": self.gpu_status(),
"models": [engine.status() for engine in self._engines.values()],
}
def gpu_status(self) -> dict[str, Any]:
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
return {
"available": cuda_available,
"device": "cuda" if cuda_available else "cpu",
"name": torch.cuda.get_device_name(0) if cuda_available else None,
"torch_available": bool(TORCH_AVAILABLE),
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
}
def _engine(self, model_id: str | None) -> Any:
return self._engines[self.normalize_model_id(model_id)]
def _ensure_available(self, model_id: str | None) -> Any:
engine = self._engine(model_id)
status = engine.status()
if not status["available"]:
raise ModelUnavailableError(status["message"])
return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
return self._ensure_available(model_id).predict_points(image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box)
def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image)
def predict_semantic(self, model_id: str | None, image: Any, text: str):
model = self.normalize_model_id(model_id)
if model == "sam3":
return self._ensure_available(model).predict_semantic(image, text)
return self._ensure_available(model).predict_auto(image)
sam_registry = SAMRegistry()

11
backend/statuses.py Normal file
View File

@@ -0,0 +1,11 @@
"""Shared status constants used across backend project/task flows."""
PROJECT_STATUS_PENDING = "pending"
PROJECT_STATUS_PARSING = "parsing"
PROJECT_STATUS_READY = "ready"
PROJECT_STATUS_ERROR = "error"
TASK_STATUS_QUEUED = "queued"
TASK_STATUS_RUNNING = "running"
TASK_STATUS_SUCCESS = "success"
TASK_STATUS_FAILED = "failed"

72
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,72 @@
"""Shared pytest fixtures for backend API tests."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Iterator
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_DIR = Path(__file__).resolve().parents[1]
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
from database import Base, get_db # noqa: E402
from main import websocket_progress # noqa: E402
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
@pytest.fixture()
def db_session() -> Iterator[Session]:
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def app(db_session: Session) -> FastAPI:
test_app = FastAPI()
def override_get_db() -> Iterator[Session]:
yield db_session
test_app.dependency_overrides[get_db] = override_get_db
test_app.include_router(auth.router)
test_app.include_router(projects.router)
test_app.include_router(templates.router)
test_app.include_router(media.router)
test_app.include_router(ai.router)
test_app.include_router(export.router)
test_app.include_router(dashboard.router)
test_app.include_router(tasks.router)
@test_app.get("/health")
def health_check() -> dict[str, str]:
return {"status": "ok", "service": "SegServer"}
test_app.add_api_websocket_route("/ws/progress", websocket_progress)
return test_app
@pytest.fixture()
def client(app: FastAPI) -> Iterator[TestClient]:
with TestClient(app) as test_client:
yield test_client

248
backend/tests/test_ai.py Normal file
View File

@@ -0,0 +1,248 @@
import numpy as np
def _create_project_and_frame(client):
project = client.post("/api/projects", json={"name": "AI Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
template = client.post("/api/templates", json={
"name": "Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
return project, frame, template
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
def fake_predict_points(image, points, labels):
calls["args"] = (points, labels)
return (
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
[0.95],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
})
assert response.status_code == 200
assert response.json()["scores"] == [0.95]
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
def test_predict_box_and_semantic_fallback(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.8],
))
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
[0.5],
))
box_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.2, 0.2, 0.8, 0.8],
})
semantic_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "semantic",
"prompt_data": "胆囊",
})
assert box_response.status_code == 200
assert box_response.json()["scores"] == [0.8]
assert semantic_response.status_code == 200
assert semantic_response.json()["scores"] == [0.5]
def test_model_status_reports_runtime(client, monkeypatch):
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
"selected_model": selected_model or "sam2",
"gpu": {
"available": False,
"device": "cpu",
"name": None,
"torch_available": True,
"torch_version": "2.x",
"cuda_version": None,
},
"models": [
{
"id": "sam2",
"label": "SAM 2",
"available": True,
"loaded": False,
"device": "cpu",
"supports": ["point", "box", "auto"],
"message": "ready",
"package_available": True,
"checkpoint_exists": True,
"checkpoint_path": "model.pt",
"python_ok": True,
"torch_ok": True,
"cuda_required": False,
},
{
"id": "sam3",
"label": "SAM 3",
"available": False,
"loaded": False,
"device": "unavailable",
"supports": ["semantic"],
"message": "missing Python 3.12+ runtime",
"package_available": False,
"checkpoint_exists": False,
"checkpoint_path": None,
"python_ok": False,
"torch_ok": True,
"cuda_required": True,
},
],
})
response = client.get("/api/ai/models/status?selected_model=sam3")
assert response.status_code == 200
body = response.json()
assert body["selected_model"] == "sam3"
assert body["models"][1]["id"] == "sam3"
assert body["models"][1]["available"] is False
def test_predict_validation_errors(client, monkeypatch):
project, _, _ = _create_project_and_frame(client)
assert client.post("/api/ai/predict", json={
"image_id": 999,
"prompt_type": "point",
"prompt_data": [[0.5, 0.5]],
}).status_code == 404
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 1,
"image_url": "frames/1.jpg",
}).json()
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
assert client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.1, 0.2],
}).status_code == 400
def test_save_annotation_validates_project_and_frame(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
})
assert saved.status_code == 201
assert saved.json()["project_id"] == project["id"]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["id"] == saved.json()["id"]
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
assert frame_listing.status_code == 200
assert len(frame_listing.json()) == 1
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
assert missing_project.status_code == 404
missing_frame = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": 999,
})
assert missing_frame.status_code == 404
missing_project_list = client.get("/api/ai/annotations?project_id=999")
assert missing_project_list.status_code == 404
def test_update_and_delete_annotation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
"label": "AI Mask",
"color": "#06b6d4",
},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
}).json()
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
"label": "胆囊",
"color": "#ff0000",
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
},
"points": [[0.4, 0.4]],
"bbox": [0.2, 0.2, 0.6, 0.6],
})
assert updated.status_code == 200
body = updated.json()
assert body["mask_data"]["label"] == "胆囊"
assert body["mask_data"]["class"]["id"] == "c1"
assert body["points"] == [[0.4, 0.4]]
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
assert deleted.status_code == 204
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert empty_listing.status_code == 200
assert empty_listing.json() == []
def test_update_and_delete_annotation_validation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
}).json()
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
assert client.delete("/api/ai/annotations/999").status_code == 404
assert client.patch(
f"/api/ai/annotations/{saved['id']}",
json={"template_id": 999},
).status_code == 404

View File

@@ -0,0 +1,15 @@
def test_login_success(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
assert response.status_code == 200
assert response.json() == {
"token": "fake-jwt-token-for-admin",
"username": "admin",
}
def test_login_rejects_invalid_credentials(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
assert response.status_code == 401
assert response.json()["detail"] == "Invalid credentials"

View File

@@ -0,0 +1,69 @@
def test_dashboard_overview_uses_persisted_records(client, db_session):
from models import ProcessingTask
project_pending = client.post("/api/projects", json={
"name": "Pending Project",
"status": "pending",
}).json()
project_ready = client.post("/api/projects", json={
"name": "Ready Project",
"status": "ready",
}).json()
frame = client.post(f"/api/projects/{project_pending['id']}/frames", json={
"project_id": project_pending["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
template = client.post("/api/templates", json={
"name": "Dashboard Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
annotation = client.post("/api/ai/annotate", json={
"project_id": project_pending["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
})
assert annotation.status_code == 201
task = ProcessingTask(
task_type="parse_video",
status="running",
progress=35,
message="正在使用 FFmpeg/OpenCV 拆帧",
project_id=project_pending["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
response = client.get("/api/dashboard/overview")
assert response.status_code == 200
body = response.json()
assert body["summary"]["project_count"] == 2
assert body["summary"]["frame_count"] == 1
assert body["summary"]["annotation_count"] == 1
assert body["summary"]["template_count"] == 1
assert body["summary"]["parsing_task_count"] == 1
assert body["tasks"] == [
{
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": project_pending["id"],
"name": "Pending Project",
"progress": 35,
"status": "正在使用 FFmpeg/OpenCV 拆帧",
"frame_count": 0,
"updated_at": body["tasks"][0]["updated_at"],
},
]
assert any(item["kind"] == "task" for item in body["activity"])
assert any(item["kind"] == "annotation" for item in body["activity"])
assert any(item["kind"] == "template" for item in body["activity"])
assert all(item["name"] != "Ready Project" for item in body["tasks"])

View File

@@ -0,0 +1,66 @@
import zipfile
from io import BytesIO
def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 100,
"height": 50,
}).json()
template = client.post("/api/templates", json={
"name": "Category",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
annotation = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8], [0.1, 0.8]]]},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.2, 0.8, 0.6],
}).json()
return project, frame, template, annotation
def test_export_coco_json_structure(client):
project, frame, _, _ = _seed_export_data(client)
response = client.get(f"/api/export/{project['id']}/coco")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/json")
data = response.json()
assert data["info"]["description"] == "Annotations for Export Project"
assert data["images"][0] == {
"id": frame["id"],
"file_name": "frames/0.jpg",
"width": 100,
"height": 50,
"frame_index": 0,
}
assert data["annotations"][0]["segmentation"] == [[10.0, 10.0, 90.0, 10.0, 90.0, 40.0, 10.0, 40.0]]
assert data["annotations"][0]["bbox"] == [10.0, 10.0, 80.0, 30.000000000000004]
assert data["categories"][0]["name"] == "Category"
def test_export_masks_zip(client):
project, _, _, annotation = _seed_export_data(client)
response = client.get(f"/api/export/{project['id']}/masks")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
def test_export_missing_project_returns_404(client):
assert client.get("/api/export/999/coco").status_code == 404
assert client.get("/api/export/999/masks").status_code == 404

View File

@@ -0,0 +1,15 @@
def test_health_endpoint(client):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "service": "SegServer"}
def test_websocket_progress_heartbeat(client):
with client.websocket_connect("/ws/progress") as websocket:
websocket.send_text("ping")
data = websocket.receive_json()
assert data["type"] == "status"
assert data["status"] == "connected"
assert data["message"] == "Progress stream active"

142
backend/tests/test_media.py Normal file
View File

@@ -0,0 +1,142 @@
def test_upload_rejects_unsupported_file_type(client):
response = client.post(
"/api/media/upload",
files={"file": ("notes.txt", b"text", "text/plain")},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_upload_auto_creates_project(client, monkeypatch):
uploaded = []
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
response = client.post(
"/api/media/upload",
files={"file": ("clip.mp4", b"video", "video/mp4")},
)
assert response.status_code == 201
data = response.json()
assert data["project_id"] is not None
assert data["object_name"] == f"uploads/{data['project_id']}/clip.mp4"
assert uploaded == ["uploads/general/clip.mp4", f"uploads/{data['project_id']}/clip.mp4"]
def test_upload_links_existing_project(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Existing"}).json()
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
response = client.post(
"/api/media/upload",
data={"project_id": str(project["id"])},
files={"file": ("clip.mp4", b"video", "video/mp4")},
)
assert response.status_code == 201
detail = client.get(f"/api/projects/{project['id']}").json()
assert detail["video_path"] == f"uploads/{project['id']}/clip.mp4"
def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatch):
uploaded = []
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
response = client.post(
"/api/media/upload/dicom",
files=[
("files", ("a.dcm", b"dcm", "application/dicom")),
("files", ("skip.txt", b"text", "text/plain")),
],
)
assert response.status_code == 201
data = response.json()
assert data["uploaded_count"] == 1
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
def test_parse_media_queues_background_task(client, monkeypatch):
project = client.post("/api/projects", json={
"name": "Parse Me",
"video_path": "uploads/1/clip.mp4",
"source_type": "video",
"parse_fps": 5,
}).json()
class FakeAsyncResult:
id = "celery-1"
queued = []
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
published = []
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: published.append(task.id))
response = client.post(f"/api/media/parse?project_id={project['id']}")
assert response.status_code == 202
data = response.json()
assert data["task_type"] == "parse_video"
assert data["status"] == "queued"
assert data["progress"] == 0
assert data["project_id"] == project["id"]
assert data["celery_task_id"] == "celery-1"
assert queued == [data["id"]]
assert published == [data["id"]]
detail = client.get(f"/api/tasks/{data['id']}")
assert detail.status_code == 200
assert detail.json()["status"] == "queued"
project_detail = client.get(f"/api/projects/{project['id']}").json()
assert project_detail["status"] == "parsing"
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
project = client.post("/api/projects", json={
"name": "Parse Me",
"video_path": "uploads/1/clip.mp4",
"source_type": "video",
"parse_fps": 5,
}).json()
task = ProcessingTask(
task_type="parse_video",
status="queued",
progress=0,
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
frame_file = tmp_path / "frame_000001.jpg"
frame_file.write_bytes(b"fake image")
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
published = []
monkeypatch.setattr(
"services.media_task_runner.publish_task_progress_event",
lambda event_task: published.append((event_task.status, event_task.progress, event_task.message)),
)
result = run_parse_media_task(db_session, task.id)
assert result["frames_extracted"] == 1
db_session.refresh(task)
assert task.status == "success"
assert task.progress == 100
assert ("running", 5, "后台解析已启动") in published
assert ("success", 100, "解析完成") in published
project_detail = client.get(f"/api/projects/{project['id']}").json()
assert project_detail["status"] == "ready"
frames = client.get(f"/api/projects/{project['id']}/frames").json()
assert "frame_000001.jpg" in frames[0]["image_url"]

View File

@@ -0,0 +1,42 @@
from types import SimpleNamespace
from progress_events import PROGRESS_CHANNEL, publish_progress_event, task_progress_payload
def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
task = SimpleNamespace(
id=12,
project_id=7,
project=SimpleNamespace(name="demo.mp4"),
status="success",
progress=100,
message="解析完成",
error=None,
updated_at=None,
)
payload = task_progress_payload(task)
assert payload["type"] == "complete"
assert payload["taskId"] == "task-12"
assert payload["task_id"] == 12
assert payload["project_id"] == 7
assert payload["filename"] == "demo.mp4"
assert payload["projectName"] == "demo.mp4"
assert payload["status"] == "解析完成"
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
calls = []
class FakeRedis:
def publish(self, channel, payload):
calls.append((channel, payload))
monkeypatch.setattr("progress_events.get_redis_client", lambda: FakeRedis())
publish_progress_event({"type": "progress", "message": "正在下载媒体文件"})
assert calls
assert calls[0][0] == PROGRESS_CHANNEL
assert "正在下载媒体文件" in calls[0][1]

View File

@@ -0,0 +1,56 @@
def test_project_crud_and_frames(client, monkeypatch):
monkeypatch.setattr("routers.projects.get_presigned_url", lambda key, expires=3600: f"http://storage/{key}")
created = client.post("/api/projects", json={
"name": "Demo",
"description": "desc",
"thumbnail_url": "thumb.jpg",
"parse_fps": 12,
})
assert created.status_code == 201
project_id = created.json()["id"]
frame = client.post(f"/api/projects/{project_id}/frames", json={
"project_id": project_id,
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
})
assert frame.status_code == 201
frame_id = frame.json()["id"]
listing = client.get("/api/projects")
assert listing.status_code == 200
assert listing.json()[0]["frame_count"] == 1
assert listing.json()[0]["thumbnail_url"] == "http://storage/thumb.jpg"
frames = client.get(f"/api/projects/{project_id}/frames")
assert frames.status_code == 200
assert frames.json()[0]["image_url"] == "http://storage/frames/0.jpg"
single_frame = client.get(f"/api/projects/{project_id}/frames/{frame_id}")
assert single_frame.status_code == 200
assert single_frame.json()["frame_index"] == 0
updated = client.patch(f"/api/projects/{project_id}", json={"name": "Renamed", "status": "ready"})
assert updated.status_code == 200
assert updated.json()["name"] == "Renamed"
assert updated.json()["status"] == "ready"
deleted = client.delete(f"/api/projects/{project_id}")
assert deleted.status_code == 204
assert client.get(f"/api/projects/{project_id}").status_code == 404
def test_project_and_frame_404s(client):
assert client.get("/api/projects/999").status_code == 404
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/projects/999").status_code == 404
assert client.post("/api/projects/999/frames", json={
"project_id": 999,
"frame_index": 0,
"image_url": "missing.jpg",
}).status_code == 404
assert client.get("/api/projects/999/frames").status_code == 404
assert client.get("/api/projects/999/frames/1").status_code == 404

View File

@@ -0,0 +1,39 @@
def test_template_crud_packs_and_unpacks_mapping_rules(client):
payload = {
"name": "Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 10}],
"rules": [{"id": "r1", "name": "rule"}],
}
created = client.post("/api/templates", json=payload)
assert created.status_code == 201
template_id = created.json()["id"]
assert created.json()["classes"][0]["name"] == "胆囊"
assert created.json()["rules"][0]["id"] == "r1"
listing = client.get("/api/templates")
assert listing.status_code == 200
assert listing.json()[0]["classes"][0]["name"] == "胆囊"
detail = client.get(f"/api/templates/{template_id}")
assert detail.status_code == 200
assert detail.json()["name"] == "Template"
updated = client.patch(f"/api/templates/{template_id}", json={
"classes": [{"id": "c2", "name": "肝脏", "color": "#00ff00", "zIndex": 20}],
"rules": [],
})
assert updated.status_code == 200
assert updated.json()["classes"][0]["name"] == "肝脏"
deleted = client.delete(f"/api/templates/{template_id}")
assert deleted.status_code == 204
assert client.get(f"/api/templates/{template_id}").status_code == 404
def test_template_404s(client):
assert client.get("/api/templates/999").status_code == 404
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/templates/999").status_code == 404

22
backend/worker_tasks.py Normal file
View File

@@ -0,0 +1,22 @@
"""Celery task definitions."""
import logging
from celery_app import celery_app
from database import SessionLocal
from services.media_task_runner import run_parse_media_task
logger = logging.getLogger(__name__)
@celery_app.task(name="media.parse_project")
def parse_project_media(task_id: int) -> dict:
"""Run media parsing for one queued task."""
db = SessionLocal()
try:
return run_parse_media_task(db, task_id)
except Exception as exc: # noqa: BLE001
logger.exception("Parse media task failed: task_id=%s", task_id)
raise exc
finally:
db.close()