feat: 完善 AI 分割与工作区标注闭环

功能增加:

- 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。

- 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。

- 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。

- 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。

- 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。

bugfix:

- 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。

- 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。

- 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。

- 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。

- 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。

测试和文档:

- 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。

- 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。

- 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。
This commit is contained in:
2026-05-01 21:50:17 +08:00
parent 5ab4602535
commit 8a9247075e
31 changed files with 920 additions and 216 deletions

View File

@@ -6,7 +6,7 @@
## 项目概述 ## 项目概述
本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、显式视频生成帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
- **项目名称**: `react-example``package.json` 中的 `name` - **项目名称**: `react-example``package.json` 中的 `name`
- **前端入口**: `src/main.tsx``src/App.tsx` - **前端入口**: `src/main.tsx``src/App.tsx`
@@ -219,12 +219,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456` 1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。 2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
3. 上传资源:视频走 `/api/media/upload`DICOM 批量走 `/api/media/upload/dicom` 3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧DICOM 批量走 `/api/media/upload/dicom`
4. 帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery接口支持 `parse_fps``max_frames``target_width` 标准帧序列参数。 4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery接口支持 `parse_fps``max_frames``target_width` 标准帧序列参数。
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms``source_frame_number` 和任务 `frame_sequence` 元数据。 5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms``source_frame_number` 和任务 `frame_sequence` 元数据。
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference内含去除结果用 even-odd 规则渲染 holeZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 可拖动/删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference内含去除结果用 even-odd 规则渲染 holeZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
8. AI 分割前端工具包括正向点、反向点和框选SAM 2 框选会建立候选 mask后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理、框选提示和 external video tracker主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。 8. AI 分割前端工具包括正向点、反向点和框选SAM 2 框选会建立候选 mask后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask包含反向点时工作区会传 `options.auto_filter_background=true``min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播但不支持文本语义提示AI 页面在 SAM 2 纯文本时提示改用点提示或切换 SAM 3SAM 2 多候选默认只采用最高分区域避免重叠候选同时显示AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理、框选提示和 external video trackersemantic 请求会把正数 `options.min_score` 传给 external worker 作为置信度阈值,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`后端按项目帧序列下载片段帧SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation` 9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`后端按项目帧序列下载片段帧SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新。 10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新。
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。 11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。
@@ -237,7 +237,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL``VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。 - `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL``VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id``prompt_type``prompt_data``model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData` - 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id``prompt_type``prompt_data``model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`
- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。 - 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。 - Polygon 顶点编辑和新增顶点会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。 - 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 - 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 - 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
@@ -245,8 +245,9 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`SAM 2 路径使用视频 predictorSAM 3 路径使用独立 Python helper 的官方 video tracker完成后刷新后端已保存标注。 - 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`SAM 2 路径使用视频 predictorSAM 3 路径使用独立 Python helper 的官方 video tracker完成后刷新后端已保存标注。
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
- 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error` - 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error`
- 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。
- `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。 - `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress` - `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`;前端 WebSocket 客户端通过 `onopen/onclose/onerror` 更新连接状态,并定时发送 `ping` 心跳
--- ---

View File

@@ -12,9 +12,9 @@
## 核心功能 ## 核心功能
- **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像上传、存储与解析 - **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box、自动分割auto和 video predictor 传播SAM 3 入口支持文本语义提示、框选提示和 external video tracker并按真实运行环境显示可用性 - **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box、自动分割auto和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示;SAM 3 入口支持文本语义提示、框选提示和 external video tracker并按真实运行环境显示可用性
- **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 - **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point - **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 - **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
@@ -325,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
``` ```
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps``max_frames``target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms``source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 视频导入只创建项目并把源视频保存到 MinIO不会自动拆帧用户在项目库点击“生成帧”后再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps``max_frames``target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms``source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的 WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
### 步骤 7: 安装前端依赖并构建 ### 步骤 7: 安装前端依赖并构建
@@ -461,6 +461,8 @@ pip install -e . --no-build-isolation
- 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData` - 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData`
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict` - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`
- 工作区 SAM 2 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks` 并自动选中;右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed调用 `/api/ai/propagate`,并在完成后刷新已保存标注。 - 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco` - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程导出前会先保存当前待归档的前端 mask。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程导出前会先保存当前待归档的前端 mask。

View File

@@ -340,7 +340,21 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
elif prompt_type == "semantic": elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text) min_score = options.get("min_score")
confidence_threshold = None
if min_score is not None:
try:
parsed_min_score = float(min_score)
if parsed_min_score > 0:
confidence_threshold = parsed_min_score
except (TypeError, ValueError):
confidence_threshold = None
polygons, scores = sam_registry.predict_semantic(
payload.model,
image,
text,
confidence_threshold=confidence_threshold,
)
else: else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
@@ -352,6 +366,13 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail=str(exc)) from exc raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points) polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
logger.info(
"AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d",
payload.model or "default",
prompt_type,
payload.image_id,
len(polygons),
)
return {"polygons": polygons, "scores": scores} return {"polygons": polygons, "scores": scores}

View File

@@ -207,7 +207,7 @@ class SAM2Engine:
masks, scores, _ = self._predictor.predict( masks, scores, _ = self._predictor.predict(
point_coords=pts, point_coords=pts,
point_labels=lbls, point_labels=lbls,
multimask_output=True, multimask_output=False,
) )
polygons = [] polygons = []
@@ -335,16 +335,16 @@ class SAM2Engine:
masks, scores, _ = self._predictor.predict( masks, scores, _ = self._predictor.predict(
point_coords=pts, point_coords=pts,
point_labels=lbls, point_labels=lbls,
multimask_output=True, multimask_output=False,
) )
polygons = [] polygons = []
for m in masks[:3]: # Limit to top 3 masks for m in masks[:1]:
poly = self._mask_to_polygon(m) poly = self._mask_to_polygon(m)
if poly: if poly:
polygons.append(poly) polygons.append(poly)
return polygons, scores[:3].tolist() return polygons, scores[:1].tolist()
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc) logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]

View File

@@ -260,6 +260,7 @@ class SAM3Engine:
*, *,
text: str = "", text: str = "",
box: list[float] | None = None, box: list[float] | None = None,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]: ) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True) status = self._external_status(force=True)
if not status.get("available"): if not status.get("available"):
@@ -279,7 +280,11 @@ class SAM3Engine:
"box": box, "box": box,
"model_version": settings.sam3_model_version, "model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(), "checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold, "confidence_threshold": (
confidence_threshold
if confidence_threshold is not None
else settings.sam3_confidence_threshold
),
}, },
ensure_ascii=False, ensure_ascii=False,
), ),
@@ -312,8 +317,18 @@ class SAM3Engine:
raise RuntimeError(str(payload["error"])) raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", []) return payload.get("polygons", []), payload.get("scores", [])
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: def _predict_semantic_external(
return self._predict_external(image, "semantic", text=text) self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(
image,
"semantic",
text=text,
confidence_threshold=confidence_threshold,
)
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]: def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "box", box=box) return self._predict_external(image, "box", box=box)
@@ -378,11 +393,16 @@ class SAM3Engine:
raise RuntimeError(str(payload["error"])) raise RuntimeError(str(payload["error"]))
return payload.get("frames", []) return payload.get("frames", [])
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: def predict_semantic(
self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip(): if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.") raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._can_load() and self._external_status().get("available"): if not self._can_load() and self._external_status().get("available"):
return self._predict_semantic_external(image, text) return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold)
if not self._ensure_ready(): if not self._ensure_ready():
raise RuntimeError(self.status()["message"]) raise RuntimeError(self.status()["message"])

View File

@@ -190,7 +190,9 @@ def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]: def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(output.get("masks", [])) masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", [])) scores = _to_numpy(output.get("scores", []))
if masks.ndim == 4: if masks.ndim == 2:
masks = masks[None, :, :]
elif masks.ndim == 4:
masks = masks[:, 0] masks = masks[:, 0]
elif masks.ndim == 3 and masks.shape[0] == 1: elif masks.ndim == 3 and masks.shape[0] == 1:
masks = masks[None, 0] masks = masks[None, 0]

View File

@@ -83,10 +83,20 @@ class SAMRegistry:
def predict_auto(self, model_id: str | None, image: Any): def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image) return self._ensure_available(model_id).predict_auto(image)
def predict_semantic(self, model_id: str | None, image: Any, text: str): def predict_semantic(
self,
model_id: str | None,
image: Any,
text: str,
confidence_threshold: float | None = None,
):
model = self.normalize_model_id(model_id) model = self.normalize_model_id(model_id)
if model == "sam3": if model == "sam3":
return self._ensure_available(model).predict_semantic(image, text) return self._ensure_available(model).predict_semantic(
image,
text,
confidence_threshold=confidence_threshold,
)
return self._ensure_available(model).predict_auto(image) return self._ensure_available(model).predict_auto(image)
def propagate_video( def propagate_video(

View File

@@ -89,15 +89,25 @@ def test_predict_applies_crop_and_background_filter_options(client, monkeypatch)
def test_predict_box_and_semantic_fallback(client, monkeypatch): def test_predict_box_and_semantic_fallback(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client) _, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: ( monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.8], [0.8],
)) ))
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]], def fake_predict_semantic(model, image, text, confidence_threshold=None):
[0.5], calls["semantic"] = {
)) "model": model,
"text": text,
"confidence_threshold": confidence_threshold,
}
return (
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
[0.5],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", fake_predict_semantic)
box_response = client.post("/api/ai/predict", json={ box_response = client.post("/api/ai/predict", json={
"image_id": frame["id"], "image_id": frame["id"],
@@ -108,12 +118,19 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
"image_id": frame["id"], "image_id": frame["id"],
"prompt_type": "semantic", "prompt_type": "semantic",
"prompt_data": "胆囊", "prompt_data": "胆囊",
"model": "sam3",
"options": {"min_score": 0.05},
}) })
assert box_response.status_code == 200 assert box_response.status_code == 200
assert box_response.json()["scores"] == [0.8] assert box_response.json()["scores"] == [0.8]
assert semantic_response.status_code == 200 assert semantic_response.status_code == 200
assert semantic_response.json()["scores"] == [0.5] assert semantic_response.json()["scores"] == [0.5]
assert calls["semantic"] == {
"model": "sam3",
"text": "胆囊",
"confidence_threshold": 0.05,
}
def test_predict_interactive_combines_box_and_points(client, monkeypatch): def test_predict_interactive_combines_box_and_points(client, monkeypatch):

View File

@@ -0,0 +1,63 @@
import numpy as np
from services.sam2_engine import SAM2Engine
class _FakePredictor:
def __init__(self, masks, scores):
self.masks = masks
self.scores = scores
self.calls = []
def set_image(self, _image):
pass
def predict(self, **kwargs):
self.calls.append(kwargs)
return self.masks, self.scores, None
def _mask(offset=0):
mask = np.zeros((32, 32), dtype=np.uint8)
mask[4 + offset:20 + offset, 5 + offset:22 + offset] = 1
return mask
def _ready_engine(monkeypatch, predictor):
monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True)
engine = SAM2Engine()
engine._model_loaded = True
engine._predictor = predictor
return engine
def test_sam2_point_prediction_requests_single_best_mask(monkeypatch):
predictor = _FakePredictor(
np.array([_mask()], dtype=np.uint8),
np.array([0.92], dtype=np.float32),
)
engine = _ready_engine(monkeypatch, predictor)
polygons, scores = engine.predict_points(
np.zeros((32, 32, 3), dtype=np.uint8),
[[0.5, 0.5]],
[1],
)
assert predictor.calls[0]["multimask_output"] is False
assert len(polygons) == 1
assert scores == [0.9200000166893005]
def test_sam2_auto_prediction_keeps_single_best_mask(monkeypatch):
predictor = _FakePredictor(
np.array([_mask(0), _mask(2), _mask(4)], dtype=np.uint8),
np.array([0.8, 0.7, 0.6], dtype=np.float32),
)
engine = _ready_engine(monkeypatch, predictor)
polygons, scores = engine.predict_auto(np.zeros((32, 32, 3), dtype=np.uint8))
assert predictor.calls[0]["multimask_output"] is False
assert len(polygons) == 1
assert scores == [0.800000011920929]

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import numpy as np import numpy as np
from services.sam3_engine import SAM3Engine from services.sam3_engine import SAM3Engine
from services.sam3_external_worker import _to_numpy from services.sam3_external_worker import _prediction_to_response, _to_numpy
class _Completed: class _Completed:
@@ -98,6 +98,41 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
assert any("--request" in args for args in calls) assert any("--request" in args for args in calls)
def test_sam3_predict_semantic_allows_request_threshold_override(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
request_path = Path(args[-1])
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["confidence_threshold"] == 0.05
return _Completed(stdout=json.dumps({
"polygons": [[[0.2, 0.2], [0.6, 0.2], [0.6, 0.6]]],
"scores": [0.07],
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
polygons, scores = SAM3Engine().predict_semantic(
np.zeros((8, 8, 3), dtype=np.uint8),
"surgical scene",
confidence_threshold=0.05,
)
assert len(polygons) == 1
assert scores == [0.07]
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch): def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python") _external_settings(monkeypatch, tmp_path / "python")
@@ -243,3 +278,16 @@ def test_sam3_worker_casts_floating_tensors_before_numpy():
assert tensor.float_called is True assert tensor.float_called is True
assert result.tolist() == [1.0] assert result.tolist() == [1.0]
def test_sam3_worker_converts_single_2d_mask_to_polygon():
mask = np.zeros((12, 12), dtype=np.uint8)
mask[2:10, 3:9] = 1
result = _prediction_to_response({
"masks": mask,
"scores": np.array([0.82], dtype=np.float32),
})
assert len(result["polygons"]) == 1
assert result["scores"] == [0.8199999928474426]

View File

@@ -65,12 +65,12 @@
### 项目与拆帧 ### 项目与拆帧
1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。 1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。
2. 上传视频时先 `createProject()`,再 `uploadMedia()`,再 `parseMedia()` 2. 上传视频时先 `createProject()`,再 `uploadMedia()`;导入视频不自动调用 `parseMedia()`
3. 后端 `media.py` 把原始文件上传到 MinIO。 3. 后端 `media.py` 把原始文件上传到 MinIO。
4. `parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。 4. 用户在项目库点击“生成帧”并选择 FPS 后,`parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。 5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
6. worker 把拆出的帧重新上传 MinIO写入 `frames` 表,并更新任务状态。 6. worker 把拆出的帧重新上传 MinIO写入 `frames` 表,并更新任务状态。
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。 7. 工作区通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL;若项目有源视频但无帧,会提示先回项目库生成帧
8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。 8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。
### 工作区浏览 ### 工作区浏览

View File

@@ -46,8 +46,9 @@
| 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 | | 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 |
| 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` | | 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` |
| 新建项目 | 真实可用 | 调用 `POST /api/projects` | | 新建项目 | 真实可用 | 调用 `POST /api/projects` |
| 导入视频文件 | 真实可用 | 创建项目、上传文件、触发拆帧、刷新项目列表 | | 导入视频文件 | 真实可用 | 创建项目、上传源视频、刷新项目列表;不会自动拆帧 |
| 解析 FPS 滑块 | 真实可用 | 值传入 `createProject({ parse_fps })` | | 生成帧按钮 | 真实可用 | 仅对已导入源视频且尚无帧、非 parsing 状态的项目显示,调用 `parseMedia(projectId, { parseFps })` |
| 生成帧 FPS 滑块 | 真实可用 | 值传入 `/api/media/parse?parse_fps=...`,决定后台拆帧目标 FPS |
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 | | 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 | | 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
| 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 | | 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 |
@@ -59,7 +60,7 @@
|------|------|------| |------|------|------|
| 当前项目名 | 真实可用 | 读取 `currentProject.name` | | 当前项目名 | 真实可用 | 读取 `currentProject.name` |
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` | | 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
| 无帧时触发解析 | 真实可用 | 如果 `video_path` 存在会调用 `parseMedia()` 创建异步任务,并轮询 `GET /api/tasks/{id}` 等待完成 | | 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 |
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 | | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask | | 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON | | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON |
@@ -93,6 +94,7 @@
| 元素 | 状态 | 说明 | | 元素 | 状态 | 说明 |
|------|------|------| |------|------|------|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 | | 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
| 调整多边形 | 真实可用 | 选中 polygon mask 后显示顶点和边中点;支持拖动顶点、点击边中点插点、双击边界按位置插点 |
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask | | 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask右下角显示已选数量和操作按钮合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference合并会保留主 mask 并移除被合并 mask去除会从主 mask 扣除后续选中 mask内含扣除会保留 hole ring 并用 even-odd 规则渲染 | | 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask右下角显示已选数量和操作按钮合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference合并会保留主 mask 并移除被合并 mask去除会从主 mask 扣除后续选中 mask内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
@@ -130,7 +132,8 @@
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 | | SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 | | 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 |
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 | | 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 |
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`没有当前帧时按钮禁用 | | 执行高精度语义分割 | 真实可用 | 使用当前项目帧调用 `/api/ai/predict`SAM 2 需要点提示且只采用最高分候选SAM 3 需要文本语义提示;生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 |
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的 mask便于继续调轮廓和归档 |
| 上传替换底图 | Mock / UI-only | 按钮无事件 | | 上传替换底图 | Mock / UI-only | 按钮无事件 |
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 | | 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks | | 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
@@ -153,6 +156,6 @@
## 总体结论 ## 总体结论
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
当前最主要的 Mock 或未打通链路是polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。 当前最主要的 Mock 或未打通链路是polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。

View File

@@ -32,7 +32,7 @@ Authorization: Bearer <token>
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 | | `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data | | `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data | | `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task支持 `parse_fps``max_frames``target_width` | | `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task由项目库“生成帧”显式调用,支持 `parse_fps``max_frames``target_width` |
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 | | `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery | | `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 | | `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
@@ -91,6 +91,21 @@ Authorization: Bearer <token>
| GET | `/health` | 健康检查 | | GET | `/health` | 健康检查 |
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 | | WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
### WebSocket 进度通道
`/ws/progress` 用于 Dashboard 实时接收后台任务状态。前端连接成功后会定时发送 `ping` 作为心跳;后端收到任意文本心跳后返回:
```json
{
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": "2026-05-01T00:00:00+00:00"
}
```
后台任务进度由 Celery worker 写入 Redis `seg:progress` 频道,再由 FastAPI 转发到当前活跃 WebSocket 连接。Dashboard 的“WebSocket 已连接/断开”状态来自浏览器 WebSocket 的 `onopen/onclose/onerror`,不再依赖是否刚好收到任务进度事件。
## 关键请求体 ## 关键请求体
### 登录 ### 登录
@@ -172,7 +187,11 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
- `point` - `point`
- `box` - `box`
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points``labels` - `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points``labels`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。 - `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理。前端 AI 页面不会再用 SAM 2 发送纯文本 semanticSAM 2 的交互入口应使用点/框提示。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。
SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同一提示下多个备选 mask 被前端叠加显示。
工作区 SAM 2 请求包含反向点时,`CanvasArea` 会发送 `options.auto_filter_background=true``options.min_score=0.05`;如果负向点过滤后没有可用 polygon前端会移除当前旧候选 mask 并要求重新框选或添加正向点。
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。 选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。
@@ -180,7 +199,7 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 - `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。 - `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。 - `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值;对 SAM 3 semantic 请求也会作为 external worker 的 `confidence_threshold` 传入,避免本地 checkpoint 在默认高阈值下返回 0 个 mask
后端响应: 后端响应:

View File

@@ -17,7 +17,8 @@
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。 - 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
- 用户可以新建项目,前端调用 `POST /api/projects` - 用户可以新建项目,前端调用 `POST /api/projects`
- 用户可以选择项目,进入工作区。 - 用户可以选择项目,进入工作区。
- 用户可以导入视频文件,前端创建项目、上传文件、触发拆帧、刷新项目列表 - 用户可以导入视频文件,前端创建项目、上传文件并刷新项目列表;导入视频不自动拆帧
- 用户可以对已导入且尚未生成帧的视频项目点击“生成帧”,在弹窗中选择目标 FPS 后创建拆帧任务。
- 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。 - 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。
- 后端支持项目创建、列表、详情、局部更新和删除。 - 后端支持项目创建、列表、详情、局部更新和删除。
- 后端支持项目帧创建、列表和单帧查询。 - 后端支持项目帧创建、列表和单帧查询。
@@ -42,7 +43,7 @@
## R4 工作区与帧浏览 ## R4 工作区与帧浏览
- 工作区根据当前项目加载帧列表。 - 工作区根据当前项目加载帧列表。
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载 - 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧
- Canvas 显示当前帧图片。 - Canvas 显示当前帧图片。
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。 - Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 - 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
@@ -57,7 +58,9 @@
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。 - 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。 - 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。 - 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。 - 工具栏提供“调整多边形”工具,用户可以点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
- 顶点编辑态显示边中点插入手柄;点击边中点会在该边中间新增顶点。
- “调整多边形”工具下双击 polygon 边界时,会在最接近的线段上按双击位置新增顶点。
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。 - 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
- 选中整个 mask 且未选中具体顶点时Delete/Backspace 删除该 mask已保存 mask 同步调用后端删除接口。 - 选中整个 mask 且未选中具体顶点时Delete/Backspace 删除该 mask已保存 mask 同步调用后端删除接口。
- 撤销、重做绑定全局 `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas 快捷键。 - 撤销、重做绑定全局 `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
@@ -75,14 +78,19 @@
- 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。 - 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。
- 框选提示传归一化 `[x1, y1, x2, y2]` - 框选提示传归一化 `[x1, y1, x2, y2]`
- 工作区 SAM 2 框选会建立一个候选 mask后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。 - 工作区 SAM 2 框选会建立一个候选 mask后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割 - 工作区 SAM 2 一旦包含反向点,会随请求启用 `auto_filter_background``min_score=0.05`;若后端判定反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask避免继续显示已被否定的区域
- SAM 2 不支持文本语义提示AI 页面在 SAM 2 下输入纯文本时会提示用户改用点提示或切换 SAM 3不再回退到自动分割。
- SAM 2 点提示和 auto fallback 默认只采用一个最高分候选 mask避免多个候选 mask 作为同一结果重叠显示。
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签。
- AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 以便继续编辑轮廓和归档保存。
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理。
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。 - SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。 - 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。 - 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
- `POST /api/ai/propagate` 支持 `model=sam2``model=sam3`SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。 - `POST /api/ai/propagate` 支持 `model=sam2``model=sam3`SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。
- 传播结果会写入后续帧 `annotations``mask_data.source` 分别标记为 `sam2_propagation``sam3_propagation`,并保留 label、color 和 class 元数据。 - 传播结果会写入后续帧 `annotations``mask_data.source` 分别标记为 `sam2_propagation``sam3_propagation`,并保留 label、color 和 class 元数据。
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。 - AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。
- AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。 - AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygonSAM 3 semantic 会用 `min_score` 控制 external worker 的置信度阈值
- 后端返回 `polygons``scores` - 后端返回 `polygons``scores`
- 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area` - 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area`
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。 - AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。

View File

@@ -22,11 +22,11 @@
| 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 | | 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 |
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 | | API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 | | 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 | | WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅、连接状态通知、心跳和重连 |
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 | | 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
| 登录页 | `src/components/Login.tsx` | 调用登录 API写入 store | | 登录页 | `src/components/Login.tsx` | 调用登录 API写入 store |
| Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 | | Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 |
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM | | 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM、显式生成帧 |
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 |
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | | 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
@@ -76,15 +76,16 @@
2. `login()` 调用 `POST /api/auth/login` 2. `login()` 调用 `POST /api/auth/login`
3. 成功后 store 写入 tokenApp 渲染主界面。 3. 成功后 store 写入 tokenApp 渲染主界面。
### 项目导入 ### 项目导入与生成帧
1. `ProjectLibrary` 创建项目。 1. `ProjectLibrary` 创建项目。
2. 上传视频或 DICOM `/api/media/upload` `/api/media/upload/dicom` 2. 导入视频时上传视频到 `/api/media/upload` 并关联项目;该步骤不调用 `/api/media/parse`
3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps``max_frames``target_width` 指定标准帧序列参数 3. 用户在项目卡片点击“生成帧”,在弹窗中选择目标 FPS
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg``frame_000000.jpg` 连续命名,并按目标宽度缩放 4. 前端调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps``max_frames``target_width` 指定标准帧序列参数
5. worker 写入 `frames.timestamp_ms``frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀 5. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg``frame_000000.jpg` 连续命名,并按目标宽度缩放
6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress` 6. worker 写入 `frames.timestamp_ms``frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀
7. 刷新项目列表 7. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`
8. 刷新项目列表。
### 任务控制 ### 任务控制
@@ -93,11 +94,12 @@
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。 3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。 4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。 5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。
6. Dashboard 通过 `/ws/progress` 接收 Redis `seg:progress` 转发事件;前端 WebSocket 客户端在 `onopen/onclose/onerror` 主动更新连接状态,并定时发送 `ping` 心跳,服务端返回 `status` 确认连接仍活跃。
### 工作区加载 ### 工作区加载
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()` 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧 2. 若无帧但项目有 `video_path`显示“尚未生成帧”的状态提示,不自动触发 `parseMedia()`
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs``sourceFrameNumber`,供时间轴和后续视频传播使用。 3. 帧数据映射为 store `Frame[]`,包含 `timestampMs``sourceFrameNumber`,供时间轴和后续视频传播使用。
4. 当前帧传入 `CanvasArea` 4. 当前帧传入 `CanvasArea`
@@ -107,13 +109,15 @@
2. `CanvasArea` 读取当前帧 ID 和宽高。 2. `CanvasArea` 读取当前帧 ID 和宽高。
3. SAM 2 框选会创建一个候选 mask并记录原始框后续正向点/反向点会累计到同一候选上。 3. SAM 2 框选会创建一个候选 mask并记录原始框后续正向点/反向点会累计到同一候选上。
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。 4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。
5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3 5. SAM 2 请求中只要存在反向点,`CanvasArea` 会额外发送 `options.auto_filter_background=true``options.min_score=0.05`,让后端移除低分结果和包含负向点的 polygon
6. 前端把 `polygons` 转为 mask交互式细化会替换同一个候选 mask而不是新增多个 mask 6. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3
7. Canvas 按当前帧过滤并渲染 mask。 7. 前端把 `polygons` 转为 mask交互式细化会替换同一个候选 mask而不是新增多个 mask。
8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft` 8. 若带反向点的 SAM 2 细化返回空结果,前端会删除当前旧候选 mask 并提示反向点已排除该区域
9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}` 9. Canvas 按当前帧过滤并渲染 mask
10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask 10. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft`
11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask 11. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`
12. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
13. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
### 视频片段传播 ### 视频片段传播
@@ -131,19 +135,20 @@
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。 1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
2. `CanvasArea` 将交互坐标转换成像素 polygon。 2. `CanvasArea` 将交互坐标转换成像素 polygon。
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。 3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。
4. mask path 只在 `move``area_merge``area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。 4. mask path 只在 `move``edit_polygon``area_merge``area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
5. 新 mask 写入 `pathData`、像素 `segmentation``bbox``area` 和当前模板分类元数据。 5. 新 mask 写入 `pathData`、像素 `segmentation``bbox``area` 和当前模板分类元数据。
6. `addMask()``setMasks()``updateMask()``clearMasks()` 会维护 `maskHistory/maskFuture` 6. `addMask()``setMasks()``updateMask()``clearMasks()` 会维护 `maskHistory/maskFuture`
7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()` 7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`
### Polygon 逐点编辑 ### Polygon 逐点编辑
1. 用户点击 Canvas 上的 mask path`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。 1. 用户选择“调整多边形”或“拖拽/选择”后点击 Canvas 上的 mask path`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点和边中点插入手柄
2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation``bbox``area` 2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation``bbox``area`
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty``saved=false` 3. 点击边中点手柄会在该边中点插入新顶点;在“调整多边形”工具下双击 polygon path 会在最接近的线段上按双击位置插入新顶点
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端 4. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty``saved=false`
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点 5. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端
6. 选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask如果包含 `annotationId`,通过工作区回调调用后端删除接口 6. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点
7. 未选中具体顶点但选中了 mask 时Delete/Backspace 从前端 store 删除该 mask如果包含 `annotationId`,通过工作区回调调用后端删除接口。
### 区域合并与去除 ### 区域合并与去除
@@ -173,9 +178,11 @@
4. 后端把 `classes``rules` 打包进 `mapping_rules` 4. 后端把 `classes``rules` 打包进 `mapping_rules`
5. 返回时再解包给前端。 5. 返回时再解包给前端。
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。 6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask 时使用 7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds`
8. 如果 `selectedMaskIds` 中存在当前 store 的 mask点击分类时会立即更新这些 mask 的 `templateId``classId``className``classZIndex``label``color` 8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon``CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中
9. 已保存 mask 被重新分类后进入 `dirty` `saved=false`,继续复用工作区归档保存的 PATCH 链路 9. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea` `AISegmentation` 新建/更新 mask 时使用
10. 如果 `selectedMaskIds` 中存在当前 store 的 mask点击分类时会立即更新这些 mask 的 `templateId``classId``className``classZIndex``label``color`
11. 已保存 mask 被重新分类后进入 `dirty``saved=false`,继续复用工作区归档保存的 PATCH 链路。
### 导出 ### 导出
@@ -204,10 +211,12 @@
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps``max_frames``target_width`,用于生成标准帧序列。 - `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps``max_frames``target_width`,用于生成标准帧序列。
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms``source_frame_number` - `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms``source_frame_number`
- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。 - 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。
- SAM 2 是点/框交互式分割模型不做文本语义分割AI 页面在 SAM 2 + 纯文本时直接提示用户改用点提示或切换 SAM 3。
- SAM 2 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。 - 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。
- SAM 3 box prompt 复用后端 `/api/ai/predict``box` prompt_type输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。 - SAM 3 box prompt 复用后端 `/api/ai/predict``box` prompt_type输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt不会把正/反点误发送为 SAM 3 point prompt空文本、后端错误和空结果都会显示反馈消息。 - AI 页面选择 SAM 3 时优先发送文本 semantic prompt不会把正/反点误发送为 SAM 3 point prompt空文本、后端错误和空结果都会显示反馈消息。
- 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 - 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果SAM 3 semantic 会把正数 `min_score` 传给 external worker 作为 `confidence_threshold`
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker当前前端默认向后传播 30 帧并保存结果标注。 - 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker当前前端默认向后传播 30 帧并保存结果标注。
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。 - 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。

View File

@@ -16,14 +16,14 @@
|------|----------|--------| |------|----------|--------|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | | R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | | R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 | | R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、SAM 2 框选后正负点细化同一候选 mask、SAM 2 反向点启用背景过滤且空结果移除旧候选、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD | | R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | | R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat |
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
@@ -34,15 +34,15 @@
|------|--------|----------|----------| |------|--------|----------|----------|
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 | | R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 | | R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 | | R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 |
| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | | R4 | 工作区加载帧、无帧项目不自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | | R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | | R5 | 顶点编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 | | R6 | SAM 2 点/框/interactive、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 semantic、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py`, `test_sam3_engine.py` | 已覆盖 |
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 | | R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | | R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | | R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | | R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 | | R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 | | R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 | | R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
@@ -51,6 +51,8 @@
- R5补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。 - R5补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
- R6补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。 - R6补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。
- R6补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。
- R6补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。
- R6补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。 - R6补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。
- R6补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。 - R6补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。
- R6补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。 - R6补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。

