后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
233 lines
8.2 KiB
Python
233 lines
8.2 KiB
Python
"""AI inference endpoints using selectable SAM runtimes."""
|
|
|
|
import logging
|
|
from typing import Any, List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from database import get_db
|
|
from minio_client import download_file
|
|
from models import Project, Frame, Template, Annotation
|
|
from schemas import (
|
|
AiRuntimeStatus,
|
|
PredictRequest,
|
|
PredictResponse,
|
|
AnnotationOut,
|
|
AnnotationCreate,
|
|
AnnotationUpdate,
|
|
)
|
|
from services.sam_registry import ModelUnavailableError, sam_registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
|
|
|
|
|
def _load_frame_image(frame: Frame) -> np.ndarray:
|
|
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
|
try:
|
|
data = download_file(frame.image_url)
|
|
arr = np.frombuffer(data, dtype=np.uint8)
|
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
if img is None:
|
|
raise ValueError("OpenCV could not decode image")
|
|
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to load frame image: %s", exc)
|
|
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
|
|
|
|
|
@router.post(
|
|
"/predict",
|
|
response_model=PredictResponse,
|
|
summary="Run SAM inference with a prompt",
|
|
)
|
|
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|
"""Execute selected SAM segmentation given an image and a prompt.
|
|
|
|
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
|
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
|
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
|
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
|
"""
|
|
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
image = _load_frame_image(frame)
|
|
prompt_type = payload.prompt_type.lower()
|
|
|
|
polygons: List[List[List[float]]] = []
|
|
scores: List[float] = []
|
|
|
|
try:
|
|
if prompt_type == "point":
|
|
point_payload = payload.prompt_data
|
|
if isinstance(point_payload, dict):
|
|
points = point_payload.get("points")
|
|
labels = point_payload.get("labels")
|
|
else:
|
|
points = point_payload
|
|
labels = None
|
|
|
|
if not isinstance(points, list) or len(points) == 0:
|
|
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
|
if not isinstance(labels, list) or len(labels) != len(points):
|
|
labels = [1] * len(points)
|
|
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
|
|
|
|
elif prompt_type == "box":
|
|
box = payload.prompt_data
|
|
if not isinstance(box, list) or len(box) != 4:
|
|
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
|
polygons, scores = sam_registry.predict_box(payload.model, image, box)
|
|
|
|
elif prompt_type == "semantic":
|
|
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
|
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
|
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
|
except ModelUnavailableError as exc:
|
|
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
|
except NotImplementedError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
return {"polygons": polygons, "scores": scores}
|
|
|
|
|
|
@router.get(
|
|
"/models/status",
|
|
response_model=AiRuntimeStatus,
|
|
summary="Get SAM model and GPU runtime status",
|
|
)
|
|
def model_status(selected_model: str | None = None) -> dict:
|
|
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
|
|
try:
|
|
return sam_registry.runtime_status(selected_model)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
|
|
@router.post(
|
|
"/auto",
|
|
response_model=PredictResponse,
|
|
summary="Run automatic segmentation",
|
|
)
|
|
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
|
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
|
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
image = _load_frame_image(frame)
|
|
try:
|
|
polygons, scores = sam_registry.predict_auto(None, image)
|
|
except ModelUnavailableError as exc:
|
|
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
|
|
|
return {"polygons": polygons, "scores": scores}
|
|
|
|
|
|
@router.post(
|
|
"/annotate",
|
|
response_model=AnnotationOut,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="Save an AI-generated annotation",
|
|
)
|
|
def save_annotation(
|
|
payload: AnnotationCreate,
|
|
db: Session = Depends(get_db),
|
|
) -> Annotation:
|
|
"""Persist an annotation (mask, points, bbox) into the database."""
|
|
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
if payload.frame_id:
|
|
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
annotation = Annotation(**payload.model_dump())
|
|
db.add(annotation)
|
|
db.commit()
|
|
db.refresh(annotation)
|
|
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
|
return annotation
|
|
|
|
|
|
@router.get(
|
|
"/annotations",
|
|
response_model=List[AnnotationOut],
|
|
summary="List saved annotations for a project",
|
|
)
|
|
def list_annotations(
|
|
project_id: int,
|
|
frame_id: int | None = None,
|
|
db: Session = Depends(get_db),
|
|
) -> List[Annotation]:
|
|
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
|
project = db.query(Project).filter(Project.id == project_id).first()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
|
if frame_id is not None:
|
|
query = query.filter(Annotation.frame_id == frame_id)
|
|
return query.order_by(Annotation.id).all()
|
|
|
|
|
|
@router.patch(
|
|
"/annotations/{annotation_id}",
|
|
response_model=AnnotationOut,
|
|
summary="Update a saved annotation",
|
|
)
|
|
def update_annotation(
|
|
annotation_id: int,
|
|
payload: AnnotationUpdate,
|
|
db: Session = Depends(get_db),
|
|
) -> Annotation:
|
|
"""Update mutable annotation fields persisted in the database."""
|
|
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
|
if not annotation:
|
|
raise HTTPException(status_code=404, detail="Annotation not found")
|
|
|
|
updates = payload.model_dump(exclude_unset=True)
|
|
if "template_id" in updates and updates["template_id"] is not None:
|
|
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
|
if not template:
|
|
raise HTTPException(status_code=404, detail="Template not found")
|
|
|
|
for field, value in updates.items():
|
|
setattr(annotation, field, value)
|
|
|
|
db.commit()
|
|
db.refresh(annotation)
|
|
logger.info("Updated annotation id=%s", annotation.id)
|
|
return annotation
|
|
|
|
|
|
@router.delete(
|
|
"/annotations/{annotation_id}",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
summary="Delete a saved annotation",
|
|
)
|
|
def delete_annotation(
|
|
annotation_id: int,
|
|
db: Session = Depends(get_db),
|
|
) -> Response:
|
|
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
|
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
|
if not annotation:
|
|
raise HTTPException(status_code=404, detail="Annotation not found")
|
|
|
|
db.delete(annotation)
|
|
db.commit()
|
|
logger.info("Deleted annotation id=%s", annotation_id)
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|