feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

View File

@@ -1,18 +1,25 @@
"""AI inference endpoints using SAM 2."""
"""AI inference endpoints using selectable SAM runtimes."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
from models import Project, Frame, Template, Annotation
from schemas import (
AiRuntimeStatus,
PredictRequest,
PredictResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
try:
if prompt_type == "point":
point_payload = payload.prompt_data
if isinstance(point_payload, dict):
points = point_payload.get("points")
labels = point_payload.get("labels")
else:
points = point_payload
labels = None
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/auto",
response_model=PredictResponse,
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@@ -106,7 +144,7 @@ def save_annotation(
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -121,3 +159,74 @@ def save_annotation(
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@@ -0,0 +1,137 @@
"""Dashboard overview endpoints."""
import os
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
def _system_load_percent() -> int:
"""Return a real host load estimate without adding a psutil dependency."""
try:
load_1m = os.getloadavg()[0]
cpu_count = os.cpu_count() or 1
return min(100, max(0, round((load_1m / cpu_count) * 100)))
except (AttributeError, OSError):
return 0
def _iso_or_none(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
return {
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": task.project_id or 0,
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"updated_at": _iso_or_none(task.updated_at),
}
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).scalar() or 0
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
template_count = db.query(func.count(Template.id)).scalar() or 0
active_task_count = (
db.query(func.count(ProcessingTask.id))
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
recent_tasks = (
db.query(ProcessingTask)
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:
project_name = task.project.name if task.project else f"项目 {task.project_id}"
activities.append({
"id": f"task-{task.id}",
"kind": "task",
"time": _iso_or_none(task.updated_at),
"message": task.message or f"任务状态: {task.status}",
"project": project_name,
})
for project in projects[:10]:
activities.append({
"id": f"project-{project.id}",
"kind": "project",
"time": _iso_or_none(project.updated_at),
"message": f"项目状态: {project.status}",
"project": project.name,
})
recent_annotations = (
db.query(Annotation)
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
)
for annotation in recent_annotations:
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
activities.append({
"id": f"annotation-{annotation.id}",
"kind": "annotation",
"time": _iso_or_none(annotation.updated_at),
"message": f"标注已更新 #{annotation.id}",
"project": project_name,
})
recent_templates = (
db.query(Template)
.order_by(Template.created_at.desc())
.limit(10)
.all()
)
for template in recent_templates:
activities.append({
"id": f"template-{template.id}",
"kind": "template",
"time": _iso_or_none(template.created_at),
"message": f"模板可用: {template.name}",
"project": "系统",
})
activities.sort(key=lambda item: item["time"] or "", reverse=True)
return {
"summary": {
"project_count": project_count,
"parsing_task_count": active_task_count,
"annotation_count": annotation_count,
"frame_count": frame_count,
"template_count": template_count,
"system_load_percent": _system_load_percent(),
},
"tasks": tasks,
"activity": activities[:10],
}

View File

@@ -1,10 +1,6 @@
"""Media upload and parsing endpoints."""
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import List, Optional
@@ -12,13 +8,12 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url, download_file
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import (
parse_video, parse_dicom, upload_frames_to_minio,
extract_thumbnail, get_video_fps,
)
from minio_client import upload_file, get_presigned_url
from models import ProcessingTask, Project
from progress_events import publish_task_progress_event
from schemas import ProcessingTaskOut
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
from worker_tasks import parse_project_media
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
@@ -79,7 +74,7 @@ async def upload_media(
project = Project(
name=file.filename,
description="Auto-created from upload",
status="pending",
status=PROJECT_STATUS_PENDING,
video_path=object_name,
source_type="video",
)
@@ -135,7 +130,7 @@ async def upload_dicom_batch(
project = Project(
name=first_name,
description=f"DICOM series with {len(files)} files",
status="pending",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
)
db.add(project)
@@ -168,19 +163,18 @@ async def upload_dicom_batch(
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: Optional[str] = None,
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
) -> ProcessingTask:
"""Create a background task for media frame extraction.
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
@@ -190,100 +184,24 @@ def parse_media(
raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if effective_source == "dicom":
# Download all dicom files from MinIO
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
from minio_client import get_minio_client, BUCKET_NAME
client = get_minio_client()
prefix = project.video_path
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
# Video: download and parse
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
# Extract thumbnail from first frame
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project_id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
logger.info("Uploaded thumbnail for project_id=%s", project_id)
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
task = ProcessingTask(
task_type=f"parse_{effective_source}",
status=TASK_STATUS_QUEUED,
progress=0,
message="解析任务已入队",
project_id=project_id,
payload={"source_type": effective_source},
)
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
for f in frames_out:
db.refresh(f)
db.refresh(task)
publish_task_progress_event(task)
# Cleanup temp files
shutil.rmtree(tmp_dir, ignore_errors=True)
project.status = "ready"
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
return {
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
return task

37
backend/routers/tasks.py Normal file
View File

@@ -0,0 +1,37 @@
"""Processing task query endpoints."""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from database import get_db
from models import ProcessingTask
from schemas import ProcessingTaskOut
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
def list_tasks(
project_id: int | None = None,
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
query = db.query(ProcessingTask)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
query = query.filter(ProcessingTask.status == status)
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task

View File

@@ -18,9 +18,9 @@ def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None:
mapping["classes"] = data["classes"]
mapping["classes"] = data.pop("classes")
if "rules" in data and data["rules"] is not None:
mapping["rules"] = data["rules"]
mapping["rules"] = data.pop("rules")
data["mapping_rules"] = mapping
return data