View File

@@ -63,6 +63,129 @@ describe('AISegmentation', () => {
})); }));
}); });
it('does not run SAM2 text-only prompts as semantic segmentation', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument();
});
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-best',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
area: 100,
},
{
id: 'sam2-alt',
pathData: 'M 1 1 L 11 1 L 11 11 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[1, 1, 11, 1, 11, 11]],
bbox: [1, 1, 10, 10],
area: 100,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
expect(useStore.getState().masks[0].id).toBe('sam2-best');
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
});
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
useStore.setState({
templates: [
{
id: 'template-1',
name: '腹腔镜模板',
classes: [
{ id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
{ id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 },
],
rules: [],
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: 'template-1',
classId: 'class-2',
className: '肝脏',
classZIndex: 20,
label: '肝脏',
color: '#00ff00',
saveStatus: 'draft',
}));
});
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
const onSendToWorkspace = vi.fn();
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('推送至工作区编辑'));
expect(useStore.getState().activeTool).toBe('edit_polygon');
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
expect(onSendToWorkspace).toHaveBeenCalled();
});
it('prompts for semantic text before running SAM3 inference', async () => { it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({ apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3', selected_model: 'sam3',
@@ -106,7 +229,7 @@ describe('AISegmentation', () => {
points: undefined, points: undefined,
text: '胆囊', text: '胆囊',
}))); })));
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument(); expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
}); });
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => { it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {

View File

@@ -17,6 +17,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks); const clearMasks = useStore((state) => state.clearMasks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const maskHistory = useStore((state) => state.maskHistory); const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture); const maskFuture = useStore((state) => state.maskFuture);
const undoMasks = useStore((state) => state.undoMasks); const undoMasks = useStore((state) => state.undoMasks);
@@ -97,6 +99,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。'); setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return; return;
} }
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
return;
}
if (points.length === 0 && !textPrompt) { if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。'); setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
return; return;
@@ -132,14 +138,22 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
}, },
}); });
if (result.masks.length === 0) { const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks;
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
if (masksToApply.length === 0) {
setInferenceMessage(aiModel === 'sam3'
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
} else { } else {
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`); setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。`
: `已生成 ${masksToApply.length} 个候选区域。`);
} }
result.masks.forEach((m) => { const generatedMaskIds: string[] = [];
masksToApply.forEach((m) => {
const label = activeClass?.name || m.label; const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color; const color = activeClass?.color || m.color;
generatedMaskIds.push(m.id);
addMask({ addMask({
id: m.id, id: m.id,
frameId: currentFrame.id, frameId: currentFrame.id,
@@ -157,6 +171,9 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
area: m.area, area: m.area,
}); });
}); });
if (generatedMaskIds.length > 0) {
setSelectedMaskIds(generatedMaskIds);
}
} catch (err) { } catch (err) {
console.error('AI inference failed:', err); console.error('AI inference failed:', err);
const detail = (err as any)?.response?.data?.detail; const detail = (err as any)?.response?.data?.detail;
@@ -164,7 +181,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]); }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]);
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (effectiveTool === 'move') return;
@@ -307,10 +324,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div> </div>
)} )}
<button <button
onClick={onSendToWorkspace} onClick={() => {
setActiveTool('edit_polygon');
onSendToWorkspace();
}}
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10" className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
> >
<SendToBack size={16} /> 退 <SendToBack size={16} />
</button> </button>
</div> </div>
</aside> </aside>
@@ -376,12 +396,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{/* AI Returned Masks */} {/* AI Returned Masks */}
{frameMasks.map((mask) => ( {frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.45}> <Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}>
<Path <Path
data={mask.pathData} data={mask.pathData}
fill={mask.color} fill={mask.color}
stroke={mask.color} stroke={mask.color}
strokeWidth={1 / scale} strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale}
onClick={(event: any) => {
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
onTap={(event: any) => {
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
/> />
</Group> </Group>
))} ))}

