From 8a9247075edbb60a262b7ac18d22b24e15e3b94b Mon Sep 17 00:00:00 2001 From: admin <572701190@qq.com> Date: Fri, 1 May 2026 21:50:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=20AI=20=E5=88=86?= =?UTF-8?q?=E5=89=B2=E4=B8=8E=E5=B7=A5=E4=BD=9C=E5=8C=BA=E6=A0=87=E6=B3=A8?= =?UTF-8?q?=E9=97=AD=E7=8E=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 功能增加: - 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。 - 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。 - 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。 - 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。 - 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。 bugfix: - 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。 - 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。 - 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。 - 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。 - 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。 测试和文档: - 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。 - 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。 - 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。 --- AGENTS.md | 15 +-- README.md | 10 +- backend/routers/ai.py | 23 ++++- backend/services/sam2_engine.py | 8 +- backend/services/sam3_engine.py | 30 +++++- backend/services/sam3_external_worker.py | 4 +- backend/services/sam_registry.py | 14 ++- backend/tests/test_ai.py | 25 ++++- backend/tests/test_sam2_engine.py | 63 ++++++++++++ backend/tests/test_sam3_engine.py | 50 ++++++++- doc/02-current-implementation-map.md | 6 +- doc/03-frontend-element-audit.md | 13 ++- doc/04-api-contracts.md | 25 ++++- doc/07-current-requirements-freeze.md | 18 +++- doc/08-current-design-freeze.md | 63 +++++++----- doc/09-test-plan.md | 24 +++-- src/components/AISegmentation.test.tsx | 125 ++++++++++++++++++++++- src/components/AISegmentation.tsx | 46 +++++++-- src/components/CanvasArea.test.tsx | 94 ++++++++++++++++- src/components/CanvasArea.tsx | 98 +++++++++++++++--- src/components/Dashboard.test.tsx | 21 ++++ src/components/Dashboard.tsx | 6 +- src/components/ProjectLibrary.test.tsx | 21 +++- src/components/ProjectLibrary.tsx | 109 +++++++++++++++----- src/components/ToolsPalette.test.tsx | 4 +- src/components/ToolsPalette.tsx | 3 +- src/components/VideoWorkspace.test.tsx | 21 ++-- src/components/VideoWorkspace.tsx | 85 +++++---------- src/lib/websocket.test.ts | 53 +++++++++- src/lib/websocket.ts | 50 ++++++++- src/test/setup.tsx | 9 ++ 31 files changed, 920 insertions(+), 216 deletions(-) create mode 100644 backend/tests/test_sam2_engine.py diff --git a/AGENTS.md b/AGENTS.md index e7e8cc7..22b336f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,7 +6,7 @@ ## 项目概述 -本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 +本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、显式视频生成帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 - **项目名称**: `react-example`(`package.json` 中的 `name`) - **前端入口**: `src/main.tsx` → `src/App.tsx` @@ -219,12 +219,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload 1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。 2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。 -3. 上传资源:视频走 `/api/media/upload`;DICOM 批量走 `/api/media/upload/dicom`。 -4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。 +3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。 +4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。 5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。 6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 -7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 -8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播;`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。 +7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 可拖动/删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 +8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播,但不支持文本语义提示;AI 页面在 SAM 2 纯文本时提示改用点提示或切换 SAM 3,SAM 2 多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,semantic 请求会把正数 `options.min_score` 传给 external worker 作为置信度阈值,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。 9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`。 10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。 11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 @@ -237,7 +237,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。 - 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。 - 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。 -- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。 +- Polygon 顶点编辑和新增顶点会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。 - 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。 - 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 - 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 @@ -245,8 +245,9 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`;SAM 2 路径使用视频 predictor,SAM 3 路径使用独立 Python helper 的官方 video tracker,完成后刷新后端已保存标注。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。 +- 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。 - `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。 -- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。 +- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`;前端 WebSocket 客户端通过 `onopen/onclose/onerror` 更新连接状态,并定时发送 `ping` 心跳。 --- diff --git a/README.md b/README.md index ce1f5a9..b768bb6 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,9 @@ ## 核心功能 -- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析 -- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性 -- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 +- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS +- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示;SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性 +- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 - **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index) - **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 @@ -325,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & ``` -`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 +视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的 WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 ### 步骤 7: 安装前端依赖并构建 @@ -461,6 +461,8 @@ pip install -e . --no-build-isolation - 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。 - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。 +- 工作区 SAM 2 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。 +- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks` 并自动选中;右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。 - 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed,调用 `/api/ai/propagate`,并在完成后刷新已保存标注。 - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 diff --git a/backend/routers/ai.py b/backend/routers/ai.py index e7d51ae..013956d 100644 --- a/backend/routers/ai.py +++ b/backend/routers/ai.py @@ -340,7 +340,21 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: elif prompt_type == "semantic": text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" - polygons, scores = sam_registry.predict_semantic(payload.model, image, text) + min_score = options.get("min_score") + confidence_threshold = None + if min_score is not None: + try: + parsed_min_score = float(min_score) + if parsed_min_score > 0: + confidence_threshold = parsed_min_score + except (TypeError, ValueError): + confidence_threshold = None + polygons, scores = sam_registry.predict_semantic( + payload.model, + image, + text, + confidence_threshold=confidence_threshold, + ) else: raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") @@ -352,6 +366,13 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: raise HTTPException(status_code=400, detail=str(exc)) from exc polygons, scores = _filter_predictions(polygons, scores, options, negative_points) + logger.info( + "AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d", + payload.model or "default", + prompt_type, + payload.image_id, + len(polygons), + ) return {"polygons": polygons, "scores": scores} diff --git a/backend/services/sam2_engine.py b/backend/services/sam2_engine.py index a527b6e..3592065 100644 --- a/backend/services/sam2_engine.py +++ b/backend/services/sam2_engine.py @@ -207,7 +207,7 @@ class SAM2Engine: masks, scores, _ = self._predictor.predict( point_coords=pts, point_labels=lbls, - multimask_output=True, + multimask_output=False, ) polygons = [] @@ -335,16 +335,16 @@ class SAM2Engine: masks, scores, _ = self._predictor.predict( point_coords=pts, point_labels=lbls, - multimask_output=True, + multimask_output=False, ) polygons = [] - for m in masks[:3]: # Limit to top 3 masks + for m in masks[:1]: poly = self._mask_to_polygon(m) if poly: polygons.append(poly) - return polygons, scores[:3].tolist() + return polygons, scores[:1].tolist() except Exception as exc: # noqa: BLE001 logger.error("SAM2 auto prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] diff --git a/backend/services/sam3_engine.py b/backend/services/sam3_engine.py index 25d7741..da90f4c 100644 --- a/backend/services/sam3_engine.py +++ b/backend/services/sam3_engine.py @@ -260,6 +260,7 @@ class SAM3Engine: *, text: str = "", box: list[float] | None = None, + confidence_threshold: float | None = None, ) -> tuple[list[list[list[float]]], list[float]]: status = self._external_status(force=True) if not status.get("available"): @@ -279,7 +280,11 @@ class SAM3Engine: "box": box, "model_version": settings.sam3_model_version, "checkpoint_path": self._checkpoint_path(), - "confidence_threshold": settings.sam3_confidence_threshold, + "confidence_threshold": ( + confidence_threshold + if confidence_threshold is not None + else settings.sam3_confidence_threshold + ), }, ensure_ascii=False, ), @@ -312,8 +317,18 @@ class SAM3Engine: raise RuntimeError(str(payload["error"])) return payload.get("polygons", []), payload.get("scores", []) - def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: - return self._predict_external(image, "semantic", text=text) + def _predict_semantic_external( + self, + image: np.ndarray, + text: str, + confidence_threshold: float | None = None, + ) -> tuple[list[list[list[float]]], list[float]]: + return self._predict_external( + image, + "semantic", + text=text, + confidence_threshold=confidence_threshold, + ) def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]: return self._predict_external(image, "box", box=box) @@ -378,11 +393,16 @@ class SAM3Engine: raise RuntimeError(str(payload["error"])) return payload.get("frames", []) - def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: + def predict_semantic( + self, + image: np.ndarray, + text: str, + confidence_threshold: float | None = None, + ) -> tuple[list[list[list[float]]], list[float]]: if not text.strip(): raise ValueError("SAM 3 semantic prompt requires non-empty text.") if not self._can_load() and self._external_status().get("available"): - return self._predict_semantic_external(image, text) + return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold) if not self._ensure_ready(): raise RuntimeError(self.status()["message"]) diff --git a/backend/services/sam3_external_worker.py b/backend/services/sam3_external_worker.py index 9e3a64d..ffceb07 100644 --- a/backend/services/sam3_external_worker.py +++ b/backend/services/sam3_external_worker.py @@ -190,7 +190,9 @@ def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]: def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]: masks = _to_numpy(output.get("masks", [])) scores = _to_numpy(output.get("scores", [])) - if masks.ndim == 4: + if masks.ndim == 2: + masks = masks[None, :, :] + elif masks.ndim == 4: masks = masks[:, 0] elif masks.ndim == 3 and masks.shape[0] == 1: masks = masks[None, 0] diff --git a/backend/services/sam_registry.py b/backend/services/sam_registry.py index 57e16ea..8d09591 100644 --- a/backend/services/sam_registry.py +++ b/backend/services/sam_registry.py @@ -83,10 +83,20 @@ class SAMRegistry: def predict_auto(self, model_id: str | None, image: Any): return self._ensure_available(model_id).predict_auto(image) - def predict_semantic(self, model_id: str | None, image: Any, text: str): + def predict_semantic( + self, + model_id: str | None, + image: Any, + text: str, + confidence_threshold: float | None = None, + ): model = self.normalize_model_id(model_id) if model == "sam3": - return self._ensure_available(model).predict_semantic(image, text) + return self._ensure_available(model).predict_semantic( + image, + text, + confidence_threshold=confidence_threshold, + ) return self._ensure_available(model).predict_auto(image) def propagate_video( diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index 5a02d15..45444cb 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -89,15 +89,25 @@ def test_predict_applies_crop_and_background_filter_options(client, monkeypatch) def test_predict_box_and_semantic_fallback(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) + calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: ( [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [0.8], )) - monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: ( - [[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]], - [0.5], - )) + + def fake_predict_semantic(model, image, text, confidence_threshold=None): + calls["semantic"] = { + "model": model, + "text": text, + "confidence_threshold": confidence_threshold, + } + return ( + [[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]], + [0.5], + ) + + monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", fake_predict_semantic) box_response = client.post("/api/ai/predict", json={ "image_id": frame["id"], @@ -108,12 +118,19 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch): "image_id": frame["id"], "prompt_type": "semantic", "prompt_data": "胆囊", + "model": "sam3", + "options": {"min_score": 0.05}, }) assert box_response.status_code == 200 assert box_response.json()["scores"] == [0.8] assert semantic_response.status_code == 200 assert semantic_response.json()["scores"] == [0.5] + assert calls["semantic"] == { + "model": "sam3", + "text": "胆囊", + "confidence_threshold": 0.05, + } def test_predict_interactive_combines_box_and_points(client, monkeypatch): diff --git a/backend/tests/test_sam2_engine.py b/backend/tests/test_sam2_engine.py new file mode 100644 index 0000000..bcea310 --- /dev/null +++ b/backend/tests/test_sam2_engine.py @@ -0,0 +1,63 @@ +import numpy as np + +from services.sam2_engine import SAM2Engine + + +class _FakePredictor: + def __init__(self, masks, scores): + self.masks = masks + self.scores = scores + self.calls = [] + + def set_image(self, _image): + pass + + def predict(self, **kwargs): + self.calls.append(kwargs) + return self.masks, self.scores, None + + +def _mask(offset=0): + mask = np.zeros((32, 32), dtype=np.uint8) + mask[4 + offset:20 + offset, 5 + offset:22 + offset] = 1 + return mask + + +def _ready_engine(monkeypatch, predictor): + monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True) + engine = SAM2Engine() + engine._model_loaded = True + engine._predictor = predictor + return engine + + +def test_sam2_point_prediction_requests_single_best_mask(monkeypatch): + predictor = _FakePredictor( + np.array([_mask()], dtype=np.uint8), + np.array([0.92], dtype=np.float32), + ) + engine = _ready_engine(monkeypatch, predictor) + + polygons, scores = engine.predict_points( + np.zeros((32, 32, 3), dtype=np.uint8), + [[0.5, 0.5]], + [1], + ) + + assert predictor.calls[0]["multimask_output"] is False + assert len(polygons) == 1 + assert scores == [0.9200000166893005] + + +def test_sam2_auto_prediction_keeps_single_best_mask(monkeypatch): + predictor = _FakePredictor( + np.array([_mask(0), _mask(2), _mask(4)], dtype=np.uint8), + np.array([0.8, 0.7, 0.6], dtype=np.float32), + ) + engine = _ready_engine(monkeypatch, predictor) + + polygons, scores = engine.predict_auto(np.zeros((32, 32, 3), dtype=np.uint8)) + + assert predictor.calls[0]["multimask_output"] is False + assert len(polygons) == 1 + assert scores == [0.800000011920929] diff --git a/backend/tests/test_sam3_engine.py b/backend/tests/test_sam3_engine.py index 3ea1303..f5296fe 100644 --- a/backend/tests/test_sam3_engine.py +++ b/backend/tests/test_sam3_engine.py @@ -4,7 +4,7 @@ from pathlib import Path import numpy as np from services.sam3_engine import SAM3Engine -from services.sam3_external_worker import _to_numpy +from services.sam3_external_worker import _prediction_to_response, _to_numpy class _Completed: @@ -98,6 +98,41 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): assert any("--request" in args for args in calls) +def test_sam3_predict_semantic_allows_request_threshold_override(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + + def fake_run(args, **_kwargs): + if "--status" in args: + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "checkpoint_access": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + request_path = Path(args[-1]) + request = json.loads(request_path.read_text(encoding="utf-8")) + assert request["confidence_threshold"] == 0.05 + return _Completed(stdout=json.dumps({ + "polygons": [[[0.2, 0.2], [0.6, 0.2], [0.6, 0.6]]], + "scores": [0.07], + })) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + polygons, scores = SAM3Engine().predict_semantic( + np.zeros((8, 8, 3), dtype=np.uint8), + "surgical scene", + confidence_threshold=0.05, + ) + + assert len(polygons) == 1 + assert scores == [0.07] + + def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch): _external_settings(monkeypatch, tmp_path / "python") @@ -243,3 +278,16 @@ def test_sam3_worker_casts_floating_tensors_before_numpy(): assert tensor.float_called is True assert result.tolist() == [1.0] + + +def test_sam3_worker_converts_single_2d_mask_to_polygon(): + mask = np.zeros((12, 12), dtype=np.uint8) + mask[2:10, 3:9] = 1 + + result = _prediction_to_response({ + "masks": mask, + "scores": np.array([0.82], dtype=np.float32), + }) + + assert len(result["polygons"]) == 1 + assert result["scores"] == [0.8199999928474426] diff --git a/doc/02-current-implementation-map.md b/doc/02-current-implementation-map.md index ecc453c..bd4246f 100644 --- a/doc/02-current-implementation-map.md +++ b/doc/02-current-implementation-map.md @@ -65,12 +65,12 @@ ### 项目与拆帧 1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。 -2. 上传视频时先 `createProject()`,再 `uploadMedia()`,再 `parseMedia()`。 +2. 上传视频时先 `createProject()`,再 `uploadMedia()`;导入视频不自动调用 `parseMedia()`。 3. 后端 `media.py` 把原始文件上传到 MinIO。 -4. `parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。 +4. 用户在项目库点击“生成帧”并选择 FPS 后,`parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。 5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。 6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。 -7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。 +7. 工作区只通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL;若项目有源视频但无帧,会提示先回项目库生成帧。 8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。 ### 工作区浏览 diff --git a/doc/03-frontend-element-audit.md b/doc/03-frontend-element-audit.md index edbdcac..b1a11d4 100644 --- a/doc/03-frontend-element-audit.md +++ b/doc/03-frontend-element-audit.md @@ -46,8 +46,9 @@ | 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 | | 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` | | 新建项目 | 真实可用 | 调用 `POST /api/projects` | -| 导入视频文件 | 真实可用 | 创建项目、上传文件、触发拆帧、刷新项目列表 | -| 解析 FPS 滑块 | 真实可用 | 值传入 `createProject({ parse_fps })` | +| 导入视频文件 | 真实可用 | 创建项目、上传源视频、刷新项目列表;不会自动拆帧 | +| 生成帧按钮 | 真实可用 | 仅对已导入源视频且尚无帧、非 parsing 状态的项目显示,调用 `parseMedia(projectId, { parseFps })` | +| 生成帧 FPS 滑块 | 真实可用 | 值传入 `/api/media/parse?parse_fps=...`,决定后台拆帧目标 FPS | | 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 | | 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 | | 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 | @@ -59,7 +60,7 @@ |------|------|------| | 当前项目名 | 真实可用 | 读取 `currentProject.name` | | 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` | -| 无帧时触发解析 | 真实可用 | 如果 `video_path` 存在会调用 `parseMedia()` 创建异步任务,并轮询 `GET /api/tasks/{id}` 等待完成 | +| 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 | | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 | | 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask | | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON | @@ -93,6 +94,7 @@ | 元素 | 状态 | 说明 | |------|------|------| | 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 | +| 调整多边形 | 真实可用 | 选中 polygon mask 后显示顶点和边中点;支持拖动顶点、点击边中点插点、双击边界按位置插点 | | 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask | | 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 | @@ -130,7 +132,8 @@ | SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 | | 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 | | 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 | -| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 | +| 执行高精度语义分割 | 真实可用 | 使用当前项目帧调用 `/api/ai/predict`;SAM 2 需要点提示且只采用最高分候选,SAM 3 需要文本语义提示;生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 | +| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的 mask,便于继续调轮廓和归档 | | 上传替换底图 | Mock / UI-only | 按钮无事件 | | 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 | | 清空全体锚点 | 部分可用 | 清空前端 points 和 masks | @@ -153,6 +156,6 @@ ## 总体结论 -当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 +当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。 diff --git a/doc/04-api-contracts.md b/doc/04-api-contracts.md index a45e853..8373b0b 100644 --- a/doc/04-api-contracts.md +++ b/doc/04-api-contracts.md @@ -32,7 +32,7 @@ Authorization: Bearer | `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 | | `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data | | `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data | -| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;支持 `parse_fps`、`max_frames`、`target_width` | +| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;由项目库“生成帧”显式调用,支持 `parse_fps`、`max_frames`、`target_width` | | `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 | | `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery | | `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 | @@ -91,6 +91,21 @@ Authorization: Bearer | GET | `/health` | 健康检查 | | WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 | +### WebSocket 进度通道 + +`/ws/progress` 用于 Dashboard 实时接收后台任务状态。前端连接成功后会定时发送 `ping` 作为心跳;后端收到任意文本心跳后返回: + +```json +{ + "type": "status", + "status": "connected", + "message": "Progress stream active", + "timestamp": "2026-05-01T00:00:00+00:00" +} +``` + +后台任务进度由 Celery worker 写入 Redis `seg:progress` 频道,再由 FastAPI 转发到当前活跃 WebSocket 连接。Dashboard 的“WebSocket 已连接/断开”状态来自浏览器 WebSocket 的 `onopen/onclose/onerror`,不再依赖是否刚好收到任务进度事件。 + ## 关键请求体 ### 登录 @@ -172,7 +187,11 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960 - `point` - `box` - `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points` 和 `labels`。 -- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。 +- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理。前端 AI 页面不会再用 SAM 2 发送纯文本 semantic;SAM 2 的交互入口应使用点/框提示。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。 + +SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同一提示下多个备选 mask 被前端叠加显示。 + +工作区 SAM 2 请求包含反向点时,`CanvasArea` 会发送 `options.auto_filter_background=true` 和 `options.min_score=0.05`;如果负向点过滤后没有可用 polygon,前端会移除当前旧候选 mask 并要求重新框选或添加正向点。 选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。 @@ -180,7 +199,7 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960 - `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 - `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。 -- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。 +- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值;对 SAM 3 semantic 请求也会作为 external worker 的 `confidence_threshold` 传入,避免本地 checkpoint 在默认高阈值下返回 0 个 mask。 后端响应: diff --git a/doc/07-current-requirements-freeze.md b/doc/07-current-requirements-freeze.md index b2ed627..6325be6 100644 --- a/doc/07-current-requirements-freeze.md +++ b/doc/07-current-requirements-freeze.md @@ -17,7 +17,8 @@ - 前端展示项目库,并从 `GET /api/projects` 获取项目列表。 - 用户可以新建项目,前端调用 `POST /api/projects`。 - 用户可以选择项目,进入工作区。 -- 用户可以导入视频文件,前端创建项目、上传文件、触发拆帧、刷新项目列表。 +- 用户可以导入视频文件,前端创建项目、上传文件并刷新项目列表;导入视频不自动拆帧。 +- 用户可以对已导入且尚未生成帧的视频项目点击“生成帧”,在弹窗中选择目标 FPS 后创建拆帧任务。 - 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。 - 后端支持项目创建、列表、详情、局部更新和删除。 - 后端支持项目帧创建、列表和单帧查询。 @@ -42,7 +43,7 @@ ## R4 工作区与帧浏览 - 工作区根据当前项目加载帧列表。 -- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。 +- 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧。 - Canvas 显示当前帧图片。 - Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。 - 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 @@ -57,7 +58,9 @@ - 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。 - 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。 - 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。 -- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。 +- 工具栏提供“调整多边形”工具,用户可以点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。 +- 顶点编辑态显示边中点插入手柄;点击边中点会在该边中间新增顶点。 +- “调整多边形”工具下双击 polygon 边界时,会在最接近的线段上按双击位置新增顶点。 - 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。 - 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。 - 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。 @@ -75,14 +78,19 @@ - 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。 - 框选提示传归一化 `[x1, y1, x2, y2]`。 - 工作区 SAM 2 框选会建立一个候选 mask;后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。 -- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 +- 工作区 SAM 2 一旦包含反向点,会随请求启用 `auto_filter_background` 和 `min_score=0.05`;若后端判定反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask,避免继续显示已被否定的区域。 +- SAM 2 不支持文本语义提示;AI 页面在 SAM 2 下输入纯文本时会提示用户改用点提示或切换 SAM 3,不再回退到自动分割。 +- SAM 2 点提示和 auto fallback 默认只采用一个最高分候选 mask,避免多个候选 mask 作为同一结果重叠显示。 +- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签。 +- AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 以便继续编辑轮廓和归档保存。 +- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理。 - SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。 - 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。 - 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed,调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。 - `POST /api/ai/propagate` 支持 `model=sam2` 或 `model=sam3`;SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。 - 传播结果会写入后续帧 `annotations`,`mask_data.source` 分别标记为 `sam2_propagation` 或 `sam3_propagation`,并保留 label、color 和 class 元数据。 - AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。 -- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。 +- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon;SAM 3 semantic 会用 `min_score` 控制 external worker 的置信度阈值。 - 后端返回 `polygons` 和 `scores`。 - 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。 - AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。 diff --git a/doc/08-current-design-freeze.md b/doc/08-current-design-freeze.md index 2d3bb34..25fabb9 100644 --- a/doc/08-current-design-freeze.md +++ b/doc/08-current-design-freeze.md @@ -22,11 +22,11 @@ | 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 | | API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 | | 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 | -| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 | +| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅、连接状态通知、心跳和重连 | | 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 | | 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store | | Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 | -| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM | +| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM、显式生成帧 | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | | 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | @@ -76,15 +76,16 @@ 2. `login()` 调用 `POST /api/auth/login`。 3. 成功后 store 写入 token,App 渲染主界面。 -### 项目导入 +### 项目导入与生成帧 1. `ProjectLibrary` 创建项目。 -2. 上传视频或 DICOM 到 `/api/media/upload` 或 `/api/media/upload/dicom`。 -3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。 -4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。 -5. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。 -6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 -7. 刷新项目列表。 +2. 导入视频时上传源视频到 `/api/media/upload` 并关联项目;该步骤不调用 `/api/media/parse`。 +3. 用户在项目卡片点击“生成帧”,在弹窗中选择目标 FPS。 +4. 前端调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。 +5. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。 +6. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。 +7. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 +8. 刷新项目列表。 ### 任务控制 @@ -93,11 +94,12 @@ 3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。 4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。 5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。 +6. Dashboard 通过 `/ws/progress` 接收 Redis `seg:progress` 转发事件;前端 WebSocket 客户端在 `onopen/onclose/onerror` 主动更新连接状态,并定时发送 `ping` 心跳,服务端返回 `status` 确认连接仍活跃。 ### 工作区加载 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。 -2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。 +2. 若无帧但项目有 `video_path`,显示“尚未生成帧”的状态提示,不自动触发 `parseMedia()`。 3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。 4. 当前帧传入 `CanvasArea`。 @@ -107,13 +109,15 @@ 2. `CanvasArea` 读取当前帧 ID 和宽高。 3. SAM 2 框选会创建一个候选 mask,并记录原始框;后续正向点/反向点会累计到同一候选上。 4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。 -5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。 -6. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。 -7. Canvas 按当前帧过滤并渲染 mask。 -8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。 -9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。 -10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 -11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 +5. SAM 2 请求中只要存在反向点,`CanvasArea` 会额外发送 `options.auto_filter_background=true` 和 `options.min_score=0.05`,让后端移除低分结果和包含负向点的 polygon。 +6. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。 +7. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。 +8. 若带反向点的 SAM 2 细化返回空结果,前端会删除当前旧候选 mask 并提示反向点已排除该区域。 +9. Canvas 按当前帧过滤并渲染 mask。 +10. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。 +11. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。 +12. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 +13. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 ### 视频片段传播 @@ -131,19 +135,20 @@ 1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。 2. `CanvasArea` 将交互坐标转换成像素 polygon。 3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。 -4. mask path 只在 `move`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。 +4. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。 5. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。 6. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。 7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。 ### Polygon 逐点编辑 -1. 用户点击 Canvas 上的 mask path 后,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。 +1. 用户选择“调整多边形”或“拖拽/选择”后点击 Canvas 上的 mask path,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点和边中点插入手柄。 2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation`、`bbox`、`area`。 -3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。 -4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。 -5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。 -6. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。 +3. 点击边中点手柄会在该边中点插入新顶点;在“调整多边形”工具下双击 polygon path 会在最接近的线段上按双击位置插入新顶点。 +4. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。 +5. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。 +6. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。 +7. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。 ### 区域合并与去除 @@ -173,9 +178,11 @@ 4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。 5. 返回时再解包给前端。 6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。 -7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。 -8. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。 -9. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。 +7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds`。 +8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon`;`CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。 +9. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。 +10. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。 +11. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。 ### 导出 @@ -204,10 +211,12 @@ - `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps`、`max_frames`、`target_width`,用于生成标准帧序列。 - `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms` 和 `source_frame_number`。 - 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。 +- SAM 2 是点/框交互式分割模型,不做文本语义分割;AI 页面在 SAM 2 + 纯文本时直接提示用户改用点提示或切换 SAM 3。 +- SAM 2 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。 - 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。 - SAM 3 box prompt 复用后端 `/api/ai/predict` 的 `box` prompt_type,输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。 - AI 页面选择 SAM 3 时优先发送文本 semantic prompt,不会把正/反点误发送为 SAM 3 point prompt;空文本、后端错误和空结果都会显示反馈消息。 -- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 +- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果;SAM 3 semantic 会把正数 `min_score` 传给 external worker 作为 `confidence_threshold`。 - 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker;当前前端默认向后传播 30 帧并保存结果标注。 - 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 diff --git a/doc/09-test-plan.md b/doc/09-test-plan.md index 19c3764..202b5fc 100644 --- a/doc/09-test-plan.md +++ b/doc/09-test-plan.md @@ -16,14 +16,14 @@ |------|----------|--------| | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | -| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | -| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | -| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 | -| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | +| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | +| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | +| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 | +| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、SAM 2 框选后正负点细化同一候选 mask、SAM 2 反向点启用背景过滤且空结果移除旧候选、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | | R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 | -| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | +| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | @@ -34,15 +34,15 @@ |------|--------|----------|----------| | R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 | | R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 | -| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 | -| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | -| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | -| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | -| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 | +| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 | +| R4 | 工作区加载帧、无帧项目不自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | +| R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | +| R5 | 顶点编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | +| R6 | SAM 2 点/框/interactive、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 semantic、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py`, `test_sam3_engine.py` | 已覆盖 | | R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 | | R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | | R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | -| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | +| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | | R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 | | R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 | | R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 | @@ -51,6 +51,8 @@ - R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。 - R6:补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。 +- R6:补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。 +- R6:补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。 - R6:补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。 - R6:补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。 - R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。 diff --git a/src/components/AISegmentation.test.tsx b/src/components/AISegmentation.test.tsx index 25ee8e9..668a107 100644 --- a/src/components/AISegmentation.test.tsx +++ b/src/components/AISegmentation.test.tsx @@ -63,6 +63,129 @@ describe('AISegmentation', () => { })); }); + it('does not run SAM2 text-only prompts as semantic segmentation', async () => { + render(); + + fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), { + target: { value: '胆囊' }, + }); + fireEvent.click(await screen.findByText('执行高精度语义分割')); + + expect(apiMock.predictMask).not.toHaveBeenCalled(); + expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument(); + }); + + it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => { + apiMock.predictMask.mockResolvedValueOnce({ + masks: [ + { + id: 'sam2-best', + pathData: 'M 0 0 L 10 0 L 10 10 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[0, 0, 10, 0, 10, 10]], + bbox: [0, 0, 10, 10], + area: 100, + }, + { + id: 'sam2-alt', + pathData: 'M 1 1 L 11 1 L 11 11 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[1, 1, 11, 1, 11, 11]], + bbox: [1, 1, 10, 10], + area: 100, + }, + ], + }); + + render(); + fireEvent.click(screen.getByText('正向选点')); + fireEvent.click(screen.getByTestId('konva-stage')); + fireEvent.click(await screen.findByText('执行高精度语义分割')); + + await waitFor(() => expect(useStore.getState().masks).toHaveLength(1)); + expect(useStore.getState().masks[0].id).toBe('sam2-best'); + expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']); + expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument(); + }); + + it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => { + useStore.setState({ + templates: [ + { + id: 'template-1', + name: '腹腔镜模板', + classes: [ + { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 }, + { id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 }, + ], + rules: [], + }, + ], + }); + apiMock.predictMask.mockResolvedValueOnce({ + masks: [ + { + id: 'sam2-mask', + pathData: 'M 10 10 L 40 10 L 40 40 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[10, 10, 40, 10, 40, 40]], + bbox: [10, 10, 30, 30], + area: 900, + }, + ], + }); + + render(); + fireEvent.click(screen.getByText('正向选点')); + fireEvent.click(screen.getByTestId('konva-stage')); + fireEvent.click(await screen.findByText('执行高精度语义分割')); + + await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask'])); + fireEvent.click(screen.getByText('肝脏')); + + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + templateId: 'template-1', + classId: 'class-2', + className: '肝脏', + classZIndex: 20, + label: '肝脏', + color: '#00ff00', + saveStatus: 'draft', + })); + }); + + it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => { + const onSendToWorkspace = vi.fn(); + apiMock.predictMask.mockResolvedValueOnce({ + masks: [ + { + id: 'sam2-mask', + pathData: 'M 10 10 L 40 10 L 40 40 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[10, 10, 40, 10, 40, 40]], + bbox: [10, 10, 30, 30], + area: 900, + }, + ], + }); + + render(); + fireEvent.click(screen.getByText('正向选点')); + fireEvent.click(screen.getByTestId('konva-stage')); + fireEvent.click(await screen.findByText('执行高精度语义分割')); + await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask'])); + + fireEvent.click(screen.getByText('推送至工作区编辑')); + + expect(useStore.getState().activeTool).toBe('edit_polygon'); + expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']); + expect(onSendToWorkspace).toHaveBeenCalled(); + }); + it('prompts for semantic text before running SAM3 inference', async () => { apiMock.getAiModelStatus.mockResolvedValue({ selected_model: 'sam3', @@ -106,7 +229,7 @@ describe('AISegmentation', () => { points: undefined, text: '胆囊', }))); - expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument(); + expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument(); }); it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => { diff --git a/src/components/AISegmentation.tsx b/src/components/AISegmentation.tsx index 5fef8c7..67e36a0 100644 --- a/src/components/AISegmentation.tsx +++ b/src/components/AISegmentation.tsx @@ -17,6 +17,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const masks = useStore((state) => state.masks); const addMask = useStore((state) => state.addMask); const clearMasks = useStore((state) => state.clearMasks); + const selectedMaskIds = useStore((state) => state.selectedMaskIds); + const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds); const maskHistory = useStore((state) => state.maskHistory); const maskFuture = useStore((state) => state.maskFuture); const undoMasks = useStore((state) => state.undoMasks); @@ -97,6 +99,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。'); return; } + if (aiModel === 'sam2' && textPrompt && points.length === 0) { + setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。'); + return; + } if (points.length === 0 && !textPrompt) { setInferenceMessage('请先放置正/反向提示点,或输入语义描述。'); return; @@ -132,14 +138,22 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { }, }); - if (result.masks.length === 0) { - setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。'); + const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks; + + if (masksToApply.length === 0) { + setInferenceMessage(aiModel === 'sam3' + ? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}` + : '模型没有返回可用区域,请换一个更具体的描述或调整提示。'); } else { - setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`); + setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1 + ? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。` + : `已生成 ${masksToApply.length} 个候选区域。`); } - result.masks.forEach((m) => { + const generatedMaskIds: string[] = []; + masksToApply.forEach((m) => { const label = activeClass?.name || m.label; const color = activeClass?.color || m.color; + generatedMaskIds.push(m.id); addMask({ id: m.id, frameId: currentFrame.id, @@ -157,6 +171,9 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { area: m.area, }); }); + if (generatedMaskIds.length > 0) { + setSelectedMaskIds(generatedMaskIds); + } } catch (err) { console.error('AI inference failed:', err); const detail = (err as any)?.response?.data?.detail; @@ -164,7 +181,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { } finally { setIsInferencing(false); } - }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]); + }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]); const handleStageClick = (e: any) => { if (effectiveTool === 'move') return; @@ -307,10 +324,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { )} @@ -376,12 +396,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { {/* AI Returned Masks */} {frameMasks.map((mask) => ( - + { + event.cancelBubble = true; + setSelectedMaskIds([mask.id]); + }} + onTap={(event: any) => { + event.cancelBubble = true; + setSelectedMaskIds([mask.id]); + }} /> ))} diff --git a/src/components/CanvasArea.test.tsx b/src/components/CanvasArea.test.tsx index 74c341c..3569359 100644 --- a/src/components/CanvasArea.test.tsx +++ b/src/components/CanvasArea.test.tsx @@ -206,16 +206,58 @@ describe('CanvasArea', () => { { x: 300, y: 150, type: 'neg' }, ], box: { x1: 120, y1: 80, x2: 260, y2: 200 }, + options: { auto_filter_background: true, min_score: 0.05 }, })); expect(useStore.getState().masks).toHaveLength(1); expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ id: 'mask-box', segmentation: [[30, 30, 70, 30, 70, 70]], points: [[150, 100]], - metadata: expect.objectContaining({ promptPointCount: 2 }), + metadata: expect.objectContaining({ promptPointCount: 2, promptNegativePointCount: 1 }), })); }); + it('removes the SAM2 candidate when a negative point filters it out', async () => { + apiMock.predictMask + .mockResolvedValueOnce({ + masks: [ + { + id: 'mask-box', + pathData: 'M 10 10 L 90 10 L 90 90 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[10, 10, 90, 10, 90, 90]], + bbox: [10, 10, 80, 80], + area: 6400, + }, + ], + }) + .mockResolvedValueOnce({ masks: [] }); + + const { rerender } = render(); + const stage = screen.getByTestId('konva-stage'); + fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 }); + fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 }); + fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 }); + + await waitFor(() => expect(useStore.getState().masks).toHaveLength(1)); + + rerender(); + fireEvent.click(stage, { clientX: 180, clientY: 120 }); + + await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, { + imageId: 'frame-1', + imageWidth: 640, + imageHeight: 360, + model: 'sam2', + points: [{ x: 180, y: 120, type: 'neg' }], + box: { x1: 120, y1: 80, x2: 260, y2: 200 }, + options: { auto_filter_background: true, min_score: 0.05 }, + })); + await waitFor(() => expect(useStore.getState().masks).toHaveLength(0)); + expect(await screen.findByText(/反向点已排除当前候选区域/)).toBeInTheDocument(); + }); + it('renders only masks that belong to the current frame', () => { useStore.setState({ masks: [ @@ -250,6 +292,28 @@ describe('CanvasArea', () => { await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1'])); }); + it('keeps a mask selected when opening the workspace polygon editor from AI results', () => { + useStore.setState({ + selectedMaskIds: ['m1'], + masks: [ + { + id: 'm1', + frameId: 'frame-1', + pathData: 'M 0 0 L 10 0 L 10 10 Z', + label: 'A', + color: '#fff', + segmentation: [[0, 0, 10, 0, 10, 10]], + }, + ], + }); + + render(); + + expect(useStore.getState().selectedMaskIds).toEqual(['m1']); + expect(screen.getAllByTestId('konva-circle') + .filter((element) => element.getAttribute('data-fill') === '#ffffff')).toHaveLength(3); + }); + it('renders imported GT seed points for editable point regions', () => { useStore.setState({ masks: [ @@ -415,6 +479,34 @@ describe('CanvasArea', () => { })); }); + it('selects a polygon with the edit tool and inserts a vertex by double-clicking an edge', () => { + useStore.setState({ + masks: [ + { + id: 'draft-1', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 40 Z', + label: 'Draft', + color: '#06b6d4', + saveStatus: 'draft', + segmentation: [[10, 10, 90, 10, 90, 40]], + bbox: [10, 10, 80, 30], + }, + ], + }); + + render(); + const path = screen.getByTestId('konva-path'); + fireEvent.click(path); + fireEvent.doubleClick(path, { clientX: 50, clientY: 10 }); + + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]], + pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z', + saveStatus: 'draft', + })); + }); + it('edits the selected polygon in a multi-polygon mask', () => { useStore.setState({ masks: [ diff --git a/src/components/CanvasArea.tsx b/src/components/CanvasArea.tsx index beacecb..d9e2abe 100644 --- a/src/components/CanvasArea.tsx +++ b/src/components/CanvasArea.tsx @@ -19,6 +19,7 @@ type PromptBox = { x1: number; y1: number; x2: number; y2: number }; const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']); const POLYGON_TOOL = 'create_polygon'; +const EDIT_POLYGON_TOOL = 'edit_polygon'; const POINT_TOOL = 'create_point'; const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']); const POLYGON_CLOSE_RADIUS = 8; @@ -95,6 +96,32 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number { return Math.hypot(a.x - b.x, a.y - b.y); } +function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number { + const dx = end.x - start.x; + const dy = end.y - start.y; + const lengthSquared = dx * dx + dy * dy; + if (lengthSquared === 0) { + return (point.x - start.x) ** 2 + (point.y - start.y) ** 2; + } + const t = clamp(((point.x - start.x) * dx + (point.y - start.y) * dy) / lengthSquared, 0, 1); + const projected = { x: start.x + t * dx, y: start.y + t * dy }; + return (point.x - projected.x) ** 2 + (point.y - projected.y) ** 2; +} + +function nearestPolygonEdgeIndex(points: CanvasPoint[], point: CanvasPoint): number { + return points.reduce((bestIndex, start, index) => { + const end = points[(index + 1) % points.length]; + if (!end) return bestIndex; + const bestStart = points[bestIndex]; + const bestEnd = points[(bestIndex + 1) % points.length]; + const currentDistance = distanceToSegmentSquared(point, start, end); + const bestDistance = bestStart && bestEnd + ? distanceToSegmentSquared(point, bestStart, bestEnd) + : Number.POSITIVE_INFINITY; + return currentDistance < bestDistance ? index : bestIndex; + }, 0); +} + function segmentationArea(segmentation?: number[][]): number { return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0); } @@ -210,10 +237,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota const [manualStart, setManualStart] = useState(null); const [manualCurrent, setManualCurrent] = useState(null); const [polygonPoints, setPolygonPoints] = useState([]); - const [selectedMaskId, setSelectedMaskId] = useState(null); - const [selectedMaskIds, setSelectedMaskIds] = useState([]); + const [selectedMaskId, setSelectedMaskId] = useState(() => useStore.getState().selectedMaskIds[0] || null); + const [selectedMaskIds, setSelectedMaskIds] = useState(() => useStore.getState().selectedMaskIds); const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0); const [selectedVertexIndex, setSelectedVertexIndex] = useState(null); + const previousFrameIdRef = useRef(frame?.id); const [isInferencing, setIsInferencing] = useState(false); const [inferenceMessage, setInferenceMessage] = useState(''); @@ -253,6 +281,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length; const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length; const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool); + const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL; useEffect(() => { const handleResize = () => { @@ -273,11 +302,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota setManualStart(null); setManualCurrent(null); setPolygonPoints([]); + setSelectedVertexIndex(null); + if (!isPolygonEditTool && !isBooleanTool) { + setSelectedMaskId(null); + setSelectedMaskIds([]); + setSelectedPolygonIndex(0); + } + }, [effectiveTool, isBooleanTool, isPolygonEditTool]); + + useEffect(() => { + if (previousFrameIdRef.current === frame?.id) return; + previousFrameIdRef.current = frame?.id; setSelectedMaskId(null); setSelectedMaskIds([]); setSelectedPolygonIndex(0); setSelectedVertexIndex(null); - }, [effectiveTool, frame?.id]); + }, [frame?.id]); useEffect(() => { setPoints([]); @@ -420,6 +460,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota setIsInferencing(true); setInferenceMessage(''); try { + const hasNegativePrompt = Boolean(promptPoints?.some((point) => point.type === 'neg')); + const existingCandidate = !options.resetCandidate && samCandidateMaskId + ? masks.find((mask) => mask.id === samCandidateMaskId) + : null; const result = await predictMask({ imageId: frame.id, imageWidth, @@ -429,13 +473,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota ? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type })) : undefined, box: promptBox, + ...(hasNegativePrompt ? { options: { auto_filter_background: true, min_score: 0.05 } } : {}), }); const [m] = result.masks; if (m) { - const existingCandidate = !options.resetCandidate && samCandidateMaskId - ? masks.find((mask) => mask.id === samCandidateMaskId) - : null; const label = activeClass?.name || existingCandidate?.label || m.label; const color = activeClass?.color || existingCandidate?.color || m.color; const metadata = { @@ -443,6 +485,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive', promptBox: promptBox || null, promptPointCount: promptPoints?.length || 0, + promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0, }; const nextMask = { frameId: frame.id, @@ -476,7 +519,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota }); } } else { - setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。'); + if (existingCandidate && hasNegativePrompt) { + setMasks(masks.filter((mask) => mask.id !== existingCandidate.id)); + setSamCandidateMaskId(null); + setSelectedMaskId(null); + setSelectedMaskIds([]); + setInferenceMessage('反向点已排除当前候选区域,请重新框选或添加新的正向点。'); + } else { + setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。'); + } } } catch (err) { console.error('Inference failed:', err); @@ -485,7 +536,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota } finally { setIsInferencing(false); } - }, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]); + }, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, setMasks, updateMask]); const handleApplyActiveClass = () => { if (!frame?.id || !activeClass) return; @@ -598,7 +649,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota }; const handleStageClick = (e: any) => { - if (effectiveTool === 'move') return; + if (isPolygonEditTool) return; if (effectiveTool === 'box_select') return; // handled by mouseup if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return; @@ -716,7 +767,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota window.addEventListener('keydown', handleKeyDown); return () => window.removeEventListener('keydown', handleKeyDown); - }, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]); + }, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]); const boxRect = React.useMemo(() => { if (!boxStart || !boxCurrent) return null; @@ -753,7 +804,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota }; const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => { - if (effectiveTool !== 'move' && !isBooleanTool) return; + if (!isPolygonEditTool && !isBooleanTool) return; event.cancelBubble = true; if (isBooleanTool) { setSelectedMaskIds((current) => ( @@ -807,6 +858,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota updatePolygonMask(mask, nextPoints, selectedPolygonIndex); }; + const handlePathDoubleClick = (mask: Mask, event: any, polygonIndex = 0) => { + if (effectiveTool !== EDIT_POLYGON_TOOL) return; + event.cancelBubble = true; + const point = stagePoint(event); + const currentPoints = segmentationToPoints(mask.segmentation, polygonIndex); + if (!point || currentPoints.length < 3) return; + const edgeIndex = nearestPolygonEdgeIndex(currentPoints, point); + const nextPoints = [ + ...currentPoints.slice(0, edgeIndex + 1), + point, + ...currentPoints.slice(edgeIndex + 1), + ]; + setSelectedMaskId(mask.id); + setSelectedMaskIds([mask.id]); + setSelectedPolygonIndex(polygonIndex); + setSelectedVertexIndex(edgeIndex + 1); + updatePolygonMask(mask, nextPoints, polygonIndex); + }; + const handleBooleanOperation = async () => { if (!frame || booleanSelectedMasks.length < 2) return; const primary = booleanSelectedMasks[0]; @@ -918,6 +988,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale} onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)} onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)} + onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)} + onDblTap={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)} /> ))} @@ -987,7 +1059,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota )))} {/* Polygon edge insertion handles */} - {!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => { + {isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => { const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length]; if (!next) return null; return ( @@ -1006,7 +1078,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota })} {/* Polygon vertex editor */} - {!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => ( + {isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => ( ({ const wsMock = vi.hoisted(() => { const state = { callback: undefined as undefined | ((data: any) => void), + statusCallback: undefined as undefined | ((status: any) => void), connected: false, }; return { @@ -24,6 +25,11 @@ const wsMock = vi.hoisted(() => { state.callback = cb; return vi.fn(); }), + onStatus: vi.fn((cb: (status: any) => void) => { + state.statusCallback = cb; + cb(state.connected ? 'connected' : 'disconnected'); + return vi.fn(); + }), }, }; }); @@ -45,6 +51,7 @@ describe('Dashboard', () => { vi.clearAllMocks(); wsMock.state.connected = false; wsMock.state.callback = undefined; + wsMock.state.statusCallback = undefined; apiMock.getDashboardOverview.mockResolvedValue({ summary: { project_count: 2, @@ -109,6 +116,20 @@ describe('Dashboard', () => { expect(screen.getByText('44%')).toBeInTheDocument(); }); + it('updates the websocket badge from connection status callbacks', async () => { + render(); + + await waitFor(() => expect(wsMock.progressWS.onStatus).toHaveBeenCalled()); + expect(screen.getByText('WebSocket 断开')).toBeInTheDocument(); + + act(() => { + wsMock.state.connected = true; + wsMock.state.statusCallback?.('connected'); + }); + + expect(screen.getByText('WebSocket 已连接')).toBeInTheDocument(); + }); + it('adds activity logs for complete and status messages', async () => { render(); diff --git a/src/components/Dashboard.tsx b/src/components/Dashboard.tsx index 7e89f34..9735a53 100644 --- a/src/components/Dashboard.tsx +++ b/src/components/Dashboard.tsx @@ -1,6 +1,6 @@ import React, { useState, useEffect } from 'react'; import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react'; -import { progressWS, type ProgressMessage } from '../lib/websocket'; +import { progressWS, type ConnectionStatus, type ProgressMessage } from '../lib/websocket'; import { cn } from '../lib/utils'; import { cancelTask, @@ -178,6 +178,9 @@ export function Dashboard() { ]); } }); + const unsubscribeStatus = progressWS.onStatus((status: ConnectionStatus) => { + if (mounted) setIsConnected(status === 'connected'); + }); const checkConnection = setInterval(() => { if (mounted) setIsConnected(progressWS.isConnected()); @@ -186,6 +189,7 @@ export function Dashboard() { return () => { mounted = false; unsubscribe(); + unsubscribeStatus(); clearInterval(checkConnection); progressWS.disconnect(); }; diff --git a/src/components/ProjectLibrary.test.tsx b/src/components/ProjectLibrary.test.tsx index d6fc398..97f3c59 100644 --- a/src/components/ProjectLibrary.test.tsx +++ b/src/components/ProjectLibrary.test.tsx @@ -56,10 +56,9 @@ describe('ProjectLibrary', () => { expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' })); }); - it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => { + it('imports video by creating a project and uploading media without parsing frames', async () => { apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' }); apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' }); - apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 }); apiMock.getProjects.mockResolvedValue([]); const { container } = render(); @@ -70,10 +69,24 @@ describe('ProjectLibrary', () => { await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({ name: 'clip.mp4', - parse_fps: 30, }))); expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3'); - expect(apiMock.parseMedia).toHaveBeenCalledWith('p3'); + expect(apiMock.parseMedia).not.toHaveBeenCalled(); + }); + + it('generates frames from an imported video with the selected FPS', async () => { + apiMock.getProjects + .mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'pending', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 30 }]) + .mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'parsing', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 12 }]); + apiMock.parseMedia.mockResolvedValueOnce({ id: 22, status: 'queued', progress: 0 }); + + const { container } = render(); + + fireEvent.click(await screen.findByRole('button', { name: '生成帧' })); + fireEvent.change(container.querySelector('input[type="range"]') as HTMLInputElement, { target: { value: '12' } }); + fireEvent.click(screen.getByRole('button', { name: '开始生成帧' })); + + await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 })); }); it('imports only valid DICOM files and parses the returned project', async () => { diff --git a/src/components/ProjectLibrary.tsx b/src/components/ProjectLibrary.tsx index 45bc4c1..1120612 100644 --- a/src/components/ProjectLibrary.tsx +++ b/src/components/ProjectLibrary.tsx @@ -1,5 +1,5 @@ import React, { useState, useEffect, useRef } from 'react'; -import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity } from 'lucide-react'; +import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity, Images } from 'lucide-react'; import { cn } from '../lib/utils'; import { useStore } from '../store/useStore'; import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api'; @@ -22,7 +22,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { const [showImportMenu, setShowImportMenu] = useState(false); const [showVideoConfig, setShowVideoConfig] = useState(false); const [pendingFile, setPendingFile] = useState(null); - const [parseFps, setParseFps] = useState(30); + const [frameProject, setFrameProject] = useState(null); + const [showFrameConfig, setShowFrameConfig] = useState(false); + const [frameParseFps, setFrameParseFps] = useState(30); + const [isGeneratingFrames, setIsGeneratingFrames] = useState(false); const videoInputRef = useRef(null); const dicomInputRef = useRef(null); @@ -57,7 +60,6 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { const handleVideoSelect = (file: File) => { setPendingFile(file); - setParseFps(30); setShowVideoConfig(true); }; @@ -69,11 +71,9 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { const newProject = await createProject({ name: pendingFile.name, description: `导入于 ${new Date().toLocaleString()}`, - parse_fps: parseFps, }); const result = await uploadMedia(pendingFile, String(newProject.id)); - await parseMedia(String(newProject.id)); - alert(`上传成功: ${pendingFile.name}\n已保存至: ${result.url}`); + alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`); const data = await getProjects(); setProjects(data); } catch (err) { @@ -86,6 +86,31 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { } }; + const openFrameConfig = (project: Project, event: React.MouseEvent) => { + event.stopPropagation(); + setFrameProject(project); + setFrameParseFps(Math.round(project.parse_fps || 30)); + setShowFrameConfig(true); + }; + + const handleGenerateFrames = async () => { + if (!frameProject?.id) return; + setIsGeneratingFrames(true); + try { + const task = await parseMedia(frameProject.id, { parseFps: frameParseFps }); + alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`); + const data = await getProjects(); + setProjects(data); + setShowFrameConfig(false); + setFrameProject(null); + } catch (err) { + console.error('Frame generation failed:', err); + alert('生成帧失败,请检查后端服务或项目源文件'); + } finally { + setIsGeneratingFrames(false); + } + }; + const handleDicomUpload = async (files: FileList | null) => { if (!files || files.length === 0) return; const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm')); @@ -209,7 +234,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { )}
- {proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')} + {proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))} {proj.status === 'ready' ? ( @@ -235,6 +260,15 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { 原 {proj.original_fps.toFixed(1)}fps )}
+ {proj.video_path && (proj.frames ?? 0) === 0 && proj.status !== 'parsing' && ( + + )} ))} @@ -245,24 +279,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { {showVideoConfig && pendingFile && (
-

导入视频配置

+

导入视频文件

文件: {pendingFile.name}
-
- -
- setParseFps(parseInt(e.target.value))} - className="flex-1 accent-cyan-500" - /> - {parseFps} -
-

帧率越低,提取的帧越少,处理速度越快

-
+

此步骤只上传源视频并创建项目,不会立即拆帧。拆帧时再选择目标 FPS。

+ +
+
+
+ )} + {/* New project modal */} {showModal && (
diff --git a/src/components/ToolsPalette.test.tsx b/src/components/ToolsPalette.test.tsx index 0808664..1426399 100644 --- a/src/components/ToolsPalette.test.tsx +++ b/src/components/ToolsPalette.test.tsx @@ -20,12 +20,14 @@ describe('ToolsPalette', () => { ); fireEvent.click(screen.getByTitle('创建多边形 (P)')); + fireEvent.click(screen.getByTitle('调整多边形 (E)')); fireEvent.click(screen.getByTitle('正向选点 (SAM)')); fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)')); fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')); expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon'); - expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos'); + expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon'); + expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos'); expect(onUndo).toHaveBeenCalled(); expect(onRedo).toHaveBeenCalled(); }); diff --git a/src/components/ToolsPalette.tsx b/src/components/ToolsPalette.tsx index 4aa0e35..e7b4712 100644 --- a/src/components/ToolsPalette.tsx +++ b/src/components/ToolsPalette.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react'; +import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react'; import { cn } from '../lib/utils'; interface ToolsPaletteProps { @@ -23,6 +23,7 @@ export function ToolsPalette({ }: ToolsPaletteProps) { const tools = [ { id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' }, + { id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' }, { id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' }, { id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' }, { id: 'create_circle', icon: Circle, label: '创建圆 (O)' }, diff --git a/src/components/VideoWorkspace.test.tsx b/src/components/VideoWorkspace.test.tsx index 9c266da..827159e 100644 --- a/src/components/VideoWorkspace.test.tsx +++ b/src/components/VideoWorkspace.test.tsx @@ -82,23 +82,16 @@ describe('VideoWorkspace', () => { expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1'); }); - it('triggers parsing when a media project has no frames yet', async () => { - apiMock.getProjectFrames - .mockResolvedValueOnce([]) - .mockResolvedValueOnce([ - { id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 }, - ]); - apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 }); - apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' }); + it('does not auto-generate frames when a media project has no frames yet', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([]); render(); - await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1')); - expect(apiMock.getTask).toHaveBeenCalledWith(7); - await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({ - id: '11', - url: '/parsed.jpg', - }))); + await waitFor(() => expect(apiMock.getProjectFrames).toHaveBeenCalledWith('1')); + expect(apiMock.parseMedia).not.toHaveBeenCalled(); + expect(apiMock.getTask).not.toHaveBeenCalled(); + expect(useStore.getState().frames).toEqual([]); + expect(await screen.findByText('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”')).toBeInTheDocument(); }); it('hydrates saved annotations after loading frames', async () => { diff --git a/src/components/VideoWorkspace.tsx b/src/components/VideoWorkspace.tsx index 18c1fdb..a6f5cc6 100644 --- a/src/components/VideoWorkspace.tsx +++ b/src/components/VideoWorkspace.tsx @@ -8,10 +8,8 @@ import { exportMasks, getProjectAnnotations, getProjectFrames, - getTask, getTemplates, importGtMask, - parseMedia, propagateMasks, saveAnnotation, updateAnnotation, @@ -23,10 +21,6 @@ import { FrameTimeline } from './FrameTimeline'; import { ModelStatusBadge } from './ModelStatusBadge'; import type { Frame } from '../store/useStore'; -function sleep(ms: number) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { const gtMaskInputRef = React.useRef(null); const activeTool = useStore((state) => state.activeTool); @@ -72,64 +66,31 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void const data = await getProjectFrames(String(currentProject.id)); if (cancelled) return; - if (data.length === 0 && currentProject.video_path) { - // No frames yet but video exists -> queue parsing and poll the task. - try { - const task = await parseMedia(String(currentProject.id)); - if (cancelled) return; - setStatusMessage(`解析任务已入队 #${task.id}`); - let completed = false; - for (let attempt = 0; attempt < 60; attempt += 1) { - const freshTask = await getTask(task.id); - if (cancelled) return; - setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`); - if (freshTask.status === 'success') { - completed = true; - break; - } - if (freshTask.status === 'failed') { - setStatusMessage(freshTask.error || '解析任务失败'); - return; - } - await sleep(2000); - } - if (!completed) { - setStatusMessage('解析仍在后台运行,可稍后刷新工作区'); - return; - } - const fresh = await getProjectFrames(String(currentProject.id)); - if (cancelled) return; - const mappedFrames = fresh.map((f) => ({ - id: String(f.id), - projectId: String(f.project_id), - index: f.frame_index, - url: f.image_url, - width: f.width ?? 0, - height: f.height ?? 0, - timestampMs: f.timestamp_ms ?? undefined, - sourceFrameNumber: f.source_frame_number ?? undefined, - })); - setFrames(mappedFrames); - setCurrentFrame(0); - await hydrateSavedAnnotations(String(currentProject.id), mappedFrames); - } catch (err) { - console.error('Parse failed:', err); + const mappedFrames = data.map((f) => ({ + id: String(f.id), + projectId: String(f.project_id), + index: f.frame_index, + url: f.image_url, + width: f.width ?? 0, + height: f.height ?? 0, + timestampMs: f.timestamp_ms ?? undefined, + sourceFrameNumber: f.source_frame_number ?? undefined, + })); + setFrames(mappedFrames); + setCurrentFrame(0); + if (mappedFrames.length === 0) { + setMasks([]); + if (currentProject.status === 'parsing') { + setStatusMessage('生成帧任务正在后台运行,可在 Dashboard 查看进度'); + } else if (currentProject.video_path) { + setStatusMessage('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”'); + } else { + setStatusMessage('当前项目没有可显示帧'); } - } else { - const mappedFrames = data.map((f) => ({ - id: String(f.id), - projectId: String(f.project_id), - index: f.frame_index, - url: f.image_url, - width: f.width ?? 0, - height: f.height ?? 0, - timestampMs: f.timestamp_ms ?? undefined, - sourceFrameNumber: f.source_frame_number ?? undefined, - })); - setFrames(mappedFrames); - setCurrentFrame(0); - await hydrateSavedAnnotations(String(currentProject.id), mappedFrames); + return; } + setStatusMessage(''); + await hydrateSavedAnnotations(String(currentProject.id), mappedFrames); } catch (err) { console.error('Failed to load frames:', err); } diff --git a/src/lib/websocket.test.ts b/src/lib/websocket.test.ts index 1125d3e..2dfdd1b 100644 --- a/src/lib/websocket.test.ts +++ b/src/lib/websocket.test.ts @@ -2,12 +2,14 @@ import { afterEach, describe, expect, it, vi } from 'vitest'; describe('progress websocket client', () => { afterEach(() => { + vi.useRealTimers(); vi.restoreAllMocks(); vi.resetModules(); vi.unstubAllGlobals(); }); - it('connects using the configured URL and reports open state', async () => { + it('connects using the configured URL, reports open state, and sends heartbeat pings', async () => { + vi.useFakeTimers(); const instances: any[] = []; class FakeWebSocket { static CONNECTING = 0; @@ -21,14 +23,26 @@ describe('progress websocket client', () => { instances.push(this); } close = vi.fn(); + send = vi.fn(); } vi.stubGlobal('WebSocket', FakeWebSocket); const { progressWS } = await import('./websocket'); + const statusCallback = vi.fn(); + const unsubscribeStatus = progressWS.onStatus(statusCallback); progressWS.connect(); + instances[0].onopen?.(); expect(instances[0].url).toContain('/ws/progress'); expect(progressWS.isConnected()).toBe(true); + expect(statusCallback).toHaveBeenCalledWith('connected'); + expect(instances[0].send).toHaveBeenCalledWith('ping'); + + vi.advanceTimersByTime(15000); + expect(instances[0].send).toHaveBeenCalledTimes(2); + + unsubscribeStatus(); + progressWS.disconnect(); }); it('subscribes and unsubscribes progress callbacks', async () => { @@ -43,4 +57,41 @@ describe('progress websocket client', () => { expect(callback).toHaveBeenCalledTimes(1); expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' }); }); + + it('notifies connection status changes and schedules reconnect on close', async () => { + vi.useFakeTimers(); + const instances: any[] = []; + class FakeWebSocket { + static CONNECTING = 0; + static OPEN = 1; + readyState = FakeWebSocket.OPEN; + onopen?: () => void; + onmessage?: (event: MessageEvent) => void; + onclose?: () => void; + onerror?: () => void; + constructor(public url: string) { + instances.push(this); + } + close = vi.fn(); + send = vi.fn(); + } + vi.stubGlobal('WebSocket', FakeWebSocket); + + const { progressWS } = await import('./websocket'); + const statusCallback = vi.fn(); + const unsubscribeStatus = progressWS.onStatus(statusCallback); + + progressWS.connect(); + instances[0].onopen?.(); + instances[0].onclose?.(); + + expect(statusCallback).toHaveBeenCalledWith('disconnected'); + expect(statusCallback).toHaveBeenCalledWith('reconnecting'); + + vi.advanceTimersByTime(3000); + expect(instances).toHaveLength(2); + + unsubscribeStatus(); + progressWS.disconnect(); + }); }); diff --git a/src/lib/websocket.ts b/src/lib/websocket.ts index 3f3d0d9..b7d584b 100644 --- a/src/lib/websocket.ts +++ b/src/lib/websocket.ts @@ -1,6 +1,8 @@ import { WS_PROGRESS_URL } from './config'; type ProgressCallback = (data: ProgressMessage) => void; +type ConnectionStatus = 'connecting' | 'connected' | 'reconnecting' | 'disconnected'; +type StatusCallback = (status: ConnectionStatus) => void; interface ProgressMessage { type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled'; @@ -20,9 +22,12 @@ class ProgressWebSocket { private ws: WebSocket | null = null; private url: string; private callbacks: Set = new Set(); + private statusCallbacks: Set = new Set(); private reconnectTimer: ReturnType | null = null; + private heartbeatTimer: ReturnType | null = null; private reconnectInterval = 3000; private maxReconnectInterval = 30000; + private heartbeatInterval = 15000; private shouldReconnect = false; private shouldCloseAfterOpen = false; private currentInterval = 3000; @@ -38,6 +43,7 @@ class ProgressWebSocket { this.shouldReconnect = true; this.shouldCloseAfterOpen = false; + this.notifyStatus('connecting'); try { this.ws = new WebSocket(this.url); @@ -50,6 +56,8 @@ class ProgressWebSocket { return; } this.currentInterval = this.reconnectInterval; + this.startHeartbeat(); + this.notifyStatus('connected'); console.log('[WebSocket] Connected to progress stream'); }; @@ -64,7 +72,9 @@ class ProgressWebSocket { this.ws.onclose = () => { console.log('[WebSocket] Connection closed'); + this.stopHeartbeat(); this.ws = null; + this.notifyStatus('disconnected'); if (this.shouldReconnect) { this.scheduleReconnect(); } @@ -72,7 +82,9 @@ class ProgressWebSocket { this.ws.onerror = () => { // 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错 + this.stopHeartbeat(); this.ws = null; + this.notifyStatus('disconnected'); if (this.shouldReconnect) { this.scheduleReconnect(); } @@ -85,6 +97,7 @@ class ProgressWebSocket { disconnect() { this.shouldReconnect = false; + this.stopHeartbeat(); if (this.reconnectTimer) { clearTimeout(this.reconnectTimer); this.reconnectTimer = null; @@ -102,6 +115,7 @@ class ProgressWebSocket { this.ws.close(); } this.ws = null; + this.notifyStatus('disconnected'); } onProgress(callback: ProgressCallback) { @@ -111,21 +125,53 @@ class ProgressWebSocket { }; } + onStatus(callback: StatusCallback) { + this.statusCallbacks.add(callback); + callback(this.isConnected() ? 'connected' : 'disconnected'); + return () => { + this.statusCallbacks.delete(callback); + }; + } + private scheduleReconnect() { if (this.reconnectTimer) { clearTimeout(this.reconnectTimer); } + this.notifyStatus('reconnecting'); this.reconnectTimer = setTimeout(() => { - console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`); + console.log('[WebSocket] Reconnecting to progress stream...'); this.connect(); this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval); }, this.currentInterval); } + private startHeartbeat() { + this.stopHeartbeat(); + this.sendHeartbeat(); + this.heartbeatTimer = setInterval(() => this.sendHeartbeat(), this.heartbeatInterval); + } + + private stopHeartbeat() { + if (this.heartbeatTimer) { + clearInterval(this.heartbeatTimer); + this.heartbeatTimer = null; + } + } + + private sendHeartbeat() { + if (this.ws?.readyState === WebSocket.OPEN) { + this.ws.send('ping'); + } + } + + private notifyStatus(status: ConnectionStatus) { + this.statusCallbacks.forEach((cb) => cb(status)); + } + isConnected(): boolean { return this.ws !== null && this.ws.readyState === WebSocket.OPEN; } } export const progressWS = new ProgressWebSocket(); -export type { ProgressMessage }; +export type { ConnectionStatus, ProgressMessage }; diff --git a/src/test/setup.tsx b/src/test/setup.tsx index b0b52cc..36f100b 100644 --- a/src/test/setup.tsx +++ b/src/test/setup.tsx @@ -102,6 +102,15 @@ vi.mock('react-konva', () => ({ props.onClick?.(konvaEvent); if (konvaEvent.cancelBubble) event.stopPropagation(); }} + onDoubleClick={(event) => { + const point = { + x: event.clientX || 120, + y: event.clientY || 80, + }; + const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false }; + props.onDblClick?.(konvaEvent); + if (konvaEvent.cancelBubble) event.stopPropagation(); + }} /> ), }));