feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

View File

@@ -1,18 +1,25 @@
"""AI inference endpoints using SAM 2."""
"""AI inference endpoints using selectable SAM runtimes."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
from models import Project, Frame, Template, Annotation
from schemas import (
AiRuntimeStatus,
PredictRequest,
PredictResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
try:
if prompt_type == "point":
point_payload = payload.prompt_data
if isinstance(point_payload, dict):
points = point_payload.get("points")
labels = point_payload.get("labels")
else:
points = point_payload
labels = None
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/auto",
response_model=PredictResponse,
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@@ -106,7 +144,7 @@ def save_annotation(
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -121,3 +159,74 @@ def save_annotation(
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)