View File

@@ -206,16 +206,58 @@ describe('CanvasArea', () => {
{ x: 300, y: 150, type: 'neg' }, { x: 300, y: 150, type: 'neg' },
], ],
box: { x1: 120, y1: 80, x2: 260, y2: 200 }, box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 },
})); }));
expect(useStore.getState().masks).toHaveLength(1); expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box', id: 'mask-box',
segmentation: [[30, 30, 70, 30, 70, 70]], segmentation: [[30, 30, 70, 30, 70, 70]],
points: [[150, 100]], points: [[150, 100]],
metadata: expect.objectContaining({ promptPointCount: 2 }), metadata: expect.objectContaining({ promptPointCount: 2, promptNegativePointCount: 1 }),
})); }));
}); });
it('removes the SAM2 candidate when a negative point filters it out', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'mask-box',
pathData: 'M 10 10 L 90 10 L 90 90 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 90]],
bbox: [10, 10, 80, 80],
area: 6400,
},
],
})
.mockResolvedValueOnce({ masks: [] });
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
fireEvent.click(stage, { clientX: 180, clientY: 120 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 180, y: 120, type: 'neg' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 },
}));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(0));
expect(await screen.findByText(/反向点已排除当前候选区域/)).toBeInTheDocument();
});
it('renders only masks that belong to the current frame', () => { it('renders only masks that belong to the current frame', () => {
useStore.setState({ useStore.setState({
masks: [ masks: [
@@ -250,6 +292,28 @@ describe('CanvasArea', () => {
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1'])); await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
}); });
it('keeps a mask selected when opening the workspace polygon editor from AI results', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'A',
color: '#fff',
segmentation: [[0, 0, 10, 0, 10, 10]],
},
],
});
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
expect(screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff')).toHaveLength(3);
});
it('renders imported GT seed points for editable point regions', () => { it('renders imported GT seed points for editable point regions', () => {
useStore.setState({ useStore.setState({
masks: [ masks: [
@@ -415,6 +479,34 @@ describe('CanvasArea', () => {
})); }));
}); });
it('selects a polygon with the edit tool and inserts a vertex by double-clicking an edge', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
const path = screen.getByTestId('konva-path');
fireEvent.click(path);
fireEvent.doubleClick(path, { clientX: 50, clientY: 10 });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
saveStatus: 'draft',
}));
});
it('edits the selected polygon in a multi-polygon mask', () => { it('edits the selected polygon in a multi-polygon mask', () => {
useStore.setState({ useStore.setState({
masks: [ masks: [

View File

@@ -19,6 +19,7 @@ type PromptBox = { x1: number; y1: number; x2: number; y2: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']); const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon'; const POLYGON_TOOL = 'create_polygon';
const EDIT_POLYGON_TOOL = 'edit_polygon';
const POINT_TOOL = 'create_point'; const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']); const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
const POLYGON_CLOSE_RADIUS = 8; const POLYGON_CLOSE_RADIUS = 8;
@@ -95,6 +96,32 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
return Math.hypot(a.x - b.x, a.y - b.y); return Math.hypot(a.x - b.x, a.y - b.y);
} }
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
const dx = end.x - start.x;
const dy = end.y - start.y;
const lengthSquared = dx * dx + dy * dy;
if (lengthSquared === 0) {
return (point.x - start.x) ** 2 + (point.y - start.y) ** 2;
}
const t = clamp(((point.x - start.x) * dx + (point.y - start.y) * dy) / lengthSquared, 0, 1);
const projected = { x: start.x + t * dx, y: start.y + t * dy };
return (point.x - projected.x) ** 2 + (point.y - projected.y) ** 2;
}
function nearestPolygonEdgeIndex(points: CanvasPoint[], point: CanvasPoint): number {
return points.reduce((bestIndex, start, index) => {
const end = points[(index + 1) % points.length];
if (!end) return bestIndex;
const bestStart = points[bestIndex];
const bestEnd = points[(bestIndex + 1) % points.length];
const currentDistance = distanceToSegmentSquared(point, start, end);
const bestDistance = bestStart && bestEnd
? distanceToSegmentSquared(point, bestStart, bestEnd)
: Number.POSITIVE_INFINITY;
return currentDistance < bestDistance ? index : bestIndex;
}, 0);
}
function segmentationArea(segmentation?: number[][]): number { function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0); return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
} }
@@ -210,10 +237,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null); const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null); const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]); const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null); const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]); const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0); const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null); const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const previousFrameIdRef = useRef<string | undefined>(frame?.id);
const [isInferencing, setIsInferencing] = useState(false); const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState(''); const [inferenceMessage, setInferenceMessage] = useState('');
@@ -253,6 +281,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length; const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length; const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool); const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
useEffect(() => { useEffect(() => {
const handleResize = () => { const handleResize = () => {
@@ -273,11 +302,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setManualStart(null); setManualStart(null);
setManualCurrent(null); setManualCurrent(null);
setPolygonPoints([]); setPolygonPoints([]);
setSelectedVertexIndex(null);
if (!isPolygonEditTool && !isBooleanTool) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
}
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
useEffect(() => {
if (previousFrameIdRef.current === frame?.id) return;
previousFrameIdRef.current = frame?.id;
setSelectedMaskId(null); setSelectedMaskId(null);
setSelectedMaskIds([]); setSelectedMaskIds([]);
setSelectedPolygonIndex(0); setSelectedPolygonIndex(0);
setSelectedVertexIndex(null); setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]); }, [frame?.id]);
useEffect(() => { useEffect(() => {
setPoints([]); setPoints([]);
@@ -420,6 +460,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setIsInferencing(true); setIsInferencing(true);
setInferenceMessage(''); setInferenceMessage('');
try { try {
const hasNegativePrompt = Boolean(promptPoints?.some((point) => point.type === 'neg'));
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const result = await predictMask({ const result = await predictMask({
imageId: frame.id, imageId: frame.id,
imageWidth, imageWidth,
@@ -429,13 +473,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type })) ? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
: undefined, : undefined,
box: promptBox, box: promptBox,
...(hasNegativePrompt ? { options: { auto_filter_background: true, min_score: 0.05 } } : {}),
}); });
const [m] = result.masks; const [m] = result.masks;
if (m) { if (m) {
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const label = activeClass?.name || existingCandidate?.label || m.label; const label = activeClass?.name || existingCandidate?.label || m.label;
const color = activeClass?.color || existingCandidate?.color || m.color; const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = { const metadata = {
@@ -443,6 +485,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive', source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
promptBox: promptBox || null, promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0, promptPointCount: promptPoints?.length || 0,
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,
}; };
const nextMask = { const nextMask = {
frameId: frame.id, frameId: frame.id,
@@ -476,7 +519,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}); });
} }
} else { } else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。'); if (existingCandidate && hasNegativePrompt) {
setMasks(masks.filter((mask) => mask.id !== existingCandidate.id));
setSamCandidateMaskId(null);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setInferenceMessage('反向点已排除当前候选区域,请重新框选或添加新的正向点。');
} else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
}
} }
} catch (err) { } catch (err) {
console.error('Inference failed:', err); console.error('Inference failed:', err);
@@ -485,7 +536,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]); }, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, setMasks, updateMask]);
const handleApplyActiveClass = () => { const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return; if (!frame?.id || !activeClass) return;
@@ -598,7 +649,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}; };
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (isPolygonEditTool) return;
if (effectiveTool === 'box_select') return; // handled by mouseup if (effectiveTool === 'box_select') return; // handled by mouseup
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return; if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
@@ -716,7 +767,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
window.addEventListener('keydown', handleKeyDown); window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown); return () => window.removeEventListener('keydown', handleKeyDown);
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]); }, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => { const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null; if (!boxStart || !boxCurrent) return null;
@@ -753,7 +804,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}; };
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => { const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== 'move' && !isBooleanTool) return; if (!isPolygonEditTool && !isBooleanTool) return;
event.cancelBubble = true; event.cancelBubble = true;
if (isBooleanTool) { if (isBooleanTool) {
setSelectedMaskIds((current) => ( setSelectedMaskIds((current) => (
@@ -807,6 +858,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
updatePolygonMask(mask, nextPoints, selectedPolygonIndex); updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
}; };
const handlePathDoubleClick = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== EDIT_POLYGON_TOOL) return;
event.cancelBubble = true;
const point = stagePoint(event);
const currentPoints = segmentationToPoints(mask.segmentation, polygonIndex);
if (!point || currentPoints.length < 3) return;
const edgeIndex = nearestPolygonEdgeIndex(currentPoints, point);
const nextPoints = [
...currentPoints.slice(0, edgeIndex + 1),
point,
...currentPoints.slice(edgeIndex + 1),
];
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(edgeIndex + 1);
updatePolygonMask(mask, nextPoints, polygonIndex);
};
const handleBooleanOperation = async () => { const handleBooleanOperation = async () => {
if (!frame || booleanSelectedMasks.length < 2) return; if (!frame || booleanSelectedMasks.length < 2) return;
const primary = booleanSelectedMasks[0]; const primary = booleanSelectedMasks[0];
@@ -918,6 +988,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale} strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)} onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)} onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
onDblTap={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
/> />
))} ))}
</Group> </Group>
@@ -987,7 +1059,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)))} )))}
{/* Polygon edge insertion handles */} {/* Polygon edge insertion handles */}
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => { {isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length]; const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null; if (!next) return null;
return ( return (
@@ -1006,7 +1078,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
})} })}
{/* Polygon vertex editor */} {/* Polygon vertex editor */}
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => ( {isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => (
<Circle <Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`} key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x} x={point.x}

View File

@@ -12,6 +12,7 @@ const apiMock = vi.hoisted(() => ({
const wsMock = vi.hoisted(() => { const wsMock = vi.hoisted(() => {
const state = { const state = {
callback: undefined as undefined | ((data: any) => void), callback: undefined as undefined | ((data: any) => void),
statusCallback: undefined as undefined | ((status: any) => void),
connected: false, connected: false,
}; };
return { return {
@@ -24,6 +25,11 @@ const wsMock = vi.hoisted(() => {
state.callback = cb; state.callback = cb;
return vi.fn(); return vi.fn();
}), }),
onStatus: vi.fn((cb: (status: any) => void) => {
state.statusCallback = cb;
cb(state.connected ? 'connected' : 'disconnected');
return vi.fn();
}),
}, },
}; };
}); });
@@ -45,6 +51,7 @@ describe('Dashboard', () => {
vi.clearAllMocks(); vi.clearAllMocks();
wsMock.state.connected = false; wsMock.state.connected = false;
wsMock.state.callback = undefined; wsMock.state.callback = undefined;
wsMock.state.statusCallback = undefined;
apiMock.getDashboardOverview.mockResolvedValue({ apiMock.getDashboardOverview.mockResolvedValue({
summary: { summary: {
project_count: 2, project_count: 2,
@@ -109,6 +116,20 @@ describe('Dashboard', () => {
expect(screen.getByText('44%')).toBeInTheDocument(); expect(screen.getByText('44%')).toBeInTheDocument();
}); });
it('updates the websocket badge from connection status callbacks', async () => {
render(<Dashboard />);
await waitFor(() => expect(wsMock.progressWS.onStatus).toHaveBeenCalled());
expect(screen.getByText('WebSocket 断开')).toBeInTheDocument();
act(() => {
wsMock.state.connected = true;
wsMock.state.statusCallback?.('connected');
});
expect(screen.getByText('WebSocket 已连接')).toBeInTheDocument();
});
it('adds activity logs for complete and status messages', async () => { it('adds activity logs for complete and status messages', async () => {
render(<Dashboard />); render(<Dashboard />);

View File

@@ -1,6 +1,6 @@
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react'; import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket'; import { progressWS, type ConnectionStatus, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { import {
cancelTask, cancelTask,
@@ -178,6 +178,9 @@ export function Dashboard() {
]); ]);
} }
}); });
const unsubscribeStatus = progressWS.onStatus((status: ConnectionStatus) => {
if (mounted) setIsConnected(status === 'connected');
});
const checkConnection = setInterval(() => { const checkConnection = setInterval(() => {
if (mounted) setIsConnected(progressWS.isConnected()); if (mounted) setIsConnected(progressWS.isConnected());
@@ -186,6 +189,7 @@ export function Dashboard() {
return () => { return () => {
mounted = false; mounted = false;
unsubscribe(); unsubscribe();
unsubscribeStatus();
clearInterval(checkConnection); clearInterval(checkConnection);
progressWS.disconnect(); progressWS.disconnect();
}; };

View File

@@ -56,10 +56,9 @@ describe('ProjectLibrary', () => {
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' })); expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
}); });
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => { it('imports video by creating a project and uploading media without parsing frames', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' }); apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' }); apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
apiMock.getProjects.mockResolvedValue([]); apiMock.getProjects.mockResolvedValue([]);
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />); const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
@@ -70,10 +69,24 @@ describe('ProjectLibrary', () => {
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({ await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
name: 'clip.mp4', name: 'clip.mp4',
parse_fps: 30,
}))); })));
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3'); expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3'); expect(apiMock.parseMedia).not.toHaveBeenCalled();
});
it('generates frames from an imported video with the selected FPS', async () => {
apiMock.getProjects
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'pending', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 30 }])
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'parsing', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 12 }]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 22, status: 'queued', progress: 0 });
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
fireEvent.click(await screen.findByRole('button', { name: '生成帧' }));
fireEvent.change(container.querySelector('input[type="range"]') as HTMLInputElement, { target: { value: '12' } });
fireEvent.click(screen.getByRole('button', { name: '开始生成帧' }));
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 }));
}); });
it('imports only valid DICOM files and parses the returned project', async () => { it('imports only valid DICOM files and parses the returned project', async () => {

View File

@@ -1,5 +1,5 @@
import React, { useState, useEffect, useRef } from 'react'; import React, { useState, useEffect, useRef } from 'react';
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity } from 'lucide-react'; import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity, Images } from 'lucide-react';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api'; import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api';
@@ -22,7 +22,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const [showImportMenu, setShowImportMenu] = useState(false); const [showImportMenu, setShowImportMenu] = useState(false);
const [showVideoConfig, setShowVideoConfig] = useState(false); const [showVideoConfig, setShowVideoConfig] = useState(false);
const [pendingFile, setPendingFile] = useState<File | null>(null); const [pendingFile, setPendingFile] = useState<File | null>(null);
const [parseFps, setParseFps] = useState(30); const [frameProject, setFrameProject] = useState<Project | null>(null);
const [showFrameConfig, setShowFrameConfig] = useState(false);
const [frameParseFps, setFrameParseFps] = useState(30);
const [isGeneratingFrames, setIsGeneratingFrames] = useState(false);
const videoInputRef = useRef<HTMLInputElement>(null); const videoInputRef = useRef<HTMLInputElement>(null);
const dicomInputRef = useRef<HTMLInputElement>(null); const dicomInputRef = useRef<HTMLInputElement>(null);
@@ -57,7 +60,6 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const handleVideoSelect = (file: File) => { const handleVideoSelect = (file: File) => {
setPendingFile(file); setPendingFile(file);
setParseFps(30);
setShowVideoConfig(true); setShowVideoConfig(true);
}; };
@@ -69,11 +71,9 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const newProject = await createProject({ const newProject = await createProject({
name: pendingFile.name, name: pendingFile.name,
description: `导入于 ${new Date().toLocaleString()}`, description: `导入于 ${new Date().toLocaleString()}`,
parse_fps: parseFps,
}); });
const result = await uploadMedia(pendingFile, String(newProject.id)); const result = await uploadMedia(pendingFile, String(newProject.id));
await parseMedia(String(newProject.id)); alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时请在项目卡片点击“生成帧”。`);
alert(`上传成功: ${pendingFile.name}\n已保存至: ${result.url}`);
const data = await getProjects(); const data = await getProjects();
setProjects(data); setProjects(data);
} catch (err) { } catch (err) {
@@ -86,6 +86,31 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
} }
}; };
const openFrameConfig = (project: Project, event: React.MouseEvent) => {
event.stopPropagation();
setFrameProject(project);
setFrameParseFps(Math.round(project.parse_fps || 30));
setShowFrameConfig(true);
};
const handleGenerateFrames = async () => {
if (!frameProject?.id) return;
setIsGeneratingFrames(true);
try {
const task = await parseMedia(frameProject.id, { parseFps: frameParseFps });
alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`);
const data = await getProjects();
setProjects(data);
setShowFrameConfig(false);
setFrameProject(null);
} catch (err) {
console.error('Frame generation failed:', err);
alert('生成帧失败,请检查后端服务或项目源文件');
} finally {
setIsGeneratingFrames(false);
}
};
const handleDicomUpload = async (files: FileList | null) => { const handleDicomUpload = async (files: FileList | null) => {
if (!files || files.length === 0) return; if (!files || files.length === 0) return;
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm')); const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
@@ -209,7 +234,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
)} )}
<div className="absolute top-2 right-2 flex gap-2"> <div className="absolute top-2 right-2 flex gap-2">
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest"> <span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')} {proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))}
</span> </span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest"> <span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'ready' ? ( {proj.status === 'ready' ? (
@@ -235,6 +260,15 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
<span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> {proj.original_fps.toFixed(1)}fps</span> <span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> {proj.original_fps.toFixed(1)}fps</span>
)} )}
</div> </div>
{proj.video_path && (proj.frames ?? 0) === 0 && proj.status !== 'parsing' && (
<button
onClick={(event) => openFrameConfig(proj, event)}
className="mt-3 inline-flex items-center justify-center gap-2 rounded-md border border-cyan-500/30 bg-cyan-500/10 px-3 py-2 text-xs font-medium text-cyan-200 hover:bg-cyan-500/20 transition-colors"
>
<Images size={14} />
</button>
)}
</div> </div>
</div> </div>
))} ))}
@@ -245,24 +279,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
{showVideoConfig && pendingFile && ( {showVideoConfig && pendingFile && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm"> <div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl"> <div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
<h2 className="text-lg font-semibold text-white mb-4"></h2> <h2 className="text-lg font-semibold text-white mb-4"></h2>
<div className="space-y-4"> <div className="space-y-4">
<div className="text-sm text-gray-400">: <span className="text-gray-200">{pendingFile.name}</span></div> <div className="text-sm text-gray-400">: <span className="text-gray-200">{pendingFile.name}</span></div>
<div> <p className="text-xs leading-5 text-gray-500"> FPS</p>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"> (FPS)</label>
<div className="flex items-center gap-3">
<input
type="range"
min="1"
max="60"
value={parseFps}
onChange={(e) => setParseFps(parseInt(e.target.value))}
className="flex-1 accent-cyan-500"
/>
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{parseFps}</span>
</div>
<p className="text-[10px] text-gray-600 mt-1"></p>
</div>
</div> </div>
<div className="flex justify-end gap-3 mt-6"> <div className="flex justify-end gap-3 mt-6">
<button <button
@@ -282,6 +302,49 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
</div> </div>
)} )}
{/* Frame generation FPS config modal */}
{showFrameConfig && frameProject && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
<h2 className="text-lg font-semibold text-white mb-4"></h2>
<div className="space-y-4">
<div className="text-sm text-gray-400">: <span className="text-gray-200">{frameProject.name}</span></div>
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"> (FPS)</label>
<div className="flex items-center gap-3">
<input
type="range"
min="1"
max="60"
value={frameParseFps}
onChange={(e) => setFrameParseFps(parseInt(e.target.value))}
className="flex-1 accent-cyan-500"
/>
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{frameParseFps}</span>
</div>
<p className="text-[10px] text-gray-600 mt-1"></p>
</div>
</div>
<div className="flex justify-end gap-3 mt-6">
<button
onClick={() => { setShowFrameConfig(false); setFrameProject(null); }}
disabled={isGeneratingFrames}
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors disabled:opacity-50"
>
</button>
<button
onClick={handleGenerateFrames}
disabled={isGeneratingFrames}
className="px-4 py-2 rounded-lg text-sm font-medium bg-cyan-500 hover:bg-cyan-400 text-black transition-all disabled:opacity-60"
>
{isGeneratingFrames ? '入队中...' : '开始生成帧'}
</button>
</div>
</div>
</div>
)}
{/* New project modal */} {/* New project modal */}
{showModal && ( {showModal && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm"> <div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">

View File

@@ -20,12 +20,14 @@ describe('ToolsPalette', () => {
); );
fireEvent.click(screen.getByTitle('创建多边形 (P)')); fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)')); fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)')); fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')); fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon'); expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos'); expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
expect(onUndo).toHaveBeenCalled(); expect(onUndo).toHaveBeenCalled();
expect(onRedo).toHaveBeenCalled(); expect(onRedo).toHaveBeenCalled();
}); });

View File

@@ -1,5 +1,5 @@
import React from 'react'; import React from 'react';
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react'; import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
interface ToolsPaletteProps { interface ToolsPaletteProps {
@@ -23,6 +23,7 @@ export function ToolsPalette({
}: ToolsPaletteProps) { }: ToolsPaletteProps) {
const tools = [ const tools = [
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' }, { id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' }, { id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' }, { id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' }, { id: 'create_circle', icon: Circle, label: '创建圆 (O)' },

View File

@@ -82,23 +82,16 @@ describe('VideoWorkspace', () => {
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1'); expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
}); });
it('triggers parsing when a media project has no frames yet', async () => { it('does not auto-generate frames when a media project has no frames yet', async () => {
apiMock.getProjectFrames apiMock.getProjectFrames.mockResolvedValueOnce([]);
.mockResolvedValueOnce([])
.mockResolvedValueOnce([
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
render(<VideoWorkspace />); render(<VideoWorkspace />);
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1')); await waitFor(() => expect(apiMock.getProjectFrames).toHaveBeenCalledWith('1'));
expect(apiMock.getTask).toHaveBeenCalledWith(7); expect(apiMock.parseMedia).not.toHaveBeenCalled();
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({ expect(apiMock.getTask).not.toHaveBeenCalled();
id: '11', expect(useStore.getState().frames).toEqual([]);
url: '/parsed.jpg', expect(await screen.findByText('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”')).toBeInTheDocument();
})));
}); });
it('hydrates saved annotations after loading frames', async () => { it('hydrates saved annotations after loading frames', async () => {

View File

@@ -8,10 +8,8 @@ import {
exportMasks, exportMasks,
getProjectAnnotations, getProjectAnnotations,
getProjectFrames, getProjectFrames,
getTask,
getTemplates, getTemplates,
importGtMask, importGtMask,
parseMedia,
propagateMasks, propagateMasks,
saveAnnotation, saveAnnotation,
updateAnnotation, updateAnnotation,
@@ -23,10 +21,6 @@ import { FrameTimeline } from './FrameTimeline';
import { ModelStatusBadge } from './ModelStatusBadge'; import { ModelStatusBadge } from './ModelStatusBadge';
import type { Frame } from '../store/useStore'; import type { Frame } from '../store/useStore';
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const gtMaskInputRef = React.useRef<HTMLInputElement>(null); const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
const activeTool = useStore((state) => state.activeTool); const activeTool = useStore((state) => state.activeTool);
@@ -72,64 +66,31 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const data = await getProjectFrames(String(currentProject.id)); const data = await getProjectFrames(String(currentProject.id));
if (cancelled) return; if (cancelled) return;
if (data.length === 0 && currentProject.video_path) { const mappedFrames = data.map((f) => ({
// No frames yet but video exists -> queue parsing and poll the task. id: String(f.id),
try { projectId: String(f.project_id),
const task = await parseMedia(String(currentProject.id)); index: f.frame_index,
if (cancelled) return; url: f.image_url,
setStatusMessage(`解析任务已入队 #${task.id}`); width: f.width ?? 0,
let completed = false; height: f.height ?? 0,
for (let attempt = 0; attempt < 60; attempt += 1) { timestampMs: f.timestamp_ms ?? undefined,
const freshTask = await getTask(task.id); sourceFrameNumber: f.source_frame_number ?? undefined,
if (cancelled) return; }));
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`); setFrames(mappedFrames);
if (freshTask.status === 'success') { setCurrentFrame(0);
completed = true; if (mappedFrames.length === 0) {
break; setMasks([]);
} if (currentProject.status === 'parsing') {
if (freshTask.status === 'failed') { setStatusMessage('生成帧任务正在后台运行,可在 Dashboard 查看进度');
setStatusMessage(freshTask.error || '解析任务失败'); } else if (currentProject.video_path) {
return; setStatusMessage('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”');
} } else {
await sleep(2000); setStatusMessage('当前项目没有可显示帧');
}
if (!completed) {
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
return;
}
const fresh = await getProjectFrames(String(currentProject.id));
if (cancelled) return;
const mappedFrames = fresh.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) {
console.error('Parse failed:', err);
} }
} else { return;
const mappedFrames = data.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} }
setStatusMessage('');
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) { } catch (err) {
console.error('Failed to load frames:', err); console.error('Failed to load frames:', err);
} }

View File

@@ -2,12 +2,14 @@ import { afterEach, describe, expect, it, vi } from 'vitest';
describe('progress websocket client', () => { describe('progress websocket client', () => {
afterEach(() => { afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks(); vi.restoreAllMocks();
vi.resetModules(); vi.resetModules();
vi.unstubAllGlobals(); vi.unstubAllGlobals();
}); });
it('connects using the configured URL and reports open state', async () => { it('connects using the configured URL, reports open state, and sends heartbeat pings', async () => {
vi.useFakeTimers();
const instances: any[] = []; const instances: any[] = [];
class FakeWebSocket { class FakeWebSocket {
static CONNECTING = 0; static CONNECTING = 0;
@@ -21,14 +23,26 @@ describe('progress websocket client', () => {
instances.push(this); instances.push(this);
} }
close = vi.fn(); close = vi.fn();
send = vi.fn();
} }
vi.stubGlobal('WebSocket', FakeWebSocket); vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket'); const { progressWS } = await import('./websocket');
const statusCallback = vi.fn();
const unsubscribeStatus = progressWS.onStatus(statusCallback);
progressWS.connect(); progressWS.connect();
instances[0].onopen?.();
expect(instances[0].url).toContain('/ws/progress'); expect(instances[0].url).toContain('/ws/progress');
expect(progressWS.isConnected()).toBe(true); expect(progressWS.isConnected()).toBe(true);
expect(statusCallback).toHaveBeenCalledWith('connected');
expect(instances[0].send).toHaveBeenCalledWith('ping');
vi.advanceTimersByTime(15000);
expect(instances[0].send).toHaveBeenCalledTimes(2);
unsubscribeStatus();
progressWS.disconnect();
}); });
it('subscribes and unsubscribes progress callbacks', async () => { it('subscribes and unsubscribes progress callbacks', async () => {
@@ -43,4 +57,41 @@ describe('progress websocket client', () => {
expect(callback).toHaveBeenCalledTimes(1); expect(callback).toHaveBeenCalledTimes(1);
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' }); expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
}); });
it('notifies connection status changes and schedules reconnect on close', async () => {
vi.useFakeTimers();
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
static OPEN = 1;
readyState = FakeWebSocket.OPEN;
onopen?: () => void;
onmessage?: (event: MessageEvent) => void;
onclose?: () => void;
onerror?: () => void;
constructor(public url: string) {
instances.push(this);
}
close = vi.fn();
send = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
const statusCallback = vi.fn();
const unsubscribeStatus = progressWS.onStatus(statusCallback);
progressWS.connect();
instances[0].onopen?.();
instances[0].onclose?.();
expect(statusCallback).toHaveBeenCalledWith('disconnected');
expect(statusCallback).toHaveBeenCalledWith('reconnecting');
vi.advanceTimersByTime(3000);
expect(instances).toHaveLength(2);
unsubscribeStatus();
progressWS.disconnect();
});
}); });

View File

@@ -1,6 +1,8 @@
import { WS_PROGRESS_URL } from './config'; import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void; type ProgressCallback = (data: ProgressMessage) => void;
type ConnectionStatus = 'connecting' | 'connected' | 'reconnecting' | 'disconnected';
type StatusCallback = (status: ConnectionStatus) => void;
interface ProgressMessage { interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled'; type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
@@ -20,9 +22,12 @@ class ProgressWebSocket {
private ws: WebSocket | null = null; private ws: WebSocket | null = null;
private url: string; private url: string;
private callbacks: Set<ProgressCallback> = new Set(); private callbacks: Set<ProgressCallback> = new Set();
private statusCallbacks: Set<StatusCallback> = new Set();
private reconnectTimer: ReturnType<typeof setTimeout> | null = null; private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
private heartbeatTimer: ReturnType<typeof setInterval> | null = null;
private reconnectInterval = 3000; private reconnectInterval = 3000;
private maxReconnectInterval = 30000; private maxReconnectInterval = 30000;
private heartbeatInterval = 15000;
private shouldReconnect = false; private shouldReconnect = false;
private shouldCloseAfterOpen = false; private shouldCloseAfterOpen = false;
private currentInterval = 3000; private currentInterval = 3000;
@@ -38,6 +43,7 @@ class ProgressWebSocket {
this.shouldReconnect = true; this.shouldReconnect = true;
this.shouldCloseAfterOpen = false; this.shouldCloseAfterOpen = false;
this.notifyStatus('connecting');
try { try {
this.ws = new WebSocket(this.url); this.ws = new WebSocket(this.url);
@@ -50,6 +56,8 @@ class ProgressWebSocket {
return; return;
} }
this.currentInterval = this.reconnectInterval; this.currentInterval = this.reconnectInterval;
this.startHeartbeat();
this.notifyStatus('connected');
console.log('[WebSocket] Connected to progress stream'); console.log('[WebSocket] Connected to progress stream');
}; };
@@ -64,7 +72,9 @@ class ProgressWebSocket {
this.ws.onclose = () => { this.ws.onclose = () => {
console.log('[WebSocket] Connection closed'); console.log('[WebSocket] Connection closed');
this.stopHeartbeat();
this.ws = null; this.ws = null;
this.notifyStatus('disconnected');
if (this.shouldReconnect) { if (this.shouldReconnect) {
this.scheduleReconnect(); this.scheduleReconnect();
} }
@@ -72,7 +82,9 @@ class ProgressWebSocket {
this.ws.onerror = () => { this.ws.onerror = () => {
// 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错 // 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错
this.stopHeartbeat();
this.ws = null; this.ws = null;
this.notifyStatus('disconnected');
if (this.shouldReconnect) { if (this.shouldReconnect) {
this.scheduleReconnect(); this.scheduleReconnect();
} }
@@ -85,6 +97,7 @@ class ProgressWebSocket {
disconnect() { disconnect() {
this.shouldReconnect = false; this.shouldReconnect = false;
this.stopHeartbeat();
if (this.reconnectTimer) { if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer); clearTimeout(this.reconnectTimer);
this.reconnectTimer = null; this.reconnectTimer = null;
@@ -102,6 +115,7 @@ class ProgressWebSocket {
this.ws.close(); this.ws.close();
} }
this.ws = null; this.ws = null;
this.notifyStatus('disconnected');
} }
onProgress(callback: ProgressCallback) { onProgress(callback: ProgressCallback) {
@@ -111,21 +125,53 @@ class ProgressWebSocket {
}; };
} }
onStatus(callback: StatusCallback) {
this.statusCallbacks.add(callback);
callback(this.isConnected() ? 'connected' : 'disconnected');
return () => {
this.statusCallbacks.delete(callback);
};
}
private scheduleReconnect() { private scheduleReconnect() {
if (this.reconnectTimer) { if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer); clearTimeout(this.reconnectTimer);
} }
this.notifyStatus('reconnecting');
this.reconnectTimer = setTimeout(() => { this.reconnectTimer = setTimeout(() => {
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`); console.log('[WebSocket] Reconnecting to progress stream...');
this.connect(); this.connect();
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval); this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
}, this.currentInterval); }, this.currentInterval);
} }
private startHeartbeat() {
this.stopHeartbeat();
this.sendHeartbeat();
this.heartbeatTimer = setInterval(() => this.sendHeartbeat(), this.heartbeatInterval);
}
private stopHeartbeat() {
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
}
private sendHeartbeat() {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send('ping');
}
}
private notifyStatus(status: ConnectionStatus) {
this.statusCallbacks.forEach((cb) => cb(status));
}
isConnected(): boolean { isConnected(): boolean {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN; return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
} }
} }
export const progressWS = new ProgressWebSocket(); export const progressWS = new ProgressWebSocket();
export type { ProgressMessage }; export type { ConnectionStatus, ProgressMessage };

View File

@@ -102,6 +102,15 @@ vi.mock('react-konva', () => ({
props.onClick?.(konvaEvent); props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation(); if (konvaEvent.cancelBubble) event.stopPropagation();
}} }}
onDoubleClick={(event) => {
const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onDblClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
/> />
), ),
})); }));