feat: 完善 AI 分割与工作区标注闭环
功能增加: - 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。 - 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。 - 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。 - 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。 - 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。 bugfix: - 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。 - 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。 - 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。 - 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。 - 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。 测试和文档: - 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。 - 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。 - 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。
This commit is contained in:
15
AGENTS.md
15
AGENTS.md
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、显式视频生成帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
||||||
|
|
||||||
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
||||||
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
||||||
@@ -219,12 +219,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
|
|
||||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
||||||
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
|
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
|
||||||
3. 上传资源:视频走 `/api/media/upload`;DICOM 批量走 `/api/media/upload/dicom`。
|
3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。
|
||||||
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。
|
4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。
|
||||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
||||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 可拖动/删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||||
8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播;`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。
|
8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播,但不支持文本语义提示;AI 页面在 SAM 2 纯文本时提示改用点提示或切换 SAM 3,SAM 2 多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,semantic 请求会把正数 `options.min_score` 传给 external worker 作为置信度阈值,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。
|
||||||
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`。
|
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`。
|
||||||
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||||
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||||
@@ -237,7 +237,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
|
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
|
||||||
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。
|
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。
|
||||||
- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
|
- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
|
||||||
- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
|
- Polygon 顶点编辑和新增顶点会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
|
||||||
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
|
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
|
||||||
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
||||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
||||||
@@ -245,8 +245,9 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`;SAM 2 路径使用视频 predictor,SAM 3 路径使用独立 Python helper 的官方 video tracker,完成后刷新后端已保存标注。
|
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`;SAM 2 路径使用视频 predictor,SAM 3 路径使用独立 Python helper 的官方 video tracker,完成后刷新后端已保存标注。
|
||||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||||
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
||||||
|
- 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。
|
||||||
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
||||||
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。
|
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`;前端 WebSocket 客户端通过 `onopen/onclose/onerror` 更新连接状态,并定时发送 `ping` 心跳。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -12,9 +12,9 @@
|
|||||||
|
|
||||||
## 核心功能
|
## 核心功能
|
||||||
|
|
||||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析
|
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS
|
||||||
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性
|
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示;SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性
|
||||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||||
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
||||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||||
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
|
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
|
||||||
@@ -325,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
|||||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的 WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||||
|
|
||||||
### 步骤 7: 安装前端依赖并构建
|
### 步骤 7: 安装前端依赖并构建
|
||||||
|
|
||||||
@@ -461,6 +461,8 @@ pip install -e . --no-build-isolation
|
|||||||
|
|
||||||
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
||||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||||
|
- 工作区 SAM 2 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。
|
||||||
|
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks` 并自动选中;右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。
|
||||||
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed,调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
|
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed,调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
|
||||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||||
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||||
|
|||||||
@@ -340,7 +340,21 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|||||||
|
|
||||||
elif prompt_type == "semantic":
|
elif prompt_type == "semantic":
|
||||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
min_score = options.get("min_score")
|
||||||
|
confidence_threshold = None
|
||||||
|
if min_score is not None:
|
||||||
|
try:
|
||||||
|
parsed_min_score = float(min_score)
|
||||||
|
if parsed_min_score > 0:
|
||||||
|
confidence_threshold = parsed_min_score
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
confidence_threshold = None
|
||||||
|
polygons, scores = sam_registry.predict_semantic(
|
||||||
|
payload.model,
|
||||||
|
image,
|
||||||
|
text,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||||
@@ -352,6 +366,13 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|||||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
|
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
|
||||||
|
logger.info(
|
||||||
|
"AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d",
|
||||||
|
payload.model or "default",
|
||||||
|
prompt_type,
|
||||||
|
payload.image_id,
|
||||||
|
len(polygons),
|
||||||
|
)
|
||||||
return {"polygons": polygons, "scores": scores}
|
return {"polygons": polygons, "scores": scores}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class SAM2Engine:
|
|||||||
masks, scores, _ = self._predictor.predict(
|
masks, scores, _ = self._predictor.predict(
|
||||||
point_coords=pts,
|
point_coords=pts,
|
||||||
point_labels=lbls,
|
point_labels=lbls,
|
||||||
multimask_output=True,
|
multimask_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
polygons = []
|
polygons = []
|
||||||
@@ -335,16 +335,16 @@ class SAM2Engine:
|
|||||||
masks, scores, _ = self._predictor.predict(
|
masks, scores, _ = self._predictor.predict(
|
||||||
point_coords=pts,
|
point_coords=pts,
|
||||||
point_labels=lbls,
|
point_labels=lbls,
|
||||||
multimask_output=True,
|
multimask_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
polygons = []
|
polygons = []
|
||||||
for m in masks[:3]: # Limit to top 3 masks
|
for m in masks[:1]:
|
||||||
poly = self._mask_to_polygon(m)
|
poly = self._mask_to_polygon(m)
|
||||||
if poly:
|
if poly:
|
||||||
polygons.append(poly)
|
polygons.append(poly)
|
||||||
|
|
||||||
return polygons, scores[:3].tolist()
|
return polygons, scores[:1].tolist()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class SAM3Engine:
|
|||||||
*,
|
*,
|
||||||
text: str = "",
|
text: str = "",
|
||||||
box: list[float] | None = None,
|
box: list[float] | None = None,
|
||||||
|
confidence_threshold: float | None = None,
|
||||||
) -> tuple[list[list[list[float]]], list[float]]:
|
) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
status = self._external_status(force=True)
|
status = self._external_status(force=True)
|
||||||
if not status.get("available"):
|
if not status.get("available"):
|
||||||
@@ -279,7 +280,11 @@ class SAM3Engine:
|
|||||||
"box": box,
|
"box": box,
|
||||||
"model_version": settings.sam3_model_version,
|
"model_version": settings.sam3_model_version,
|
||||||
"checkpoint_path": self._checkpoint_path(),
|
"checkpoint_path": self._checkpoint_path(),
|
||||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
"confidence_threshold": (
|
||||||
|
confidence_threshold
|
||||||
|
if confidence_threshold is not None
|
||||||
|
else settings.sam3_confidence_threshold
|
||||||
|
),
|
||||||
},
|
},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
),
|
),
|
||||||
@@ -312,8 +317,18 @@ class SAM3Engine:
|
|||||||
raise RuntimeError(str(payload["error"]))
|
raise RuntimeError(str(payload["error"]))
|
||||||
return payload.get("polygons", []), payload.get("scores", [])
|
return payload.get("polygons", []), payload.get("scores", [])
|
||||||
|
|
||||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
def _predict_semantic_external(
|
||||||
return self._predict_external(image, "semantic", text=text)
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
text: str,
|
||||||
|
confidence_threshold: float | None = None,
|
||||||
|
) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
|
return self._predict_external(
|
||||||
|
image,
|
||||||
|
"semantic",
|
||||||
|
text=text,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
return self._predict_external(image, "box", box=box)
|
return self._predict_external(image, "box", box=box)
|
||||||
@@ -378,11 +393,16 @@ class SAM3Engine:
|
|||||||
raise RuntimeError(str(payload["error"]))
|
raise RuntimeError(str(payload["error"]))
|
||||||
return payload.get("frames", [])
|
return payload.get("frames", [])
|
||||||
|
|
||||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
def predict_semantic(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
text: str,
|
||||||
|
confidence_threshold: float | None = None,
|
||||||
|
) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||||
if not self._can_load() and self._external_status().get("available"):
|
if not self._can_load() and self._external_status().get("available"):
|
||||||
return self._predict_semantic_external(image, text)
|
return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold)
|
||||||
if not self._ensure_ready():
|
if not self._ensure_ready():
|
||||||
raise RuntimeError(self.status()["message"])
|
raise RuntimeError(self.status()["message"])
|
||||||
|
|
||||||
|
|||||||
@@ -190,7 +190,9 @@ def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
|
|||||||
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
|
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
|
||||||
masks = _to_numpy(output.get("masks", []))
|
masks = _to_numpy(output.get("masks", []))
|
||||||
scores = _to_numpy(output.get("scores", []))
|
scores = _to_numpy(output.get("scores", []))
|
||||||
if masks.ndim == 4:
|
if masks.ndim == 2:
|
||||||
|
masks = masks[None, :, :]
|
||||||
|
elif masks.ndim == 4:
|
||||||
masks = masks[:, 0]
|
masks = masks[:, 0]
|
||||||
elif masks.ndim == 3 and masks.shape[0] == 1:
|
elif masks.ndim == 3 and masks.shape[0] == 1:
|
||||||
masks = masks[None, 0]
|
masks = masks[None, 0]
|
||||||
|
|||||||
@@ -83,10 +83,20 @@ class SAMRegistry:
|
|||||||
def predict_auto(self, model_id: str | None, image: Any):
|
def predict_auto(self, model_id: str | None, image: Any):
|
||||||
return self._ensure_available(model_id).predict_auto(image)
|
return self._ensure_available(model_id).predict_auto(image)
|
||||||
|
|
||||||
def predict_semantic(self, model_id: str | None, image: Any, text: str):
|
def predict_semantic(
|
||||||
|
self,
|
||||||
|
model_id: str | None,
|
||||||
|
image: Any,
|
||||||
|
text: str,
|
||||||
|
confidence_threshold: float | None = None,
|
||||||
|
):
|
||||||
model = self.normalize_model_id(model_id)
|
model = self.normalize_model_id(model_id)
|
||||||
if model == "sam3":
|
if model == "sam3":
|
||||||
return self._ensure_available(model).predict_semantic(image, text)
|
return self._ensure_available(model).predict_semantic(
|
||||||
|
image,
|
||||||
|
text,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
)
|
||||||
return self._ensure_available(model).predict_auto(image)
|
return self._ensure_available(model).predict_auto(image)
|
||||||
|
|
||||||
def propagate_video(
|
def propagate_video(
|
||||||
|
|||||||
@@ -89,15 +89,25 @@ def test_predict_applies_crop_and_background_filter_options(client, monkeypatch)
|
|||||||
|
|
||||||
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||||
_, frame, _ = _create_project_and_frame(client)
|
_, frame, _ = _create_project_and_frame(client)
|
||||||
|
calls = {}
|
||||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||||
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
||||||
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||||
[0.8],
|
[0.8],
|
||||||
))
|
))
|
||||||
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
|
|
||||||
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
|
def fake_predict_semantic(model, image, text, confidence_threshold=None):
|
||||||
[0.5],
|
calls["semantic"] = {
|
||||||
))
|
"model": model,
|
||||||
|
"text": text,
|
||||||
|
"confidence_threshold": confidence_threshold,
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
|
||||||
|
[0.5],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", fake_predict_semantic)
|
||||||
|
|
||||||
box_response = client.post("/api/ai/predict", json={
|
box_response = client.post("/api/ai/predict", json={
|
||||||
"image_id": frame["id"],
|
"image_id": frame["id"],
|
||||||
@@ -108,12 +118,19 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
|||||||
"image_id": frame["id"],
|
"image_id": frame["id"],
|
||||||
"prompt_type": "semantic",
|
"prompt_type": "semantic",
|
||||||
"prompt_data": "胆囊",
|
"prompt_data": "胆囊",
|
||||||
|
"model": "sam3",
|
||||||
|
"options": {"min_score": 0.05},
|
||||||
})
|
})
|
||||||
|
|
||||||
assert box_response.status_code == 200
|
assert box_response.status_code == 200
|
||||||
assert box_response.json()["scores"] == [0.8]
|
assert box_response.json()["scores"] == [0.8]
|
||||||
assert semantic_response.status_code == 200
|
assert semantic_response.status_code == 200
|
||||||
assert semantic_response.json()["scores"] == [0.5]
|
assert semantic_response.json()["scores"] == [0.5]
|
||||||
|
assert calls["semantic"] == {
|
||||||
|
"model": "sam3",
|
||||||
|
"text": "胆囊",
|
||||||
|
"confidence_threshold": 0.05,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
||||||
|
|||||||
63
backend/tests/test_sam2_engine.py
Normal file
63
backend/tests/test_sam2_engine.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from services.sam2_engine import SAM2Engine
|
||||||
|
|
||||||
|
|
||||||
|
class _FakePredictor:
|
||||||
|
def __init__(self, masks, scores):
|
||||||
|
self.masks = masks
|
||||||
|
self.scores = scores
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def set_image(self, _image):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict(self, **kwargs):
|
||||||
|
self.calls.append(kwargs)
|
||||||
|
return self.masks, self.scores, None
|
||||||
|
|
||||||
|
|
||||||
|
def _mask(offset=0):
|
||||||
|
mask = np.zeros((32, 32), dtype=np.uint8)
|
||||||
|
mask[4 + offset:20 + offset, 5 + offset:22 + offset] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _ready_engine(monkeypatch, predictor):
|
||||||
|
monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True)
|
||||||
|
engine = SAM2Engine()
|
||||||
|
engine._model_loaded = True
|
||||||
|
engine._predictor = predictor
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam2_point_prediction_requests_single_best_mask(monkeypatch):
|
||||||
|
predictor = _FakePredictor(
|
||||||
|
np.array([_mask()], dtype=np.uint8),
|
||||||
|
np.array([0.92], dtype=np.float32),
|
||||||
|
)
|
||||||
|
engine = _ready_engine(monkeypatch, predictor)
|
||||||
|
|
||||||
|
polygons, scores = engine.predict_points(
|
||||||
|
np.zeros((32, 32, 3), dtype=np.uint8),
|
||||||
|
[[0.5, 0.5]],
|
||||||
|
[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictor.calls[0]["multimask_output"] is False
|
||||||
|
assert len(polygons) == 1
|
||||||
|
assert scores == [0.9200000166893005]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam2_auto_prediction_keeps_single_best_mask(monkeypatch):
|
||||||
|
predictor = _FakePredictor(
|
||||||
|
np.array([_mask(0), _mask(2), _mask(4)], dtype=np.uint8),
|
||||||
|
np.array([0.8, 0.7, 0.6], dtype=np.float32),
|
||||||
|
)
|
||||||
|
engine = _ready_engine(monkeypatch, predictor)
|
||||||
|
|
||||||
|
polygons, scores = engine.predict_auto(np.zeros((32, 32, 3), dtype=np.uint8))
|
||||||
|
|
||||||
|
assert predictor.calls[0]["multimask_output"] is False
|
||||||
|
assert len(polygons) == 1
|
||||||
|
assert scores == [0.800000011920929]
|
||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from services.sam3_engine import SAM3Engine
|
from services.sam3_engine import SAM3Engine
|
||||||
from services.sam3_external_worker import _to_numpy
|
from services.sam3_external_worker import _prediction_to_response, _to_numpy
|
||||||
|
|
||||||
|
|
||||||
class _Completed:
|
class _Completed:
|
||||||
@@ -98,6 +98,41 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
|||||||
assert any("--request" in args for args in calls)
|
assert any("--request" in args for args in calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam3_predict_semantic_allows_request_threshold_override(tmp_path, monkeypatch):
|
||||||
|
_external_settings(monkeypatch, tmp_path / "python")
|
||||||
|
|
||||||
|
def fake_run(args, **_kwargs):
|
||||||
|
if "--status" in args:
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"available": True,
|
||||||
|
"package_available": True,
|
||||||
|
"checkpoint_access": True,
|
||||||
|
"python_ok": True,
|
||||||
|
"torch_ok": True,
|
||||||
|
"cuda_available": True,
|
||||||
|
"device": "cuda",
|
||||||
|
"message": "ready",
|
||||||
|
}))
|
||||||
|
request_path = Path(args[-1])
|
||||||
|
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||||
|
assert request["confidence_threshold"] == 0.05
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"polygons": [[[0.2, 0.2], [0.6, 0.2], [0.6, 0.6]]],
|
||||||
|
"scores": [0.07],
|
||||||
|
}))
|
||||||
|
|
||||||
|
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
polygons, scores = SAM3Engine().predict_semantic(
|
||||||
|
np.zeros((8, 8, 3), dtype=np.uint8),
|
||||||
|
"surgical scene",
|
||||||
|
confidence_threshold=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(polygons) == 1
|
||||||
|
assert scores == [0.07]
|
||||||
|
|
||||||
|
|
||||||
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
|
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
|
||||||
_external_settings(monkeypatch, tmp_path / "python")
|
_external_settings(monkeypatch, tmp_path / "python")
|
||||||
|
|
||||||
@@ -243,3 +278,16 @@ def test_sam3_worker_casts_floating_tensors_before_numpy():
|
|||||||
|
|
||||||
assert tensor.float_called is True
|
assert tensor.float_called is True
|
||||||
assert result.tolist() == [1.0]
|
assert result.tolist() == [1.0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam3_worker_converts_single_2d_mask_to_polygon():
|
||||||
|
mask = np.zeros((12, 12), dtype=np.uint8)
|
||||||
|
mask[2:10, 3:9] = 1
|
||||||
|
|
||||||
|
result = _prediction_to_response({
|
||||||
|
"masks": mask,
|
||||||
|
"scores": np.array([0.82], dtype=np.float32),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert len(result["polygons"]) == 1
|
||||||
|
assert result["scores"] == [0.8199999928474426]
|
||||||
|
|||||||
@@ -65,12 +65,12 @@
|
|||||||
### 项目与拆帧
|
### 项目与拆帧
|
||||||
|
|
||||||
1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。
|
1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。
|
||||||
2. 上传视频时先 `createProject()`,再 `uploadMedia()`,再 `parseMedia()`。
|
2. 上传视频时先 `createProject()`,再 `uploadMedia()`;导入视频不自动调用 `parseMedia()`。
|
||||||
3. 后端 `media.py` 把原始文件上传到 MinIO。
|
3. 后端 `media.py` 把原始文件上传到 MinIO。
|
||||||
4. `parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。
|
4. 用户在项目库点击“生成帧”并选择 FPS 后,`parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。
|
||||||
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
|
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
|
||||||
6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。
|
6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。
|
||||||
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
|
7. 工作区只通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL;若项目有源视频但无帧,会提示先回项目库生成帧。
|
||||||
8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。
|
8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。
|
||||||
|
|
||||||
### 工作区浏览
|
### 工作区浏览
|
||||||
|
|||||||
@@ -46,8 +46,9 @@
|
|||||||
| 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 |
|
| 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 |
|
||||||
| 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` |
|
| 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` |
|
||||||
| 新建项目 | 真实可用 | 调用 `POST /api/projects` |
|
| 新建项目 | 真实可用 | 调用 `POST /api/projects` |
|
||||||
| 导入视频文件 | 真实可用 | 创建项目、上传文件、触发拆帧、刷新项目列表 |
|
| 导入视频文件 | 真实可用 | 创建项目、上传源视频、刷新项目列表;不会自动拆帧 |
|
||||||
| 解析 FPS 滑块 | 真实可用 | 值传入 `createProject({ parse_fps })` |
|
| 生成帧按钮 | 真实可用 | 仅对已导入源视频且尚无帧、非 parsing 状态的项目显示,调用 `parseMedia(projectId, { parseFps })` |
|
||||||
|
| 生成帧 FPS 滑块 | 真实可用 | 值传入 `/api/media/parse?parse_fps=...`,决定后台拆帧目标 FPS |
|
||||||
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
|
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
|
||||||
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
|
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
|
||||||
| 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 |
|
| 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 |
|
||||||
@@ -59,7 +60,7 @@
|
|||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 当前项目名 | 真实可用 | 读取 `currentProject.name` |
|
| 当前项目名 | 真实可用 | 读取 `currentProject.name` |
|
||||||
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
|
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
|
||||||
| 无帧时触发解析 | 真实可用 | 如果 `video_path` 存在会调用 `parseMedia()` 创建异步任务,并轮询 `GET /api/tasks/{id}` 等待完成 |
|
| 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 |
|
||||||
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
|
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
|
||||||
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
|
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
|
||||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||||
@@ -93,6 +94,7 @@
|
|||||||
| 元素 | 状态 | 说明 |
|
| 元素 | 状态 | 说明 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
||||||
|
| 调整多边形 | 真实可用 | 选中 polygon mask 后显示顶点和边中点;支持拖动顶点、点击边中点插点、双击边界按位置插点 |
|
||||||
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
|
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
|
||||||
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
|
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
|
||||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
||||||
@@ -130,7 +132,8 @@
|
|||||||
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 |
|
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 |
|
||||||
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 |
|
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 |
|
||||||
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 |
|
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 |
|
||||||
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
|
| 执行高精度语义分割 | 真实可用 | 使用当前项目帧调用 `/api/ai/predict`;SAM 2 需要点提示且只采用最高分候选,SAM 3 需要文本语义提示;生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 |
|
||||||
|
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的 mask,便于继续调轮廓和归档 |
|
||||||
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
||||||
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
|
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
|
||||||
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
|
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
|
||||||
@@ -153,6 +156,6 @@
|
|||||||
|
|
||||||
## 总体结论
|
## 总体结论
|
||||||
|
|
||||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||||
|
|
||||||
当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。
|
当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ Authorization: Bearer <token>
|
|||||||
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
|
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
|
||||||
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
|
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
|
||||||
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
||||||
| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;支持 `parse_fps`、`max_frames`、`target_width` |
|
| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;由项目库“生成帧”显式调用,支持 `parse_fps`、`max_frames`、`target_width` |
|
||||||
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
||||||
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
|
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
|
||||||
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
|
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
|
||||||
@@ -91,6 +91,21 @@ Authorization: Bearer <token>
|
|||||||
| GET | `/health` | 健康检查 |
|
| GET | `/health` | 健康检查 |
|
||||||
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
|
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
|
||||||
|
|
||||||
|
### WebSocket 进度通道
|
||||||
|
|
||||||
|
`/ws/progress` 用于 Dashboard 实时接收后台任务状态。前端连接成功后会定时发送 `ping` 作为心跳;后端收到任意文本心跳后返回:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"status": "connected",
|
||||||
|
"message": "Progress stream active",
|
||||||
|
"timestamp": "2026-05-01T00:00:00+00:00"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
后台任务进度由 Celery worker 写入 Redis `seg:progress` 频道,再由 FastAPI 转发到当前活跃 WebSocket 连接。Dashboard 的“WebSocket 已连接/断开”状态来自浏览器 WebSocket 的 `onopen/onclose/onerror`,不再依赖是否刚好收到任务进度事件。
|
||||||
|
|
||||||
## 关键请求体
|
## 关键请求体
|
||||||
|
|
||||||
### 登录
|
### 登录
|
||||||
@@ -172,7 +187,11 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
|
|||||||
- `point`
|
- `point`
|
||||||
- `box`
|
- `box`
|
||||||
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points` 和 `labels`。
|
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points` 和 `labels`。
|
||||||
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。
|
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理。前端 AI 页面不会再用 SAM 2 发送纯文本 semantic;SAM 2 的交互入口应使用点/框提示。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。
|
||||||
|
|
||||||
|
SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同一提示下多个备选 mask 被前端叠加显示。
|
||||||
|
|
||||||
|
工作区 SAM 2 请求包含反向点时,`CanvasArea` 会发送 `options.auto_filter_background=true` 和 `options.min_score=0.05`;如果负向点过滤后没有可用 polygon,前端会移除当前旧候选 mask 并要求重新框选或添加正向点。
|
||||||
|
|
||||||
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。
|
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。
|
||||||
|
|
||||||
@@ -180,7 +199,7 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
|
|||||||
|
|
||||||
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
|
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
|
||||||
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
|
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
|
||||||
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
|
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值;对 SAM 3 semantic 请求也会作为 external worker 的 `confidence_threshold` 传入,避免本地 checkpoint 在默认高阈值下返回 0 个 mask。
|
||||||
|
|
||||||
后端响应:
|
后端响应:
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@
|
|||||||
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
|
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
|
||||||
- 用户可以新建项目,前端调用 `POST /api/projects`。
|
- 用户可以新建项目,前端调用 `POST /api/projects`。
|
||||||
- 用户可以选择项目,进入工作区。
|
- 用户可以选择项目,进入工作区。
|
||||||
- 用户可以导入视频文件,前端创建项目、上传文件、触发拆帧、刷新项目列表。
|
- 用户可以导入视频文件,前端创建项目、上传文件并刷新项目列表;导入视频不自动拆帧。
|
||||||
|
- 用户可以对已导入且尚未生成帧的视频项目点击“生成帧”,在弹窗中选择目标 FPS 后创建拆帧任务。
|
||||||
- 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。
|
- 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。
|
||||||
- 后端支持项目创建、列表、详情、局部更新和删除。
|
- 后端支持项目创建、列表、详情、局部更新和删除。
|
||||||
- 后端支持项目帧创建、列表和单帧查询。
|
- 后端支持项目帧创建、列表和单帧查询。
|
||||||
@@ -42,7 +43,7 @@
|
|||||||
## R4 工作区与帧浏览
|
## R4 工作区与帧浏览
|
||||||
|
|
||||||
- 工作区根据当前项目加载帧列表。
|
- 工作区根据当前项目加载帧列表。
|
||||||
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。
|
- 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧。
|
||||||
- Canvas 显示当前帧图片。
|
- Canvas 显示当前帧图片。
|
||||||
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
||||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
||||||
@@ -57,7 +58,9 @@
|
|||||||
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
|
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
|
||||||
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
||||||
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
|
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
|
||||||
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
|
- 工具栏提供“调整多边形”工具,用户可以点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
|
||||||
|
- 顶点编辑态显示边中点插入手柄;点击边中点会在该边中间新增顶点。
|
||||||
|
- “调整多边形”工具下双击 polygon 边界时,会在最接近的线段上按双击位置新增顶点。
|
||||||
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
|
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
|
||||||
- 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。
|
- 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。
|
||||||
- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
|
- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
|
||||||
@@ -75,14 +78,19 @@
|
|||||||
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
||||||
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
||||||
- 工作区 SAM 2 框选会建立一个候选 mask;后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。
|
- 工作区 SAM 2 框选会建立一个候选 mask;后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。
|
||||||
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
- 工作区 SAM 2 一旦包含反向点,会随请求启用 `auto_filter_background` 和 `min_score=0.05`;若后端判定反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask,避免继续显示已被否定的区域。
|
||||||
|
- SAM 2 不支持文本语义提示;AI 页面在 SAM 2 下输入纯文本时会提示用户改用点提示或切换 SAM 3,不再回退到自动分割。
|
||||||
|
- SAM 2 点提示和 auto fallback 默认只采用一个最高分候选 mask,避免多个候选 mask 作为同一结果重叠显示。
|
||||||
|
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签。
|
||||||
|
- AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 以便继续编辑轮廓和归档保存。
|
||||||
|
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理。
|
||||||
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。
|
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。
|
||||||
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。
|
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。
|
||||||
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed,调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
|
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed,调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
|
||||||
- `POST /api/ai/propagate` 支持 `model=sam2` 或 `model=sam3`;SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。
|
- `POST /api/ai/propagate` 支持 `model=sam2` 或 `model=sam3`;SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。
|
||||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 分别标记为 `sam2_propagation` 或 `sam3_propagation`,并保留 label、color 和 class 元数据。
|
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 分别标记为 `sam2_propagation` 或 `sam3_propagation`,并保留 label、color 和 class 元数据。
|
||||||
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。
|
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。
|
||||||
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
|
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon;SAM 3 semantic 会用 `min_score` 控制 external worker 的置信度阈值。
|
||||||
- 后端返回 `polygons` 和 `scores`。
|
- 后端返回 `polygons` 和 `scores`。
|
||||||
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
||||||
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
||||||
|
|||||||
@@ -22,11 +22,11 @@
|
|||||||
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 |
|
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 |
|
||||||
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
||||||
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
||||||
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
|
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅、连接状态通知、心跳和重连 |
|
||||||
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
|
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
|
||||||
| 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store |
|
| 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store |
|
||||||
| Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 |
|
| Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 |
|
||||||
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
|
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM、显式生成帧 |
|
||||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
|
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
|
||||||
@@ -76,15 +76,16 @@
|
|||||||
2. `login()` 调用 `POST /api/auth/login`。
|
2. `login()` 调用 `POST /api/auth/login`。
|
||||||
3. 成功后 store 写入 token,App 渲染主界面。
|
3. 成功后 store 写入 token,App 渲染主界面。
|
||||||
|
|
||||||
### 项目导入
|
### 项目导入与生成帧
|
||||||
|
|
||||||
1. `ProjectLibrary` 创建项目。
|
1. `ProjectLibrary` 创建项目。
|
||||||
2. 上传视频或 DICOM 到 `/api/media/upload` 或 `/api/media/upload/dicom`。
|
2. 导入视频时上传源视频到 `/api/media/upload` 并关联项目;该步骤不调用 `/api/media/parse`。
|
||||||
3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。
|
3. 用户在项目卡片点击“生成帧”,在弹窗中选择目标 FPS。
|
||||||
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。
|
4. 前端调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。
|
||||||
5. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。
|
5. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。
|
||||||
6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
6. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。
|
||||||
7. 刷新项目列表。
|
7. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||||
|
8. 刷新项目列表。
|
||||||
|
|
||||||
### 任务控制
|
### 任务控制
|
||||||
|
|
||||||
@@ -93,11 +94,12 @@
|
|||||||
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
|
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
|
||||||
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
|
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
|
||||||
5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。
|
5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。
|
||||||
|
6. Dashboard 通过 `/ws/progress` 接收 Redis `seg:progress` 转发事件;前端 WebSocket 客户端在 `onopen/onclose/onerror` 主动更新连接状态,并定时发送 `ping` 心跳,服务端返回 `status` 确认连接仍活跃。
|
||||||
|
|
||||||
### 工作区加载
|
### 工作区加载
|
||||||
|
|
||||||
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
||||||
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。
|
2. 若无帧但项目有 `video_path`,显示“尚未生成帧”的状态提示,不自动触发 `parseMedia()`。
|
||||||
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。
|
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。
|
||||||
4. 当前帧传入 `CanvasArea`。
|
4. 当前帧传入 `CanvasArea`。
|
||||||
|
|
||||||
@@ -107,13 +109,15 @@
|
|||||||
2. `CanvasArea` 读取当前帧 ID 和宽高。
|
2. `CanvasArea` 读取当前帧 ID 和宽高。
|
||||||
3. SAM 2 框选会创建一个候选 mask,并记录原始框;后续正向点/反向点会累计到同一候选上。
|
3. SAM 2 框选会创建一个候选 mask,并记录原始框;后续正向点/反向点会累计到同一候选上。
|
||||||
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。
|
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。
|
||||||
5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
|
5. SAM 2 请求中只要存在反向点,`CanvasArea` 会额外发送 `options.auto_filter_background=true` 和 `options.min_score=0.05`,让后端移除低分结果和包含负向点的 polygon。
|
||||||
6. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。
|
6. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
|
||||||
7. Canvas 按当前帧过滤并渲染 mask。
|
7. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。
|
||||||
8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。
|
8. 若带反向点的 SAM 2 细化返回空结果,前端会删除当前旧候选 mask 并提示反向点已排除该区域。
|
||||||
9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
9. Canvas 按当前帧过滤并渲染 mask。
|
||||||
10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
10. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。
|
||||||
11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
11. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||||
|
12. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||||
|
13. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||||
|
|
||||||
### 视频片段传播
|
### 视频片段传播
|
||||||
|
|
||||||
@@ -131,19 +135,20 @@
|
|||||||
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
|
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
|
||||||
2. `CanvasArea` 将交互坐标转换成像素 polygon。
|
2. `CanvasArea` 将交互坐标转换成像素 polygon。
|
||||||
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。
|
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。
|
||||||
4. mask path 只在 `move`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
4. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
||||||
5. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
5. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||||
6. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
6. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||||
7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
||||||
|
|
||||||
### Polygon 逐点编辑
|
### Polygon 逐点编辑
|
||||||
|
|
||||||
1. 用户点击 Canvas 上的 mask path 后,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。
|
1. 用户选择“调整多边形”或“拖拽/选择”后点击 Canvas 上的 mask path,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点和边中点插入手柄。
|
||||||
2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation`、`bbox`、`area`。
|
2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation`、`bbox`、`area`。
|
||||||
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。
|
3. 点击边中点手柄会在该边中点插入新顶点;在“调整多边形”工具下双击 polygon path 会在最接近的线段上按双击位置插入新顶点。
|
||||||
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
|
4. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。
|
||||||
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
|
5. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
|
||||||
6. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。
|
6. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
|
||||||
|
7. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。
|
||||||
|
|
||||||
### 区域合并与去除
|
### 区域合并与去除
|
||||||
|
|
||||||
@@ -173,9 +178,11 @@
|
|||||||
4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。
|
4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。
|
||||||
5. 返回时再解包给前端。
|
5. 返回时再解包给前端。
|
||||||
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
|
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
|
||||||
7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds`。
|
||||||
8. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon`;`CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。
|
||||||
9. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。
|
9. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||||
|
10. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
||||||
|
11. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。
|
||||||
|
|
||||||
### 导出
|
### 导出
|
||||||
|
|
||||||
@@ -204,10 +211,12 @@
|
|||||||
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps`、`max_frames`、`target_width`,用于生成标准帧序列。
|
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps`、`max_frames`、`target_width`,用于生成标准帧序列。
|
||||||
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms` 和 `source_frame_number`。
|
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms` 和 `source_frame_number`。
|
||||||
- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
||||||
|
- SAM 2 是点/框交互式分割模型,不做文本语义分割;AI 页面在 SAM 2 + 纯文本时直接提示用户改用点提示或切换 SAM 3。
|
||||||
|
- SAM 2 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。
|
||||||
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。
|
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。
|
||||||
- SAM 3 box prompt 复用后端 `/api/ai/predict` 的 `box` prompt_type,输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。
|
- SAM 3 box prompt 复用后端 `/api/ai/predict` 的 `box` prompt_type,输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。
|
||||||
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt,不会把正/反点误发送为 SAM 3 point prompt;空文本、后端错误和空结果都会显示反馈消息。
|
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt,不会把正/反点误发送为 SAM 3 point prompt;空文本、后端错误和空结果都会显示反馈消息。
|
||||||
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
|
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果;SAM 3 semantic 会把正数 `min_score` 传给 external worker 作为 `confidence_threshold`。
|
||||||
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker;当前前端默认向后传播 30 帧并保存结果标注。
|
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker;当前前端默认向后传播 30 帧并保存结果标注。
|
||||||
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
|
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
|
||||||
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
||||||
|
|||||||
@@ -16,14 +16,14 @@
|
|||||||
|------|----------|--------|
|
|------|----------|--------|
|
||||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||||
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
||||||
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
|
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
|
||||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、SAM 2 框选后正负点细化同一候选 mask、SAM 2 反向点启用背景过滤且空结果移除旧候选、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||||
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
|
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
|
||||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
|
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
|
||||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
|
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat |
|
||||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
||||||
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
||||||
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
||||||
@@ -34,15 +34,15 @@
|
|||||||
|------|--------|----------|----------|
|
|------|--------|----------|----------|
|
||||||
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
|
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
|
||||||
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
||||||
| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 |
|
| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 |
|
||||||
| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
| R4 | 工作区加载帧、无帧项目不自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||||
| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
| R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||||
| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
| R5 | 顶点编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||||
| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 |
|
| R6 | SAM 2 点/框/interactive、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 semantic、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py`, `test_sam3_engine.py` | 已覆盖 |
|
||||||
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
||||||
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
||||||
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||||
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
|
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
|
||||||
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
||||||
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
|
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
|
||||||
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
|
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
|
||||||
@@ -51,6 +51,8 @@
|
|||||||
|
|
||||||
- R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
|
- R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
|
||||||
- R6:补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。
|
- R6:补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。
|
||||||
|
- R6:补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。
|
||||||
|
- R6:补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。
|
||||||
- R6:补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。
|
- R6:补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。
|
||||||
- R6:补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。
|
- R6:补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。
|
||||||
- R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
|
- R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
|
||||||
|
|||||||
@@ -63,6 +63,129 @@ describe('AISegmentation', () => {
|
|||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('does not run SAM2 text-only prompts as semantic segmentation', async () => {
|
||||||
|
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||||
|
|
||||||
|
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||||
|
target: { value: '胆囊' },
|
||||||
|
});
|
||||||
|
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||||
|
|
||||||
|
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||||
|
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
|
||||||
|
apiMock.predictMask.mockResolvedValueOnce({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'sam2-best',
|
||||||
|
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||||
|
bbox: [0, 0, 10, 10],
|
||||||
|
area: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'sam2-alt',
|
||||||
|
pathData: 'M 1 1 L 11 1 L 11 11 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[1, 1, 11, 1, 11, 11]],
|
||||||
|
bbox: [1, 1, 10, 10],
|
||||||
|
area: 100,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||||
|
fireEvent.click(screen.getByText('正向选点'));
|
||||||
|
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||||
|
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||||
|
|
||||||
|
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||||
|
expect(useStore.getState().masks[0].id).toBe('sam2-best');
|
||||||
|
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
|
||||||
|
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
|
||||||
|
useStore.setState({
|
||||||
|
templates: [
|
||||||
|
{
|
||||||
|
id: 'template-1',
|
||||||
|
name: '腹腔镜模板',
|
||||||
|
classes: [
|
||||||
|
{ id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||||
|
{ id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 },
|
||||||
|
],
|
||||||
|
rules: [],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
apiMock.predictMask.mockResolvedValueOnce({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'sam2-mask',
|
||||||
|
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||||
|
bbox: [10, 10, 30, 30],
|
||||||
|
area: 900,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||||
|
fireEvent.click(screen.getByText('正向选点'));
|
||||||
|
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||||
|
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||||
|
|
||||||
|
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
|
||||||
|
fireEvent.click(screen.getByText('肝脏'));
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
templateId: 'template-1',
|
||||||
|
classId: 'class-2',
|
||||||
|
className: '肝脏',
|
||||||
|
classZIndex: 20,
|
||||||
|
label: '肝脏',
|
||||||
|
color: '#00ff00',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
|
||||||
|
const onSendToWorkspace = vi.fn();
|
||||||
|
apiMock.predictMask.mockResolvedValueOnce({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'sam2-mask',
|
||||||
|
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||||
|
bbox: [10, 10, 30, 30],
|
||||||
|
area: 900,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
|
||||||
|
fireEvent.click(screen.getByText('正向选点'));
|
||||||
|
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||||
|
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||||
|
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText('推送至工作区编辑'));
|
||||||
|
|
||||||
|
expect(useStore.getState().activeTool).toBe('edit_polygon');
|
||||||
|
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
|
||||||
|
expect(onSendToWorkspace).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
it('prompts for semantic text before running SAM3 inference', async () => {
|
it('prompts for semantic text before running SAM3 inference', async () => {
|
||||||
apiMock.getAiModelStatus.mockResolvedValue({
|
apiMock.getAiModelStatus.mockResolvedValue({
|
||||||
selected_model: 'sam3',
|
selected_model: 'sam3',
|
||||||
@@ -106,7 +229,7 @@ describe('AISegmentation', () => {
|
|||||||
points: undefined,
|
points: undefined,
|
||||||
text: '胆囊',
|
text: '胆囊',
|
||||||
})));
|
})));
|
||||||
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
|
expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
|
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
const masks = useStore((state) => state.masks);
|
const masks = useStore((state) => state.masks);
|
||||||
const addMask = useStore((state) => state.addMask);
|
const addMask = useStore((state) => state.addMask);
|
||||||
const clearMasks = useStore((state) => state.clearMasks);
|
const clearMasks = useStore((state) => state.clearMasks);
|
||||||
|
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||||
|
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||||
const maskHistory = useStore((state) => state.maskHistory);
|
const maskHistory = useStore((state) => state.maskHistory);
|
||||||
const maskFuture = useStore((state) => state.maskFuture);
|
const maskFuture = useStore((state) => state.maskFuture);
|
||||||
const undoMasks = useStore((state) => state.undoMasks);
|
const undoMasks = useStore((state) => state.undoMasks);
|
||||||
@@ -97,6 +99,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
|
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
|
||||||
|
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (points.length === 0 && !textPrompt) {
|
if (points.length === 0 && !textPrompt) {
|
||||||
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
|
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
|
||||||
return;
|
return;
|
||||||
@@ -132,14 +138,22 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (result.masks.length === 0) {
|
const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks;
|
||||||
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
|
||||||
|
if (masksToApply.length === 0) {
|
||||||
|
setInferenceMessage(aiModel === 'sam3'
|
||||||
|
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
|
||||||
|
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
||||||
} else {
|
} else {
|
||||||
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
|
setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1
|
||||||
|
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。`
|
||||||
|
: `已生成 ${masksToApply.length} 个候选区域。`);
|
||||||
}
|
}
|
||||||
result.masks.forEach((m) => {
|
const generatedMaskIds: string[] = [];
|
||||||
|
masksToApply.forEach((m) => {
|
||||||
const label = activeClass?.name || m.label;
|
const label = activeClass?.name || m.label;
|
||||||
const color = activeClass?.color || m.color;
|
const color = activeClass?.color || m.color;
|
||||||
|
generatedMaskIds.push(m.id);
|
||||||
addMask({
|
addMask({
|
||||||
id: m.id,
|
id: m.id,
|
||||||
frameId: currentFrame.id,
|
frameId: currentFrame.id,
|
||||||
@@ -157,6 +171,9 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
area: m.area,
|
area: m.area,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
if (generatedMaskIds.length > 0) {
|
||||||
|
setSelectedMaskIds(generatedMaskIds);
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('AI inference failed:', err);
|
console.error('AI inference failed:', err);
|
||||||
const detail = (err as any)?.response?.data?.detail;
|
const detail = (err as any)?.response?.data?.detail;
|
||||||
@@ -164,7 +181,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
} finally {
|
} finally {
|
||||||
setIsInferencing(false);
|
setIsInferencing(false);
|
||||||
}
|
}
|
||||||
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]);
|
||||||
|
|
||||||
const handleStageClick = (e: any) => {
|
const handleStageClick = (e: any) => {
|
||||||
if (effectiveTool === 'move') return;
|
if (effectiveTool === 'move') return;
|
||||||
@@ -307,10 +324,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<button
|
<button
|
||||||
onClick={onSendToWorkspace}
|
onClick={() => {
|
||||||
|
setActiveTool('edit_polygon');
|
||||||
|
onSendToWorkspace();
|
||||||
|
}}
|
||||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
||||||
>
|
>
|
||||||
<SendToBack size={16} /> 退档推送至工作区重组
|
<SendToBack size={16} /> 推送至工作区编辑
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</aside>
|
</aside>
|
||||||
@@ -376,12 +396,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
|
|
||||||
{/* AI Returned Masks */}
|
{/* AI Returned Masks */}
|
||||||
{frameMasks.map((mask) => (
|
{frameMasks.map((mask) => (
|
||||||
<Group key={mask.id} opacity={0.45}>
|
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}>
|
||||||
<Path
|
<Path
|
||||||
data={mask.pathData}
|
data={mask.pathData}
|
||||||
fill={mask.color}
|
fill={mask.color}
|
||||||
stroke={mask.color}
|
stroke={mask.color}
|
||||||
strokeWidth={1 / scale}
|
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale}
|
||||||
|
onClick={(event: any) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
setSelectedMaskIds([mask.id]);
|
||||||
|
}}
|
||||||
|
onTap={(event: any) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
setSelectedMaskIds([mask.id]);
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</Group>
|
</Group>
|
||||||
))}
|
))}
|
||||||
|
|||||||
@@ -206,16 +206,58 @@ describe('CanvasArea', () => {
|
|||||||
{ x: 300, y: 150, type: 'neg' },
|
{ x: 300, y: 150, type: 'neg' },
|
||||||
],
|
],
|
||||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||||
|
options: { auto_filter_background: true, min_score: 0.05 },
|
||||||
}));
|
}));
|
||||||
expect(useStore.getState().masks).toHaveLength(1);
|
expect(useStore.getState().masks).toHaveLength(1);
|
||||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
id: 'mask-box',
|
id: 'mask-box',
|
||||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||||
points: [[150, 100]],
|
points: [[150, 100]],
|
||||||
metadata: expect.objectContaining({ promptPointCount: 2 }),
|
metadata: expect.objectContaining({ promptPointCount: 2, promptNegativePointCount: 1 }),
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('removes the SAM2 candidate when a negative point filters it out', async () => {
|
||||||
|
apiMock.predictMask
|
||||||
|
.mockResolvedValueOnce({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'mask-box',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 90 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 90]],
|
||||||
|
bbox: [10, 10, 80, 80],
|
||||||
|
area: 6400,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
.mockResolvedValueOnce({ masks: [] });
|
||||||
|
|
||||||
|
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||||
|
const stage = screen.getByTestId('konva-stage');
|
||||||
|
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||||
|
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||||
|
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||||
|
|
||||||
|
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||||
|
|
||||||
|
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
|
||||||
|
fireEvent.click(stage, { clientX: 180, clientY: 120 });
|
||||||
|
|
||||||
|
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
|
||||||
|
imageId: 'frame-1',
|
||||||
|
imageWidth: 640,
|
||||||
|
imageHeight: 360,
|
||||||
|
model: 'sam2',
|
||||||
|
points: [{ x: 180, y: 120, type: 'neg' }],
|
||||||
|
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||||
|
options: { auto_filter_background: true, min_score: 0.05 },
|
||||||
|
}));
|
||||||
|
await waitFor(() => expect(useStore.getState().masks).toHaveLength(0));
|
||||||
|
expect(await screen.findByText(/反向点已排除当前候选区域/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
it('renders only masks that belong to the current frame', () => {
|
it('renders only masks that belong to the current frame', () => {
|
||||||
useStore.setState({
|
useStore.setState({
|
||||||
masks: [
|
masks: [
|
||||||
@@ -250,6 +292,28 @@ describe('CanvasArea', () => {
|
|||||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
|
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('keeps a mask selected when opening the workspace polygon editor from AI results', () => {
|
||||||
|
useStore.setState({
|
||||||
|
selectedMaskIds: ['m1'],
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'm1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||||
|
label: 'A',
|
||||||
|
color: '#fff',
|
||||||
|
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
|
||||||
|
|
||||||
|
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||||
|
expect(screen.getAllByTestId('konva-circle')
|
||||||
|
.filter((element) => element.getAttribute('data-fill') === '#ffffff')).toHaveLength(3);
|
||||||
|
});
|
||||||
|
|
||||||
it('renders imported GT seed points for editable point regions', () => {
|
it('renders imported GT seed points for editable point regions', () => {
|
||||||
useStore.setState({
|
useStore.setState({
|
||||||
masks: [
|
masks: [
|
||||||
@@ -415,6 +479,34 @@ describe('CanvasArea', () => {
|
|||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('selects a polygon with the edit tool and inserts a vertex by double-clicking an edge', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'draft-1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||||
|
label: 'Draft',
|
||||||
|
color: '#06b6d4',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||||
|
bbox: [10, 10, 80, 30],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
|
||||||
|
const path = screen.getByTestId('konva-path');
|
||||||
|
fireEvent.click(path);
|
||||||
|
fireEvent.doubleClick(path, { clientX: 50, clientY: 10 });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
|
||||||
|
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
it('edits the selected polygon in a multi-polygon mask', () => {
|
it('edits the selected polygon in a multi-polygon mask', () => {
|
||||||
useStore.setState({
|
useStore.setState({
|
||||||
masks: [
|
masks: [
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
|||||||
|
|
||||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||||
const POLYGON_TOOL = 'create_polygon';
|
const POLYGON_TOOL = 'create_polygon';
|
||||||
|
const EDIT_POLYGON_TOOL = 'edit_polygon';
|
||||||
const POINT_TOOL = 'create_point';
|
const POINT_TOOL = 'create_point';
|
||||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||||
const POLYGON_CLOSE_RADIUS = 8;
|
const POLYGON_CLOSE_RADIUS = 8;
|
||||||
@@ -95,6 +96,32 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
|
|||||||
return Math.hypot(a.x - b.x, a.y - b.y);
|
return Math.hypot(a.x - b.x, a.y - b.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
|
||||||
|
const dx = end.x - start.x;
|
||||||
|
const dy = end.y - start.y;
|
||||||
|
const lengthSquared = dx * dx + dy * dy;
|
||||||
|
if (lengthSquared === 0) {
|
||||||
|
return (point.x - start.x) ** 2 + (point.y - start.y) ** 2;
|
||||||
|
}
|
||||||
|
const t = clamp(((point.x - start.x) * dx + (point.y - start.y) * dy) / lengthSquared, 0, 1);
|
||||||
|
const projected = { x: start.x + t * dx, y: start.y + t * dy };
|
||||||
|
return (point.x - projected.x) ** 2 + (point.y - projected.y) ** 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
function nearestPolygonEdgeIndex(points: CanvasPoint[], point: CanvasPoint): number {
|
||||||
|
return points.reduce((bestIndex, start, index) => {
|
||||||
|
const end = points[(index + 1) % points.length];
|
||||||
|
if (!end) return bestIndex;
|
||||||
|
const bestStart = points[bestIndex];
|
||||||
|
const bestEnd = points[(bestIndex + 1) % points.length];
|
||||||
|
const currentDistance = distanceToSegmentSquared(point, start, end);
|
||||||
|
const bestDistance = bestStart && bestEnd
|
||||||
|
? distanceToSegmentSquared(point, bestStart, bestEnd)
|
||||||
|
: Number.POSITIVE_INFINITY;
|
||||||
|
return currentDistance < bestDistance ? index : bestIndex;
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
function segmentationArea(segmentation?: number[][]): number {
|
function segmentationArea(segmentation?: number[][]): number {
|
||||||
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||||
}
|
}
|
||||||
@@ -210,10 +237,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||||
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
|
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
|
||||||
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
|
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
|
||||||
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||||
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||||
|
const previousFrameIdRef = useRef<string | undefined>(frame?.id);
|
||||||
const [isInferencing, setIsInferencing] = useState(false);
|
const [isInferencing, setIsInferencing] = useState(false);
|
||||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||||
|
|
||||||
@@ -253,6 +281,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||||
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
||||||
|
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleResize = () => {
|
const handleResize = () => {
|
||||||
@@ -273,11 +302,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
setManualStart(null);
|
setManualStart(null);
|
||||||
setManualCurrent(null);
|
setManualCurrent(null);
|
||||||
setPolygonPoints([]);
|
setPolygonPoints([]);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
if (!isPolygonEditTool && !isBooleanTool) {
|
||||||
|
setSelectedMaskId(null);
|
||||||
|
setSelectedMaskIds([]);
|
||||||
|
setSelectedPolygonIndex(0);
|
||||||
|
}
|
||||||
|
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (previousFrameIdRef.current === frame?.id) return;
|
||||||
|
previousFrameIdRef.current = frame?.id;
|
||||||
setSelectedMaskId(null);
|
setSelectedMaskId(null);
|
||||||
setSelectedMaskIds([]);
|
setSelectedMaskIds([]);
|
||||||
setSelectedPolygonIndex(0);
|
setSelectedPolygonIndex(0);
|
||||||
setSelectedVertexIndex(null);
|
setSelectedVertexIndex(null);
|
||||||
}, [effectiveTool, frame?.id]);
|
}, [frame?.id]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setPoints([]);
|
setPoints([]);
|
||||||
@@ -420,6 +460,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
setIsInferencing(true);
|
setIsInferencing(true);
|
||||||
setInferenceMessage('');
|
setInferenceMessage('');
|
||||||
try {
|
try {
|
||||||
|
const hasNegativePrompt = Boolean(promptPoints?.some((point) => point.type === 'neg'));
|
||||||
|
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
||||||
|
? masks.find((mask) => mask.id === samCandidateMaskId)
|
||||||
|
: null;
|
||||||
const result = await predictMask({
|
const result = await predictMask({
|
||||||
imageId: frame.id,
|
imageId: frame.id,
|
||||||
imageWidth,
|
imageWidth,
|
||||||
@@ -429,13 +473,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
|
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
|
||||||
: undefined,
|
: undefined,
|
||||||
box: promptBox,
|
box: promptBox,
|
||||||
|
...(hasNegativePrompt ? { options: { auto_filter_background: true, min_score: 0.05 } } : {}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const [m] = result.masks;
|
const [m] = result.masks;
|
||||||
if (m) {
|
if (m) {
|
||||||
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
|
||||||
? masks.find((mask) => mask.id === samCandidateMaskId)
|
|
||||||
: null;
|
|
||||||
const label = activeClass?.name || existingCandidate?.label || m.label;
|
const label = activeClass?.name || existingCandidate?.label || m.label;
|
||||||
const color = activeClass?.color || existingCandidate?.color || m.color;
|
const color = activeClass?.color || existingCandidate?.color || m.color;
|
||||||
const metadata = {
|
const metadata = {
|
||||||
@@ -443,6 +485,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
|
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
|
||||||
promptBox: promptBox || null,
|
promptBox: promptBox || null,
|
||||||
promptPointCount: promptPoints?.length || 0,
|
promptPointCount: promptPoints?.length || 0,
|
||||||
|
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,
|
||||||
};
|
};
|
||||||
const nextMask = {
|
const nextMask = {
|
||||||
frameId: frame.id,
|
frameId: frame.id,
|
||||||
@@ -476,7 +519,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
if (existingCandidate && hasNegativePrompt) {
|
||||||
|
setMasks(masks.filter((mask) => mask.id !== existingCandidate.id));
|
||||||
|
setSamCandidateMaskId(null);
|
||||||
|
setSelectedMaskId(null);
|
||||||
|
setSelectedMaskIds([]);
|
||||||
|
setInferenceMessage('反向点已排除当前候选区域,请重新框选或添加新的正向点。');
|
||||||
|
} else {
|
||||||
|
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Inference failed:', err);
|
console.error('Inference failed:', err);
|
||||||
@@ -485,7 +536,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
} finally {
|
} finally {
|
||||||
setIsInferencing(false);
|
setIsInferencing(false);
|
||||||
}
|
}
|
||||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
|
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, setMasks, updateMask]);
|
||||||
|
|
||||||
const handleApplyActiveClass = () => {
|
const handleApplyActiveClass = () => {
|
||||||
if (!frame?.id || !activeClass) return;
|
if (!frame?.id || !activeClass) return;
|
||||||
@@ -598,7 +649,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleStageClick = (e: any) => {
|
const handleStageClick = (e: any) => {
|
||||||
if (effectiveTool === 'move') return;
|
if (isPolygonEditTool) return;
|
||||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
||||||
|
|
||||||
@@ -716,7 +767,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
|
|
||||||
window.addEventListener('keydown', handleKeyDown);
|
window.addEventListener('keydown', handleKeyDown);
|
||||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||||
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
}, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||||
|
|
||||||
const boxRect = React.useMemo(() => {
|
const boxRect = React.useMemo(() => {
|
||||||
if (!boxStart || !boxCurrent) return null;
|
if (!boxStart || !boxCurrent) return null;
|
||||||
@@ -753,7 +804,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||||
if (effectiveTool !== 'move' && !isBooleanTool) return;
|
if (!isPolygonEditTool && !isBooleanTool) return;
|
||||||
event.cancelBubble = true;
|
event.cancelBubble = true;
|
||||||
if (isBooleanTool) {
|
if (isBooleanTool) {
|
||||||
setSelectedMaskIds((current) => (
|
setSelectedMaskIds((current) => (
|
||||||
@@ -807,6 +858,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handlePathDoubleClick = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||||
|
if (effectiveTool !== EDIT_POLYGON_TOOL) return;
|
||||||
|
event.cancelBubble = true;
|
||||||
|
const point = stagePoint(event);
|
||||||
|
const currentPoints = segmentationToPoints(mask.segmentation, polygonIndex);
|
||||||
|
if (!point || currentPoints.length < 3) return;
|
||||||
|
const edgeIndex = nearestPolygonEdgeIndex(currentPoints, point);
|
||||||
|
const nextPoints = [
|
||||||
|
...currentPoints.slice(0, edgeIndex + 1),
|
||||||
|
point,
|
||||||
|
...currentPoints.slice(edgeIndex + 1),
|
||||||
|
];
|
||||||
|
setSelectedMaskId(mask.id);
|
||||||
|
setSelectedMaskIds([mask.id]);
|
||||||
|
setSelectedPolygonIndex(polygonIndex);
|
||||||
|
setSelectedVertexIndex(edgeIndex + 1);
|
||||||
|
updatePolygonMask(mask, nextPoints, polygonIndex);
|
||||||
|
};
|
||||||
|
|
||||||
const handleBooleanOperation = async () => {
|
const handleBooleanOperation = async () => {
|
||||||
if (!frame || booleanSelectedMasks.length < 2) return;
|
if (!frame || booleanSelectedMasks.length < 2) return;
|
||||||
const primary = booleanSelectedMasks[0];
|
const primary = booleanSelectedMasks[0];
|
||||||
@@ -918,6 +988,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||||
|
onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
|
||||||
|
onDblTap={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
</Group>
|
</Group>
|
||||||
@@ -987,7 +1059,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
)))}
|
)))}
|
||||||
|
|
||||||
{/* Polygon edge insertion handles */}
|
{/* Polygon edge insertion handles */}
|
||||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
||||||
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||||
if (!next) return null;
|
if (!next) return null;
|
||||||
return (
|
return (
|
||||||
@@ -1006,7 +1078,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
|||||||
})}
|
})}
|
||||||
|
|
||||||
{/* Polygon vertex editor */}
|
{/* Polygon vertex editor */}
|
||||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
||||||
<Circle
|
<Circle
|
||||||
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||||
x={point.x}
|
x={point.x}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ const apiMock = vi.hoisted(() => ({
|
|||||||
const wsMock = vi.hoisted(() => {
|
const wsMock = vi.hoisted(() => {
|
||||||
const state = {
|
const state = {
|
||||||
callback: undefined as undefined | ((data: any) => void),
|
callback: undefined as undefined | ((data: any) => void),
|
||||||
|
statusCallback: undefined as undefined | ((status: any) => void),
|
||||||
connected: false,
|
connected: false,
|
||||||
};
|
};
|
||||||
return {
|
return {
|
||||||
@@ -24,6 +25,11 @@ const wsMock = vi.hoisted(() => {
|
|||||||
state.callback = cb;
|
state.callback = cb;
|
||||||
return vi.fn();
|
return vi.fn();
|
||||||
}),
|
}),
|
||||||
|
onStatus: vi.fn((cb: (status: any) => void) => {
|
||||||
|
state.statusCallback = cb;
|
||||||
|
cb(state.connected ? 'connected' : 'disconnected');
|
||||||
|
return vi.fn();
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@@ -45,6 +51,7 @@ describe('Dashboard', () => {
|
|||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
wsMock.state.connected = false;
|
wsMock.state.connected = false;
|
||||||
wsMock.state.callback = undefined;
|
wsMock.state.callback = undefined;
|
||||||
|
wsMock.state.statusCallback = undefined;
|
||||||
apiMock.getDashboardOverview.mockResolvedValue({
|
apiMock.getDashboardOverview.mockResolvedValue({
|
||||||
summary: {
|
summary: {
|
||||||
project_count: 2,
|
project_count: 2,
|
||||||
@@ -109,6 +116,20 @@ describe('Dashboard', () => {
|
|||||||
expect(screen.getByText('44%')).toBeInTheDocument();
|
expect(screen.getByText('44%')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('updates the websocket badge from connection status callbacks', async () => {
|
||||||
|
render(<Dashboard />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(wsMock.progressWS.onStatus).toHaveBeenCalled());
|
||||||
|
expect(screen.getByText('WebSocket 断开')).toBeInTheDocument();
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
wsMock.state.connected = true;
|
||||||
|
wsMock.state.statusCallback?.('connected');
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText('WebSocket 已连接')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
it('adds activity logs for complete and status messages', async () => {
|
it('adds activity logs for complete and status messages', async () => {
|
||||||
render(<Dashboard />);
|
render(<Dashboard />);
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import React, { useState, useEffect } from 'react';
|
import React, { useState, useEffect } from 'react';
|
||||||
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
|
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
|
||||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
import { progressWS, type ConnectionStatus, type ProgressMessage } from '../lib/websocket';
|
||||||
import { cn } from '../lib/utils';
|
import { cn } from '../lib/utils';
|
||||||
import {
|
import {
|
||||||
cancelTask,
|
cancelTask,
|
||||||
@@ -178,6 +178,9 @@ export function Dashboard() {
|
|||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
const unsubscribeStatus = progressWS.onStatus((status: ConnectionStatus) => {
|
||||||
|
if (mounted) setIsConnected(status === 'connected');
|
||||||
|
});
|
||||||
|
|
||||||
const checkConnection = setInterval(() => {
|
const checkConnection = setInterval(() => {
|
||||||
if (mounted) setIsConnected(progressWS.isConnected());
|
if (mounted) setIsConnected(progressWS.isConnected());
|
||||||
@@ -186,6 +189,7 @@ export function Dashboard() {
|
|||||||
return () => {
|
return () => {
|
||||||
mounted = false;
|
mounted = false;
|
||||||
unsubscribe();
|
unsubscribe();
|
||||||
|
unsubscribeStatus();
|
||||||
clearInterval(checkConnection);
|
clearInterval(checkConnection);
|
||||||
progressWS.disconnect();
|
progressWS.disconnect();
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -56,10 +56,9 @@ describe('ProjectLibrary', () => {
|
|||||||
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
|
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
|
it('imports video by creating a project and uploading media without parsing frames', async () => {
|
||||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
|
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
|
||||||
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
|
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
|
||||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
|
||||||
apiMock.getProjects.mockResolvedValue([]);
|
apiMock.getProjects.mockResolvedValue([]);
|
||||||
|
|
||||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||||
@@ -70,10 +69,24 @@ describe('ProjectLibrary', () => {
|
|||||||
|
|
||||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
|
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
name: 'clip.mp4',
|
name: 'clip.mp4',
|
||||||
parse_fps: 30,
|
|
||||||
})));
|
})));
|
||||||
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
||||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
|
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('generates frames from an imported video with the selected FPS', async () => {
|
||||||
|
apiMock.getProjects
|
||||||
|
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'pending', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 30 }])
|
||||||
|
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'parsing', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 12 }]);
|
||||||
|
apiMock.parseMedia.mockResolvedValueOnce({ id: 22, status: 'queued', progress: 0 });
|
||||||
|
|
||||||
|
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||||
|
|
||||||
|
fireEvent.click(await screen.findByRole('button', { name: '生成帧' }));
|
||||||
|
fireEvent.change(container.querySelector('input[type="range"]') as HTMLInputElement, { target: { value: '12' } });
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: '开始生成帧' }));
|
||||||
|
|
||||||
|
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 }));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('imports only valid DICOM files and parses the returned project', async () => {
|
it('imports only valid DICOM files and parses the returned project', async () => {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import React, { useState, useEffect, useRef } from 'react';
|
import React, { useState, useEffect, useRef } from 'react';
|
||||||
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity } from 'lucide-react';
|
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity, Images } from 'lucide-react';
|
||||||
import { cn } from '../lib/utils';
|
import { cn } from '../lib/utils';
|
||||||
import { useStore } from '../store/useStore';
|
import { useStore } from '../store/useStore';
|
||||||
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api';
|
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api';
|
||||||
@@ -22,7 +22,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
const [showImportMenu, setShowImportMenu] = useState(false);
|
const [showImportMenu, setShowImportMenu] = useState(false);
|
||||||
const [showVideoConfig, setShowVideoConfig] = useState(false);
|
const [showVideoConfig, setShowVideoConfig] = useState(false);
|
||||||
const [pendingFile, setPendingFile] = useState<File | null>(null);
|
const [pendingFile, setPendingFile] = useState<File | null>(null);
|
||||||
const [parseFps, setParseFps] = useState(30);
|
const [frameProject, setFrameProject] = useState<Project | null>(null);
|
||||||
|
const [showFrameConfig, setShowFrameConfig] = useState(false);
|
||||||
|
const [frameParseFps, setFrameParseFps] = useState(30);
|
||||||
|
const [isGeneratingFrames, setIsGeneratingFrames] = useState(false);
|
||||||
const videoInputRef = useRef<HTMLInputElement>(null);
|
const videoInputRef = useRef<HTMLInputElement>(null);
|
||||||
const dicomInputRef = useRef<HTMLInputElement>(null);
|
const dicomInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
@@ -57,7 +60,6 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
|
|
||||||
const handleVideoSelect = (file: File) => {
|
const handleVideoSelect = (file: File) => {
|
||||||
setPendingFile(file);
|
setPendingFile(file);
|
||||||
setParseFps(30);
|
|
||||||
setShowVideoConfig(true);
|
setShowVideoConfig(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -69,11 +71,9 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
const newProject = await createProject({
|
const newProject = await createProject({
|
||||||
name: pendingFile.name,
|
name: pendingFile.name,
|
||||||
description: `导入于 ${new Date().toLocaleString()}`,
|
description: `导入于 ${new Date().toLocaleString()}`,
|
||||||
parse_fps: parseFps,
|
|
||||||
});
|
});
|
||||||
const result = await uploadMedia(pendingFile, String(newProject.id));
|
const result = await uploadMedia(pendingFile, String(newProject.id));
|
||||||
await parseMedia(String(newProject.id));
|
alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`);
|
||||||
alert(`上传成功: ${pendingFile.name}\n已保存至: ${result.url}`);
|
|
||||||
const data = await getProjects();
|
const data = await getProjects();
|
||||||
setProjects(data);
|
setProjects(data);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -86,6 +86,31 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const openFrameConfig = (project: Project, event: React.MouseEvent) => {
|
||||||
|
event.stopPropagation();
|
||||||
|
setFrameProject(project);
|
||||||
|
setFrameParseFps(Math.round(project.parse_fps || 30));
|
||||||
|
setShowFrameConfig(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleGenerateFrames = async () => {
|
||||||
|
if (!frameProject?.id) return;
|
||||||
|
setIsGeneratingFrames(true);
|
||||||
|
try {
|
||||||
|
const task = await parseMedia(frameProject.id, { parseFps: frameParseFps });
|
||||||
|
alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`);
|
||||||
|
const data = await getProjects();
|
||||||
|
setProjects(data);
|
||||||
|
setShowFrameConfig(false);
|
||||||
|
setFrameProject(null);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Frame generation failed:', err);
|
||||||
|
alert('生成帧失败,请检查后端服务或项目源文件');
|
||||||
|
} finally {
|
||||||
|
setIsGeneratingFrames(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const handleDicomUpload = async (files: FileList | null) => {
|
const handleDicomUpload = async (files: FileList | null) => {
|
||||||
if (!files || files.length === 0) return;
|
if (!files || files.length === 0) return;
|
||||||
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
|
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
|
||||||
@@ -209,7 +234,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
)}
|
)}
|
||||||
<div className="absolute top-2 right-2 flex gap-2">
|
<div className="absolute top-2 right-2 flex gap-2">
|
||||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
||||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
|
{proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))}
|
||||||
</span>
|
</span>
|
||||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||||
{proj.status === 'ready' ? (
|
{proj.status === 'ready' ? (
|
||||||
@@ -235,6 +260,15 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
<span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> 原 {proj.original_fps.toFixed(1)}fps</span>
|
<span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> 原 {proj.original_fps.toFixed(1)}fps</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
{proj.video_path && (proj.frames ?? 0) === 0 && proj.status !== 'parsing' && (
|
||||||
|
<button
|
||||||
|
onClick={(event) => openFrameConfig(proj, event)}
|
||||||
|
className="mt-3 inline-flex items-center justify-center gap-2 rounded-md border border-cyan-500/30 bg-cyan-500/10 px-3 py-2 text-xs font-medium text-cyan-200 hover:bg-cyan-500/20 transition-colors"
|
||||||
|
>
|
||||||
|
<Images size={14} />
|
||||||
|
生成帧
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
@@ -245,24 +279,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
{showVideoConfig && pendingFile && (
|
{showVideoConfig && pendingFile && (
|
||||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||||
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
||||||
<h2 className="text-lg font-semibold text-white mb-4">导入视频配置</h2>
|
<h2 className="text-lg font-semibold text-white mb-4">导入视频文件</h2>
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
<div className="text-sm text-gray-400">文件: <span className="text-gray-200">{pendingFile.name}</span></div>
|
<div className="text-sm text-gray-400">文件: <span className="text-gray-200">{pendingFile.name}</span></div>
|
||||||
<div>
|
<p className="text-xs leading-5 text-gray-500">此步骤只上传源视频并创建项目,不会立即拆帧。拆帧时再选择目标 FPS。</p>
|
||||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">解析帧率 (FPS)</label>
|
|
||||||
<div className="flex items-center gap-3">
|
|
||||||
<input
|
|
||||||
type="range"
|
|
||||||
min="1"
|
|
||||||
max="60"
|
|
||||||
value={parseFps}
|
|
||||||
onChange={(e) => setParseFps(parseInt(e.target.value))}
|
|
||||||
className="flex-1 accent-cyan-500"
|
|
||||||
/>
|
|
||||||
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{parseFps}</span>
|
|
||||||
</div>
|
|
||||||
<p className="text-[10px] text-gray-600 mt-1">帧率越低,提取的帧越少,处理速度越快</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
<div className="flex justify-end gap-3 mt-6">
|
<div className="flex justify-end gap-3 mt-6">
|
||||||
<button
|
<button
|
||||||
@@ -282,6 +302,49 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Frame generation FPS config modal */}
|
||||||
|
{showFrameConfig && frameProject && (
|
||||||
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||||
|
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
||||||
|
<h2 className="text-lg font-semibold text-white mb-4">生成帧</h2>
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="text-sm text-gray-400">项目: <span className="text-gray-200">{frameProject.name}</span></div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">生成帧率 (FPS)</label>
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<input
|
||||||
|
type="range"
|
||||||
|
min="1"
|
||||||
|
max="60"
|
||||||
|
value={frameParseFps}
|
||||||
|
onChange={(e) => setFrameParseFps(parseInt(e.target.value))}
|
||||||
|
className="flex-1 accent-cyan-500"
|
||||||
|
/>
|
||||||
|
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{frameParseFps}</span>
|
||||||
|
</div>
|
||||||
|
<p className="text-[10px] text-gray-600 mt-1">帧率越低,提取的帧越少,处理速度越快</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-end gap-3 mt-6">
|
||||||
|
<button
|
||||||
|
onClick={() => { setShowFrameConfig(false); setFrameProject(null); }}
|
||||||
|
disabled={isGeneratingFrames}
|
||||||
|
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors disabled:opacity-50"
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={handleGenerateFrames}
|
||||||
|
disabled={isGeneratingFrames}
|
||||||
|
className="px-4 py-2 rounded-lg text-sm font-medium bg-cyan-500 hover:bg-cyan-400 text-black transition-all disabled:opacity-60"
|
||||||
|
>
|
||||||
|
{isGeneratingFrames ? '入队中...' : '开始生成帧'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* New project modal */}
|
{/* New project modal */}
|
||||||
{showModal && (
|
{showModal && (
|
||||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ describe('ToolsPalette', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||||
|
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
|
||||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||||
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
||||||
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
||||||
|
|
||||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
|
||||||
|
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
|
||||||
expect(onUndo).toHaveBeenCalled();
|
expect(onUndo).toHaveBeenCalled();
|
||||||
expect(onRedo).toHaveBeenCalled();
|
expect(onRedo).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
||||||
import { cn } from '../lib/utils';
|
import { cn } from '../lib/utils';
|
||||||
|
|
||||||
interface ToolsPaletteProps {
|
interface ToolsPaletteProps {
|
||||||
@@ -23,6 +23,7 @@ export function ToolsPalette({
|
|||||||
}: ToolsPaletteProps) {
|
}: ToolsPaletteProps) {
|
||||||
const tools = [
|
const tools = [
|
||||||
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
||||||
|
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
|
||||||
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
||||||
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
|
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
|
||||||
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },
|
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },
|
||||||
|
|||||||
@@ -82,23 +82,16 @@ describe('VideoWorkspace', () => {
|
|||||||
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('triggers parsing when a media project has no frames yet', async () => {
|
it('does not auto-generate frames when a media project has no frames yet', async () => {
|
||||||
apiMock.getProjectFrames
|
apiMock.getProjectFrames.mockResolvedValueOnce([]);
|
||||||
.mockResolvedValueOnce([])
|
|
||||||
.mockResolvedValueOnce([
|
|
||||||
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
|
|
||||||
]);
|
|
||||||
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
|
|
||||||
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
|
|
||||||
|
|
||||||
render(<VideoWorkspace />);
|
render(<VideoWorkspace />);
|
||||||
|
|
||||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
|
await waitFor(() => expect(apiMock.getProjectFrames).toHaveBeenCalledWith('1'));
|
||||||
expect(apiMock.getTask).toHaveBeenCalledWith(7);
|
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||||
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
|
expect(apiMock.getTask).not.toHaveBeenCalled();
|
||||||
id: '11',
|
expect(useStore.getState().frames).toEqual([]);
|
||||||
url: '/parsed.jpg',
|
expect(await screen.findByText('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”')).toBeInTheDocument();
|
||||||
})));
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('hydrates saved annotations after loading frames', async () => {
|
it('hydrates saved annotations after loading frames', async () => {
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ import {
|
|||||||
exportMasks,
|
exportMasks,
|
||||||
getProjectAnnotations,
|
getProjectAnnotations,
|
||||||
getProjectFrames,
|
getProjectFrames,
|
||||||
getTask,
|
|
||||||
getTemplates,
|
getTemplates,
|
||||||
importGtMask,
|
importGtMask,
|
||||||
parseMedia,
|
|
||||||
propagateMasks,
|
propagateMasks,
|
||||||
saveAnnotation,
|
saveAnnotation,
|
||||||
updateAnnotation,
|
updateAnnotation,
|
||||||
@@ -23,10 +21,6 @@ import { FrameTimeline } from './FrameTimeline';
|
|||||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||||
import type { Frame } from '../store/useStore';
|
import type { Frame } from '../store/useStore';
|
||||||
|
|
||||||
function sleep(ms: number) {
|
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
||||||
}
|
|
||||||
|
|
||||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||||
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
||||||
const activeTool = useStore((state) => state.activeTool);
|
const activeTool = useStore((state) => state.activeTool);
|
||||||
@@ -72,64 +66,31 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
|||||||
const data = await getProjectFrames(String(currentProject.id));
|
const data = await getProjectFrames(String(currentProject.id));
|
||||||
if (cancelled) return;
|
if (cancelled) return;
|
||||||
|
|
||||||
if (data.length === 0 && currentProject.video_path) {
|
const mappedFrames = data.map((f) => ({
|
||||||
// No frames yet but video exists -> queue parsing and poll the task.
|
id: String(f.id),
|
||||||
try {
|
projectId: String(f.project_id),
|
||||||
const task = await parseMedia(String(currentProject.id));
|
index: f.frame_index,
|
||||||
if (cancelled) return;
|
url: f.image_url,
|
||||||
setStatusMessage(`解析任务已入队 #${task.id}`);
|
width: f.width ?? 0,
|
||||||
let completed = false;
|
height: f.height ?? 0,
|
||||||
for (let attempt = 0; attempt < 60; attempt += 1) {
|
timestampMs: f.timestamp_ms ?? undefined,
|
||||||
const freshTask = await getTask(task.id);
|
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||||
if (cancelled) return;
|
}));
|
||||||
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
|
setFrames(mappedFrames);
|
||||||
if (freshTask.status === 'success') {
|
setCurrentFrame(0);
|
||||||
completed = true;
|
if (mappedFrames.length === 0) {
|
||||||
break;
|
setMasks([]);
|
||||||
}
|
if (currentProject.status === 'parsing') {
|
||||||
if (freshTask.status === 'failed') {
|
setStatusMessage('生成帧任务正在后台运行,可在 Dashboard 查看进度');
|
||||||
setStatusMessage(freshTask.error || '解析任务失败');
|
} else if (currentProject.video_path) {
|
||||||
return;
|
setStatusMessage('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”');
|
||||||
}
|
} else {
|
||||||
await sleep(2000);
|
setStatusMessage('当前项目没有可显示帧');
|
||||||
}
|
|
||||||
if (!completed) {
|
|
||||||
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const fresh = await getProjectFrames(String(currentProject.id));
|
|
||||||
if (cancelled) return;
|
|
||||||
const mappedFrames = fresh.map((f) => ({
|
|
||||||
id: String(f.id),
|
|
||||||
projectId: String(f.project_id),
|
|
||||||
index: f.frame_index,
|
|
||||||
url: f.image_url,
|
|
||||||
width: f.width ?? 0,
|
|
||||||
height: f.height ?? 0,
|
|
||||||
timestampMs: f.timestamp_ms ?? undefined,
|
|
||||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
|
||||||
}));
|
|
||||||
setFrames(mappedFrames);
|
|
||||||
setCurrentFrame(0);
|
|
||||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Parse failed:', err);
|
|
||||||
}
|
}
|
||||||
} else {
|
return;
|
||||||
const mappedFrames = data.map((f) => ({
|
|
||||||
id: String(f.id),
|
|
||||||
projectId: String(f.project_id),
|
|
||||||
index: f.frame_index,
|
|
||||||
url: f.image_url,
|
|
||||||
width: f.width ?? 0,
|
|
||||||
height: f.height ?? 0,
|
|
||||||
timestampMs: f.timestamp_ms ?? undefined,
|
|
||||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
|
||||||
}));
|
|
||||||
setFrames(mappedFrames);
|
|
||||||
setCurrentFrame(0);
|
|
||||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
|
||||||
}
|
}
|
||||||
|
setStatusMessage('');
|
||||||
|
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to load frames:', err);
|
console.error('Failed to load frames:', err);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ import { afterEach, describe, expect, it, vi } from 'vitest';
|
|||||||
|
|
||||||
describe('progress websocket client', () => {
|
describe('progress websocket client', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
vi.useRealTimers();
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
vi.resetModules();
|
vi.resetModules();
|
||||||
vi.unstubAllGlobals();
|
vi.unstubAllGlobals();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('connects using the configured URL and reports open state', async () => {
|
it('connects using the configured URL, reports open state, and sends heartbeat pings', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
const instances: any[] = [];
|
const instances: any[] = [];
|
||||||
class FakeWebSocket {
|
class FakeWebSocket {
|
||||||
static CONNECTING = 0;
|
static CONNECTING = 0;
|
||||||
@@ -21,14 +23,26 @@ describe('progress websocket client', () => {
|
|||||||
instances.push(this);
|
instances.push(this);
|
||||||
}
|
}
|
||||||
close = vi.fn();
|
close = vi.fn();
|
||||||
|
send = vi.fn();
|
||||||
}
|
}
|
||||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||||
|
|
||||||
const { progressWS } = await import('./websocket');
|
const { progressWS } = await import('./websocket');
|
||||||
|
const statusCallback = vi.fn();
|
||||||
|
const unsubscribeStatus = progressWS.onStatus(statusCallback);
|
||||||
progressWS.connect();
|
progressWS.connect();
|
||||||
|
instances[0].onopen?.();
|
||||||
|
|
||||||
expect(instances[0].url).toContain('/ws/progress');
|
expect(instances[0].url).toContain('/ws/progress');
|
||||||
expect(progressWS.isConnected()).toBe(true);
|
expect(progressWS.isConnected()).toBe(true);
|
||||||
|
expect(statusCallback).toHaveBeenCalledWith('connected');
|
||||||
|
expect(instances[0].send).toHaveBeenCalledWith('ping');
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(15000);
|
||||||
|
expect(instances[0].send).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
unsubscribeStatus();
|
||||||
|
progressWS.disconnect();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('subscribes and unsubscribes progress callbacks', async () => {
|
it('subscribes and unsubscribes progress callbacks', async () => {
|
||||||
@@ -43,4 +57,41 @@ describe('progress websocket client', () => {
|
|||||||
expect(callback).toHaveBeenCalledTimes(1);
|
expect(callback).toHaveBeenCalledTimes(1);
|
||||||
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
|
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('notifies connection status changes and schedules reconnect on close', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
const instances: any[] = [];
|
||||||
|
class FakeWebSocket {
|
||||||
|
static CONNECTING = 0;
|
||||||
|
static OPEN = 1;
|
||||||
|
readyState = FakeWebSocket.OPEN;
|
||||||
|
onopen?: () => void;
|
||||||
|
onmessage?: (event: MessageEvent) => void;
|
||||||
|
onclose?: () => void;
|
||||||
|
onerror?: () => void;
|
||||||
|
constructor(public url: string) {
|
||||||
|
instances.push(this);
|
||||||
|
}
|
||||||
|
close = vi.fn();
|
||||||
|
send = vi.fn();
|
||||||
|
}
|
||||||
|
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||||
|
|
||||||
|
const { progressWS } = await import('./websocket');
|
||||||
|
const statusCallback = vi.fn();
|
||||||
|
const unsubscribeStatus = progressWS.onStatus(statusCallback);
|
||||||
|
|
||||||
|
progressWS.connect();
|
||||||
|
instances[0].onopen?.();
|
||||||
|
instances[0].onclose?.();
|
||||||
|
|
||||||
|
expect(statusCallback).toHaveBeenCalledWith('disconnected');
|
||||||
|
expect(statusCallback).toHaveBeenCalledWith('reconnecting');
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(3000);
|
||||||
|
expect(instances).toHaveLength(2);
|
||||||
|
|
||||||
|
unsubscribeStatus();
|
||||||
|
progressWS.disconnect();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { WS_PROGRESS_URL } from './config';
|
import { WS_PROGRESS_URL } from './config';
|
||||||
|
|
||||||
type ProgressCallback = (data: ProgressMessage) => void;
|
type ProgressCallback = (data: ProgressMessage) => void;
|
||||||
|
type ConnectionStatus = 'connecting' | 'connected' | 'reconnecting' | 'disconnected';
|
||||||
|
type StatusCallback = (status: ConnectionStatus) => void;
|
||||||
|
|
||||||
interface ProgressMessage {
|
interface ProgressMessage {
|
||||||
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
|
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
|
||||||
@@ -20,9 +22,12 @@ class ProgressWebSocket {
|
|||||||
private ws: WebSocket | null = null;
|
private ws: WebSocket | null = null;
|
||||||
private url: string;
|
private url: string;
|
||||||
private callbacks: Set<ProgressCallback> = new Set();
|
private callbacks: Set<ProgressCallback> = new Set();
|
||||||
|
private statusCallbacks: Set<StatusCallback> = new Set();
|
||||||
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
|
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
private heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||||
private reconnectInterval = 3000;
|
private reconnectInterval = 3000;
|
||||||
private maxReconnectInterval = 30000;
|
private maxReconnectInterval = 30000;
|
||||||
|
private heartbeatInterval = 15000;
|
||||||
private shouldReconnect = false;
|
private shouldReconnect = false;
|
||||||
private shouldCloseAfterOpen = false;
|
private shouldCloseAfterOpen = false;
|
||||||
private currentInterval = 3000;
|
private currentInterval = 3000;
|
||||||
@@ -38,6 +43,7 @@ class ProgressWebSocket {
|
|||||||
|
|
||||||
this.shouldReconnect = true;
|
this.shouldReconnect = true;
|
||||||
this.shouldCloseAfterOpen = false;
|
this.shouldCloseAfterOpen = false;
|
||||||
|
this.notifyStatus('connecting');
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.ws = new WebSocket(this.url);
|
this.ws = new WebSocket(this.url);
|
||||||
@@ -50,6 +56,8 @@ class ProgressWebSocket {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
this.currentInterval = this.reconnectInterval;
|
this.currentInterval = this.reconnectInterval;
|
||||||
|
this.startHeartbeat();
|
||||||
|
this.notifyStatus('connected');
|
||||||
console.log('[WebSocket] Connected to progress stream');
|
console.log('[WebSocket] Connected to progress stream');
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -64,7 +72,9 @@ class ProgressWebSocket {
|
|||||||
|
|
||||||
this.ws.onclose = () => {
|
this.ws.onclose = () => {
|
||||||
console.log('[WebSocket] Connection closed');
|
console.log('[WebSocket] Connection closed');
|
||||||
|
this.stopHeartbeat();
|
||||||
this.ws = null;
|
this.ws = null;
|
||||||
|
this.notifyStatus('disconnected');
|
||||||
if (this.shouldReconnect) {
|
if (this.shouldReconnect) {
|
||||||
this.scheduleReconnect();
|
this.scheduleReconnect();
|
||||||
}
|
}
|
||||||
@@ -72,7 +82,9 @@ class ProgressWebSocket {
|
|||||||
|
|
||||||
this.ws.onerror = () => {
|
this.ws.onerror = () => {
|
||||||
// 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错
|
// 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错
|
||||||
|
this.stopHeartbeat();
|
||||||
this.ws = null;
|
this.ws = null;
|
||||||
|
this.notifyStatus('disconnected');
|
||||||
if (this.shouldReconnect) {
|
if (this.shouldReconnect) {
|
||||||
this.scheduleReconnect();
|
this.scheduleReconnect();
|
||||||
}
|
}
|
||||||
@@ -85,6 +97,7 @@ class ProgressWebSocket {
|
|||||||
|
|
||||||
disconnect() {
|
disconnect() {
|
||||||
this.shouldReconnect = false;
|
this.shouldReconnect = false;
|
||||||
|
this.stopHeartbeat();
|
||||||
if (this.reconnectTimer) {
|
if (this.reconnectTimer) {
|
||||||
clearTimeout(this.reconnectTimer);
|
clearTimeout(this.reconnectTimer);
|
||||||
this.reconnectTimer = null;
|
this.reconnectTimer = null;
|
||||||
@@ -102,6 +115,7 @@ class ProgressWebSocket {
|
|||||||
this.ws.close();
|
this.ws.close();
|
||||||
}
|
}
|
||||||
this.ws = null;
|
this.ws = null;
|
||||||
|
this.notifyStatus('disconnected');
|
||||||
}
|
}
|
||||||
|
|
||||||
onProgress(callback: ProgressCallback) {
|
onProgress(callback: ProgressCallback) {
|
||||||
@@ -111,21 +125,53 @@ class ProgressWebSocket {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onStatus(callback: StatusCallback) {
|
||||||
|
this.statusCallbacks.add(callback);
|
||||||
|
callback(this.isConnected() ? 'connected' : 'disconnected');
|
||||||
|
return () => {
|
||||||
|
this.statusCallbacks.delete(callback);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private scheduleReconnect() {
|
private scheduleReconnect() {
|
||||||
if (this.reconnectTimer) {
|
if (this.reconnectTimer) {
|
||||||
clearTimeout(this.reconnectTimer);
|
clearTimeout(this.reconnectTimer);
|
||||||
}
|
}
|
||||||
|
this.notifyStatus('reconnecting');
|
||||||
this.reconnectTimer = setTimeout(() => {
|
this.reconnectTimer = setTimeout(() => {
|
||||||
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`);
|
console.log('[WebSocket] Reconnecting to progress stream...');
|
||||||
this.connect();
|
this.connect();
|
||||||
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
|
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
|
||||||
}, this.currentInterval);
|
}, this.currentInterval);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private startHeartbeat() {
|
||||||
|
this.stopHeartbeat();
|
||||||
|
this.sendHeartbeat();
|
||||||
|
this.heartbeatTimer = setInterval(() => this.sendHeartbeat(), this.heartbeatInterval);
|
||||||
|
}
|
||||||
|
|
||||||
|
private stopHeartbeat() {
|
||||||
|
if (this.heartbeatTimer) {
|
||||||
|
clearInterval(this.heartbeatTimer);
|
||||||
|
this.heartbeatTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private sendHeartbeat() {
|
||||||
|
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||||
|
this.ws.send('ping');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private notifyStatus(status: ConnectionStatus) {
|
||||||
|
this.statusCallbacks.forEach((cb) => cb(status));
|
||||||
|
}
|
||||||
|
|
||||||
isConnected(): boolean {
|
isConnected(): boolean {
|
||||||
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
|
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const progressWS = new ProgressWebSocket();
|
export const progressWS = new ProgressWebSocket();
|
||||||
export type { ProgressMessage };
|
export type { ConnectionStatus, ProgressMessage };
|
||||||
|
|||||||
@@ -102,6 +102,15 @@ vi.mock('react-konva', () => ({
|
|||||||
props.onClick?.(konvaEvent);
|
props.onClick?.(konvaEvent);
|
||||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||||
}}
|
}}
|
||||||
|
onDoubleClick={(event) => {
|
||||||
|
const point = {
|
||||||
|
x: event.clientX || 120,
|
||||||
|
y: event.clientY || 80,
|
||||||
|
};
|
||||||
|
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
|
||||||
|
props.onDblClick?.(konvaEvent);
|
||||||
|
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
),
|
),
|
||||||
}));
|
}));
|
||||||
|
|||||||
Reference in New Issue
Block a user