feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

View File

@@ -0,0 +1,220 @@
"""Background media parsing runner used by Celery workers."""
import logging
import os
import shutil
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
from models import Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.project_id is None:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="任务缺少 project_id",
error="Task has no project_id",
finished=True,
)
raise ValueError("Task has no project_id")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目不存在",
error="Project not found",
finished=True,
)
raise ValueError(f"Project not found: {task.project_id}")
if not project.video_path:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目没有可解析媒体",
error="Project has no media uploaded",
finished=True,
)
project.status = PROJECT_STATUS_ERROR
db.commit()
raise ValueError("Project has no media uploaded")
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
else:
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
project.status = PROJECT_STATUS_READY
db.commit()
result = {
"project_id": project.id,
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="解析完成",
result=result,
finished=True,
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="解析失败",
error=str(exc),
finished=True,
)
logger.error("Frame extraction failed: %s", exc)
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -1,4 +1,4 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging
import os
@@ -11,10 +11,18 @@ from config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable.
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -31,6 +39,8 @@ class SAM2Engine:
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False
self._loaded_device: str | None = None
self._last_error: str | None = None
# -----------------------------------------------------------------------
# Internal helpers
@@ -40,34 +50,87 @@ class SAM2Engine:
if self._model_loaded:
return
if not TORCH_AVAILABLE:
self._last_error = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded = True
return
if not SAM2_AVAILABLE:
self._last_error = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True
return
try:
device = self._best_device()
model = build_sam2(
settings.sam_model_config,
settings.sam_model_path,
device="cuda",
device=device,
)
self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
self._loaded_device = device
self._last_error = None
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
def status(self) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path)
device = self._loaded_device or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if self._predictor is not None:
message = "SAM 2 model loaded and ready."
elif available:
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"SAM 2 unavailable: missing {', '.join(missing)}."
if self._last_error and not self._predictor:
message = self._last_error
return {
"id": "sam2",
"label": "SAM 2",
"available": available,
"loaded": self._predictor is not None,
"device": device,
"supports": ["point", "box", "auto"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": settings.sam_model_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------

View File

@@ -0,0 +1,148 @@
"""SAM 3 engine adapter and runtime status.
The official facebookresearch/sam3 package currently targets Python 3.12+
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
only performs inference when the local runtime can actually import and execute
the package.
"""
from __future__ import annotations
import importlib.util
import logging
import sys
from typing import Any
import numpy as np
from PIL import Image
from config import settings
from services.sam2_engine import SAM2Engine
logger = logging.getLogger(__name__)
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
class SAM3Engine:
"""Lazy SAM 3 image inference adapter."""
def __init__(self) -> None:
self._model: Any | None = None
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
def _load_model(self) -> None:
if self._model_loaded:
return
if not self._can_load():
self._last_error = self._status_message()
self._model_loaded = True
return
try:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model()
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
self._model_loaded = True
logger.error("Failed to load SAM 3 model: %s", exc)
def _ensure_ready(self) -> bool:
self._load_model()
return self._processor is not None
def _status_message(self) -> str:
missing = []
if not SAM3_PACKAGE_AVAILABLE:
missing.append("sam3 package")
if not self._python_ok():
missing.append("Python 3.12+ runtime")
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
available = self._can_load()
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else "unavailable",
"supports": ["semantic"],
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
"package_available": SAM3_PACKAGE_AVAILABLE,
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": self._python_ok(),
"torch_ok": TORCH_AVAILABLE,
"cuda_required": True,
}
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
sam3_engine = SAM3Engine()

View File

@@ -0,0 +1,80 @@
"""Model registry for SAM runtimes and GPU status."""
from __future__ import annotations
from typing import Any
from config import settings
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine
from services.sam3_engine import sam3_engine
try:
import torch
except Exception: # noqa: BLE001
torch = None # type: ignore[assignment]
class ModelUnavailableError(RuntimeError):
"""Raised when a selected model cannot run in this environment."""
class SAMRegistry:
"""Dispatch predictions to the selected SAM backend."""
def __init__(self) -> None:
self._engines = {
"sam2": sam2_engine,
"sam3": sam3_engine,
}
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or "sam2").lower()
if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}")
return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
return {
"selected_model": self.normalize_model_id(selected_model),
"gpu": self.gpu_status(),
"models": [engine.status() for engine in self._engines.values()],
}
def gpu_status(self) -> dict[str, Any]:
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
return {
"available": cuda_available,
"device": "cuda" if cuda_available else "cpu",
"name": torch.cuda.get_device_name(0) if cuda_available else None,
"torch_available": bool(TORCH_AVAILABLE),
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
}
def _engine(self, model_id: str | None) -> Any:
return self._engines[self.normalize_model_id(model_id)]
def _ensure_available(self, model_id: str | None) -> Any:
engine = self._engine(model_id)
status = engine.status()
if not status["available"]:
raise ModelUnavailableError(status["message"])
return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
return self._ensure_available(model_id).predict_points(image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box)
def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image)
def predict_semantic(self, model_id: str | None, image: Any, text: str):
model = self.normalize_model_id(model_id)
if model == "sam3":
return self._ensure_available(model).predict_semantic(image, text)
return self._ensure_available(model).predict_auto(image)
sam_registry = SAMRegistry()