feat: 打通全栈标注闭环、异步拆帧与模型状态
后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
@@ -1,18 +1,25 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
from models import Project, Frame, Template, Annotation
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
)
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
try:
|
||||
if prompt_type == "point":
|
||||
point_payload = payload.prompt_data
|
||||
if isinstance(point_payload, dict):
|
||||
points = point_payload.get("points")
|
||||
labels = point_payload.get("labels")
|
||||
else:
|
||||
points = point_payload
|
||||
labels = None
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_registry.predict_box(payload.model, image, box)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/status",
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
try:
|
||||
polygons, scores = sam_registry.predict_auto(None, image)
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
@@ -106,7 +144,7 @@ def save_annotation(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -121,3 +159,74 @@ def save_annotation(
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
|
||||
|
||||
@router.get(
|
||||
"/annotations",
|
||||
response_model=List[AnnotationOut],
|
||||
summary="List saved annotations for a project",
|
||||
)
|
||||
def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/annotations/{annotation_id}",
|
||||
response_model=AnnotationOut,
|
||||
summary="Update a saved annotation",
|
||||
)
|
||||
def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Updated annotation id=%s", annotation.id)
|
||||
return annotation
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/annotations/{annotation_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a saved annotation",
|
||||
)
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
db.delete(annotation)
|
||||
db.commit()
|
||||
logger.info("Deleted annotation id=%s", annotation_id)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
Reference in New Issue
Block a user