diff --git a/AGENTS.md b/AGENTS.md index 2745a99..c3f23bf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -78,10 +78,11 @@ Seg_Server/ │ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames │ │ ├── templates.py # /api/templates │ │ ├── media.py # /api/media/upload、/upload/dicom、/parse -│ │ ├── ai.py # /api/ai/predict、/propagate、/models/status、/auto、/annotate +│ │ ├── ai.py # /api/ai/predict、/propagate、/propagate/task、/models/status、/auto、/annotate │ │ └── export.py # /api/export/{project_id}/coco、/masks │ └── services/ │ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传 +│ ├── propagation_task_runner.py # Celery 自动传播任务 runner │ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理和 video predictor 传播封装 │ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry │ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper;当前未被产品入口调用 @@ -105,6 +106,7 @@ Seg_Server/ ├── OntologyInspector.tsx ├── FrameTimeline.tsx ├── AISegmentation.tsx + ├── TransientNotice.tsx └── TemplateRegistry.tsx ``` @@ -115,6 +117,7 @@ Seg_Server/ - `doc/03-frontend-element-audit.md`:哪些前端元素是真功能,哪些是 Mock/UI-only。 - `doc/04-api-contracts.md`:前后端接口契约,以及当前不一致点。 - `doc/05-implementation-plan.md`:建议的后续实施顺序。 +- `doc/10-installation.md`:完整安装部署流程,覆盖 PostgreSQL、Redis、MinIO、FastAPI、Celery、前端和 SAM 2.1 权重。 --- @@ -195,6 +198,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - `POST /api/tasks/{task_id}/retry` - `POST /api/ai/predict` - `POST /api/ai/propagate` + - `POST /api/ai/propagate/task` - `GET /api/ai/models/status` - `POST /api/ai/auto` - `POST /api/ai/annotate` @@ -220,12 +224,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload 1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。 2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表、删除项目;删除当前项目后会清空工作区当前项目、帧、mask 和选区。 3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。 -4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。 +4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数;项目库和模板库的成功/失败短反馈使用非阻塞 `TransientNotice`,会自动消失。 5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。 -6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在顶部进度条对应区段标为浅蓝色,当前帧位置由播放进度条末端、时间提示和缩略图高亮表达;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 -7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 后可按住顶点直接拖动并实时更新 polygon,也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 -8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。 -9. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;前端只保留一个“自动传播”按钮,会按 seed mask 和前/后方向顺序调用单 seed `POST /api/ai/propagate`,避免多个视频 tracker 并发抢占 GPU;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`。 +6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`CanvasArea` 会按容器和帧尺寸默认居中放大底图并保留边距;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在视频处理进度条显示为蓝色区段,人工/AI 标注帧显示红色竖线;视频处理进度条和红/蓝标识可点击跳转到对应帧;底部缩略图中人工/AI 标注帧用红色边框、自动传播/推理帧用蓝色边框,当前帧仍以青色外框高亮优先;若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边;只有进入自动传播范围选择模式时,播放进度条和视频处理进度条才显示黄色范围框,并可点击/拖拽选择传播起止帧;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 +7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,左侧 `ToolsPalette` 使用紧凑垂直布局并在高度不足时自身滚动;点击 mask 后可按住顶点直接拖动并实时更新 polygon,也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 +8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;AI 画布会按容器和当前帧尺寸默认居中放大底图并保留边距;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。 +9. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;用户可直接修改数字框,也可点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏有独立“传播权重”选择器,可为本次传播二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,也不影响 AI 单帧分割权重;前端会按传播权重 id、seed mask、seed 来源 id 和前/后方向组装 `steps` 并调用 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务;后端入队时会规范化/校验权重 id 并把规范化后的 id 写入任务 payload/result;Celery worker 顺序执行各 step,避免多个视频 tracker 并发抢占 GPU;每个 step 会根据 seed 来源 id、权重、方向和 seed 签名做幂等判断,未改变的 seed 直接跳过,已改变的 seed 会先删除同源旧自动传播标注再重传;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 权重变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`。工作区轮询 `GET /api/tasks/{task_id}` 展示进度并刷新标注,Dashboard 也能显示/取消/重试传播任务。 10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。 11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。 @@ -242,7 +246,8 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 - 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 - 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。 -- 工作区“自动传播”按钮已接入 `POST /api/ai/propagate`;当前启用所选 SAM 2.1 变体的视频 predictor,完成后刷新后端已保存标注。 +- 工作区“自动传播”按钮已接入 `POST /api/ai/propagate/task`;若用户尚未显式设置范围,第一次点击会进入时间轴范围选择模式,第二次点击“开始传播”才提交后台任务;当前启用所选 SAM 2.1 变体的视频 predictor 后台任务,运行中轮询任务进度,完成后刷新后端已保存标注;同步 `POST /api/ai/propagate` 仍作为单 seed 兼容接口保留。 +- 工作区顶栏短状态会自动消失;保存、导出、导入 GT、传播进行中和无帧项目提示会保留到状态变化。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。 - 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。 diff --git a/README.md b/README.md index c9f8a55..4bdf84d 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,8 @@ ## 核心功能 - **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS,项目卡片可删除项目及其关联帧、标注和任务记录 -- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示 -- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 +- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 Celery 后台 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示 +- **交互式画布标注** — 基于 Konva 的高性能 Canvas,工作区和 AI 画布会默认居中放大底图并保留边距;支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 - **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index) - **项目工作区** — 项目创建、帧浏览、多图层标注、自动传播帧提示、进度追踪 @@ -104,6 +104,7 @@ Seg_Server/ │ │ ├── ai.py # SAM 推理与模型状态接口 │ │ └── export.py # 数据导出 │ └── services/ # 业务服务 +│ ├── propagation_task_runner.py # Celery 自动传播任务 runner │ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理 + video predictor 传播 │ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry │ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper;当前未被产品入口调用 @@ -125,10 +126,11 @@ Seg_Server/ │ ├── ProjectLibrary.tsx # 项目库列表 │ ├── VideoWorkspace.tsx # 核心分割工作区布局 │ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染) -│ ├── ToolsPalette.tsx # 左侧工具栏 +│ ├── ToolsPalette.tsx # 左侧紧凑工具栏(高度不足时滚动) │ ├── OntologyInspector.tsx # 右侧本体/属性检查面板 │ ├── FrameTimeline.tsx # 底部时间轴 │ ├── AISegmentation.tsx # AI 智能分割引擎界面 +│ ├── TransientNotice.tsx # 非阻塞自动消失短提示 │ └── TemplateRegistry.tsx # 模板库管理 ├── models/ # SAM 2 模型权重(.pt 文件) ├── uploads/ # 临时上传目录 @@ -154,6 +156,7 @@ Seg_Server/ - `doc/03-frontend-element-audit.md` — 前端逐元素功能审计,标注真实可用、部分可用、Mock/UI-only、接口不通 - `doc/04-api-contracts.md` — 前后端接口契约和已知不一致 - `doc/06-fastapi-docs-explained.md` — `http://192.168.3.11:8000/docs` 的作用说明 +- `doc/10-installation.md` — 独立安装部署指南,覆盖 PostgreSQL、Redis、MinIO、FastAPI、Celery、前端和 SAM 2.1 权重 --- @@ -315,7 +318,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 & - 测试 Redis 连接 - 懒加载所选 SAM 2.1 模型;`GET /api/ai/models/status` 会返回 tiny/small/base+/large 和 GPU 的真实可用状态,`selected_model=sam3` 会返回不支持 - `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤 -- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()` +- `/api/ai/propagate/task` 支持从当前帧 seed 区域向视频片段创建后台传播任务:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`;同步 `/api/ai/propagate` 仍作为单 seed 兼容接口保留 ### 步骤 6.1: 启动 Celery Worker @@ -329,7 +332,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & ``` -视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的任务进度区展示 queued/running/success/failed/cancelled 最近任务,处理中统计只计算 queued/running;WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 +视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。项目库和模板库的成功/失败反馈使用非阻塞短提示,会自动消失,不再用浏览器 `alert()` 阻塞后续操作。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的任务进度区展示 queued/running/success/failed/cancelled 最近任务,处理中统计只计算 queued/running;WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 ### 步骤 7: 安装前端依赖并构建 @@ -467,7 +470,7 @@ pip install -e . --no-build-isolation - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。 - 工作区 SAM 2.1 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。 - AI 页面只显示本页最新生成的 SAM 2.1 候选,不会把工作区已有 mask 带入 AI 画布;重复执行高精度分割会替换上一次 AI 页候选;新生成 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。 -- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;前端只保留一个“自动传播”按钮,会按 seed 和方向顺序调用 `/api/ai/propagate`,并在完成后刷新已保存标注。传播结果回显后,时间进度条会把自动传播生成的帧区段标为浅蓝色。 +- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;用户可直接修改数字框,也可先点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏可单独选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换;前端会把传播权重 id、seed、seed 来源 id 和方向组装为 `/api/ai/propagate/task` 后台任务。后端入队时会规范化/校验权重 id,并把规范化后的 id 写入任务 payload/result;worker 会按 seed 来源、权重、方向和 seed 签名去重,未改变的 mask 二次传播时直接跳过,已改变的 mask 会先删除同源旧自动传播标注再重传,避免同一个 mask 传播两次产生重叠。任务进度写入 `processing_tasks` 并可在 Dashboard 查看/取消/重试,工作区轮询任务状态并刷新已保存标注。传播结果回显后,视频处理进度条会把自动传播生成的帧区段标为蓝色,人工/AI 标注帧显示为红色竖线;普通状态下点击视频处理进度条或红/蓝帧标识可跳转到对应帧,底部缩略图也会用红色边框标识人工/AI 标注帧、蓝色边框标识传播/推理帧;当前帧如果同时是人工/AI 标注帧,会显示青色外框加红色内描边。 - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 - 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。 diff --git a/backend/routers/ai.py b/backend/routers/ai.py index 413716d..ea37ab5 100644 --- a/backend/routers/ai.py +++ b/backend/routers/ai.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import Session from database import get_db from minio_client import download_file -from models import Project, Frame, Template, Annotation +from models import Project, Frame, Template, Annotation, ProcessingTask from schemas import ( AiRuntimeStatus, MaskAnalysisRequest, @@ -21,10 +21,15 @@ from schemas import ( PredictResponse, PropagateRequest, PropagateResponse, + PropagateTaskRequest, + ProcessingTaskOut, AnnotationOut, AnnotationCreate, AnnotationUpdate, ) +from progress_events import publish_task_progress_event +from statuses import TASK_STATUS_QUEUED +from worker_tasks import propagate_project_masks from services.sam_registry import ModelUnavailableError, sam_registry logger = logging.getLogger(__name__) @@ -586,6 +591,66 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict: } +@router.post( + "/propagate/task", + status_code=status.HTTP_202_ACCEPTED, + response_model=ProcessingTaskOut, + summary="Queue a background video propagation task", +) +def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut: + """Queue multiple seed/direction propagation steps as one background task.""" + project = db.query(Project).filter(Project.id == payload.project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + source_frame = db.query(Frame).filter( + Frame.id == payload.frame_id, + Frame.project_id == payload.project_id, + ).first() + if not source_frame: + raise HTTPException(status_code=404, detail="Frame not found") + + if not payload.steps: + raise HTTPException(status_code=400, detail="Propagation task requires at least one step") + + try: + model_id = sam_registry.normalize_model_id(payload.model) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + for step in payload.steps: + direction = step.direction.lower() + if direction not in {"forward", "backward"}: + raise HTTPException(status_code=400, detail="direction must be forward or backward") + seed = step.seed.model_dump(exclude_none=True) + if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")): + raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points") + + task_payload = payload.model_dump(exclude_none=True) + task_payload["model"] = model_id + task = ProcessingTask( + task_type="propagate_masks", + status=TASK_STATUS_QUEUED, + progress=0, + message="自动传播任务已入队", + project_id=payload.project_id, + payload=task_payload, + ) + db.add(task) + db.commit() + db.refresh(task) + publish_task_progress_event(task) + + async_result = propagate_project_masks.delay(task.id) + task.celery_task_id = async_result.id + db.commit() + db.refresh(task) + publish_task_progress_event(task) + + logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id) + return task + + @router.post( "/auto", response_model=PredictResponse, diff --git a/backend/routers/dashboard.py b/backend/routers/dashboard.py index 5bb1dc2..62505a4 100644 --- a/backend/routers/dashboard.py +++ b/backend/routers/dashboard.py @@ -36,6 +36,7 @@ def _iso_or_none(value: datetime | None) -> str | None: def _task_payload(task: ProcessingTask) -> dict[str, Any]: + result = task.result or {} return { "id": f"task-{task.id}", "task_id": task.id, @@ -44,7 +45,7 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]: "progress": task.progress, "status": task.message or task.status, "raw_status": task.status, - "frame_count": (task.result or {}).get("frames_extracted", 0), + "frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)), "error": task.error, "updated_at": _iso_or_none(task.updated_at), } diff --git a/backend/routers/tasks.py b/backend/routers/tasks.py index 385bdc7..ff0e367 100644 --- a/backend/routers/tasks.py +++ b/backend/routers/tasks.py @@ -21,7 +21,7 @@ from statuses import ( TASK_STATUS_FAILED, TASK_STATUS_QUEUED, ) -from worker_tasks import parse_project_media +from worker_tasks import parse_project_media, propagate_project_masks router = APIRouter(prefix="/api/tasks", tags=["Tasks"]) logger = logging.getLogger(__name__) @@ -109,7 +109,8 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: project = db.query(Project).filter(Project.id == previous.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") - if not project.video_path: + is_propagation_task = previous.task_type == "propagate_masks" + if not is_propagation_task and not project.video_path: raise HTTPException(status_code=400, detail="Project has no media uploaded") payload = dict(previous.payload or {}) @@ -124,13 +125,14 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: project_id=project.id, payload=payload, ) - project.status = PROJECT_STATUS_PARSING + if not is_propagation_task: + project.status = PROJECT_STATUS_PARSING db.add(task) db.commit() db.refresh(task) publish_task_progress_event(task) - async_result = parse_project_media.delay(task.id) + async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id) task.celery_task_id = async_result.id db.commit() db.refresh(task) diff --git a/backend/schemas.py b/backend/schemas.py index 0e7b606..01b8212 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -218,6 +218,8 @@ class PropagationSeed(BaseModel): color: Optional[str] = None class_metadata: Optional[dict[str, Any]] = None template_id: Optional[int] = None + source_mask_id: Optional[str] = None + source_annotation_id: Optional[int] = None class PropagateRequest(BaseModel): @@ -240,6 +242,21 @@ class PropagateResponse(BaseModel): annotations: list[AnnotationOut] +class PropagateTaskStep(BaseModel): + seed: PropagationSeed + direction: str = "forward" + max_frames: int = 30 + + +class PropagateTaskRequest(BaseModel): + project_id: int + frame_id: int + model: Optional[str] = "sam2.1_hiera_tiny" + steps: list[PropagateTaskStep] + include_source: bool = False + save_annotations: bool = True + + class AiModelStatus(BaseModel): id: str label: str diff --git a/backend/services/propagation_task_runner.py b/backend/services/propagation_task_runner.py new file mode 100644 index 0000000..2d52d18 --- /dev/null +++ b/backend/services/propagation_task_runner.py @@ -0,0 +1,512 @@ +"""Background SAM video propagation runner used by Celery workers.""" + +import hashlib +import json +import logging +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from sqlalchemy.orm import Session + +from minio_client import download_file +from models import Annotation, Frame, ProcessingTask, Project +from progress_events import publish_task_progress_event +from services.sam_registry import ModelUnavailableError, sam_registry +from statuses import ( + TASK_STATUS_CANCELLED, + TASK_STATUS_FAILED, + TASK_STATUS_RUNNING, + TASK_STATUS_SUCCESS, +) + +logger = logging.getLogger(__name__) + + +class PropagationTaskCancelled(RuntimeError): + """Raised internally when a persisted propagation task has been cancelled.""" + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def _set_task_state( + db: Session, + task: ProcessingTask, + *, + status: str | None = None, + progress: int | None = None, + message: str | None = None, + result: dict[str, Any] | None = None, + error: str | None = None, + started: bool = False, + finished: bool = False, +) -> None: + if status is not None: + task.status = status + if progress is not None: + task.progress = max(0, min(100, progress)) + if message is not None: + task.message = message + if result is not None: + task.result = result + if error is not None: + task.error = error + if started: + task.started_at = _now() + if finished: + task.finished_at = _now() + db.commit() + db.refresh(task) + publish_task_progress_event(task) + + +def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None: + db.refresh(task) + if task.status == TASK_STATUS_CANCELLED: + raise PropagationTaskCancelled("Task was cancelled") + + +def _clamp01(value: float) -> float: + return min(max(float(value), 0.0), 1.0) + + +def _polygon_bbox(polygon: list[list[float]]) -> list[float]: + xs = [_clamp01(point[0]) for point in polygon] + ys = [_clamp01(point[1]) for point in polygon] + left, right = min(xs), max(xs) + top, bottom = min(ys), max(ys) + return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)] + + +def _stable_json(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + +def _seed_signature(seed: dict[str, Any]) -> str: + """Return a stable signature for seed geometry and semantic attrs.""" + signature_payload = { + "polygons": seed.get("polygons") or [], + "bbox": seed.get("bbox") or [], + "points": seed.get("points") or [], + "labels": seed.get("labels") or [], + "label": seed.get("label"), + "color": seed.get("color"), + "class_metadata": seed.get("class_metadata") or {}, + "template_id": seed.get("template_id"), + } + return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest() + + +def _seed_key(seed: dict[str, Any]) -> str: + """Prefer stable persisted ids; fall back to semantic attrs for legacy callers.""" + source_annotation_id = seed.get("source_annotation_id") + if source_annotation_id is not None: + return f"annotation:{source_annotation_id}" + source_mask_id = seed.get("source_mask_id") + if source_mask_id: + return f"mask:{source_mask_id}" + class_metadata = seed.get("class_metadata") or {} + class_id = class_metadata.get("id") or class_metadata.get("name") + return _stable_json({ + "template_id": seed.get("template_id"), + "class_id": class_id, + "label": seed.get("label"), + "color": seed.get("color"), + }) + + +def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool: + """Best-effort match for propagation annotations created before seed keys.""" + class_metadata = seed.get("class_metadata") or {} + previous_class = mask_data.get("class") or {} + previous_class_id = previous_class.get("id") or previous_class.get("name") + class_id = class_metadata.get("id") or class_metadata.get("name") + return ( + mask_data.get("label") == seed.get("label") + and mask_data.get("color") == seed.get("color") + and previous_class_id == class_id + ) + + +def _is_propagation_annotation( + annotation: Annotation, + model_id: str, + source_frame: Frame, + seed_key: str, + seed: dict[str, Any], +) -> bool: + mask_data = annotation.mask_data or {} + source = str(mask_data.get("source") or "") + if source != f"{model_id}_propagation": + return False + if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id): + return False + previous_seed_key = mask_data.get("propagation_seed_key") + if previous_seed_key is not None: + return previous_seed_key == seed_key + return _legacy_seed_matches(mask_data, seed) + + +def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool: + previous_direction = mask_data.get("propagation_direction") + return previous_direction in {None, direction} + + +def _prepare_seed_propagation( + db: Session, + *, + payload: dict[str, Any], + model_id: str, + source_frame: Frame, + seed: dict[str, Any], + direction: str, +) -> dict[str, Any]: + seed_key = _seed_key(seed) + seed_signature = _seed_signature(seed) + previous_annotations = ( + db.query(Annotation) + .filter(Annotation.project_id == int(payload["project_id"])) + .all() + ) + matching = [ + annotation for annotation in previous_annotations + if _is_propagation_annotation(annotation, model_id, source_frame, seed_key, seed) + and _direction_matches(annotation.mask_data or {}, direction) + ] + if matching and all((annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature for annotation in matching): + return { + "skip": True, + "seed_key": seed_key, + "seed_signature": seed_signature, + "deleted_annotation_count": 0, + } + + deleted_count = 0 + if matching: + for annotation in matching: + db.delete(annotation) + deleted_count += 1 + db.commit() + + return { + "skip": False, + "seed_key": seed_key, + "seed_signature": seed_signature, + "deleted_annotation_count": deleted_count, + } + + +def _frame_window( + frames: list[Frame], + source_position: int, + direction: str, + max_frames: int, +) -> tuple[list[Frame], int]: + count = max(1, min(max_frames, len(frames))) + if direction == "backward": + start = max(0, source_position - count + 1) + return frames[start:source_position + 1], source_position - start + end = min(len(frames), source_position + count) + return frames[source_position:end], 0 + + +def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]: + paths = [] + for index, frame in enumerate(frames): + data = download_file(frame.image_url) + # SAM2VideoPredictor sorts frames by converting the filename stem to int. + path = directory / f"{index:06d}.jpg" + path.write_bytes(data) + paths.append(str(path)) + return paths + + +def _save_propagated_annotations( + db: Session, + *, + payload: dict[str, Any], + selected_frames: list[Frame], + source_frame: Frame, + propagated: list[dict[str, Any]], + seed: dict[str, Any], +) -> list[Annotation]: + created: list[Annotation] = [] + if payload.get("save_annotations", True) is False: + return created + + class_metadata = seed.get("class_metadata") + template_id = seed.get("template_id") + label = seed.get("label") or "Propagated Mask" + color = seed.get("color") or "#06b6d4" + model_id = sam_registry.normalize_model_id(payload.get("model")) + include_source = bool(payload.get("include_source", False)) + seed_key = _seed_key(seed) + seed_signature = _seed_signature(seed) + source_annotation_id = seed.get("source_annotation_id") + source_mask_id = seed.get("source_mask_id") + direction = str(payload.get("current_direction") or "") + + for frame_result in propagated: + relative_index = int(frame_result.get("frame_index", -1)) + if relative_index < 0 or relative_index >= len(selected_frames): + continue + frame = selected_frames[relative_index] + if not include_source and frame.id == source_frame.id: + continue + result_polygons = frame_result.get("polygons") or [] + scores = frame_result.get("scores") or [] + for polygon_index, polygon in enumerate(result_polygons): + if len(polygon) < 3: + continue + annotation = Annotation( + project_id=int(payload["project_id"]), + frame_id=frame.id, + template_id=template_id, + mask_data={ + "polygons": [polygon], + "label": label, + "color": color, + "source": f"{model_id}_propagation", + "propagated_from_frame_id": source_frame.id, + "propagated_from_frame_index": source_frame.frame_index, + "propagation_seed_key": seed_key, + "propagation_seed_signature": seed_signature, + "propagation_direction": direction, + "source_annotation_id": source_annotation_id, + "source_mask_id": source_mask_id, + "score": scores[polygon_index] if polygon_index < len(scores) else None, + **({"class": class_metadata} if class_metadata else {}), + }, + points=None, + bbox=_polygon_bbox(polygon), + ) + db.add(annotation) + created.append(annotation) + + db.commit() + for annotation in created: + db.refresh(annotation) + return created + + +def _run_one_step( + db: Session, + *, + payload: dict[str, Any], + frames: list[Frame], + source_frame: Frame, + source_position: int, + step: dict[str, Any], +) -> dict[str, Any]: + direction = str(step.get("direction") or "forward").lower() + if direction not in {"forward", "backward"}: + raise ValueError("direction must be forward or backward") + max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500)) + seed = step.get("seed") or {} + if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")): + raise ValueError("Propagation requires seed polygons, bbox, or points") + + model_id = sam_registry.normalize_model_id(payload.get("model")) + seed_state = _prepare_seed_propagation( + db, + payload=payload, + model_id=model_id, + source_frame=source_frame, + seed=seed, + direction=direction, + ) + if seed_state["skip"]: + return { + "model": model_id, + "direction": direction, + "processed_frame_count": 0, + "created_annotation_count": 0, + "deleted_annotation_count": 0, + "skipped_seed_count": 1, + "seed_label": seed.get("label"), + "seed_key": seed_state["seed_key"], + } + + selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames) + with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir: + frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir)) + propagated = sam_registry.propagate_video( + model_id, + frame_paths, + source_relative_index, + seed, + direction, + len(selected_frames), + ) + + save_payload = {**payload, "current_direction": direction} + created = _save_propagated_annotations( + db, + payload=save_payload, + selected_frames=selected_frames, + source_frame=source_frame, + propagated=propagated, + seed=seed, + ) + return { + "model": model_id, + "direction": direction, + "processed_frame_count": len(selected_frames), + "created_annotation_count": len(created), + "deleted_annotation_count": int(seed_state["deleted_annotation_count"]), + "skipped_seed_count": 0, + "seed_label": seed.get("label"), + "seed_key": seed_state["seed_key"], + } + + +def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]: + """Run one queued SAM propagation task and update persisted progress.""" + task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() + if not task: + raise ValueError(f"Task not found: {task_id}") + + if task.status == TASK_STATUS_CANCELLED: + return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"} + + payload = task.payload or {} + project_id = int(payload.get("project_id") or task.project_id or 0) + source_frame_id = int(payload.get("frame_id") or 0) + try: + model_id = sam_registry.normalize_model_id(payload.get("model")) + except ValueError as exc: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) + raise + + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True) + raise ValueError(f"Project not found: {project_id}") + + source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first() + if not source_frame: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True) + raise ValueError(f"Frame not found: {source_frame_id}") + + frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all() + source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None) + if source_position is None: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True) + raise ValueError("Source frame is not in project frame sequence") + + steps = payload.get("steps") or [] + if not steps: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True) + raise ValueError("Propagation task has no steps") + + _ensure_not_cancelled(db, task) + _set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True) + + step_results: list[dict[str, Any]] = [] + created_count = 0 + processed_count = 0 + deleted_count = 0 + skipped_count = 0 + total_steps = len(steps) + + try: + for index, step in enumerate(steps, start=1): + _ensure_not_cancelled(db, task) + seed_label = (step.get("seed") or {}).get("label") or "mask" + direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播" + progress_before = 5 + int(((index - 1) / total_steps) * 90) + _set_task_state( + db, + task, + progress=progress_before, + message=f"{direction_label} {seed_label} ({index}/{total_steps})", + result={ + "project_id": project_id, + "source_frame_id": source_frame_id, + "model": model_id, + "total_steps": total_steps, + "completed_steps": index - 1, + "processed_frame_count": processed_count, + "created_annotation_count": created_count, + "deleted_annotation_count": deleted_count, + "skipped_seed_count": skipped_count, + "steps": step_results, + }, + ) + + result = _run_one_step( + db, + payload=payload, + frames=frames, + source_frame=source_frame, + source_position=source_position, + step=step, + ) + step_results.append(result) + created_count += int(result["created_annotation_count"]) + processed_count += int(result["processed_frame_count"]) + deleted_count += int(result.get("deleted_annotation_count") or 0) + skipped_count += int(result.get("skipped_seed_count") or 0) + _set_task_state( + db, + task, + progress=5 + int((index / total_steps) * 90), + message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})", + result={ + "project_id": project_id, + "source_frame_id": source_frame_id, + "model": model_id, + "total_steps": total_steps, + "completed_steps": index, + "processed_frame_count": processed_count, + "created_annotation_count": created_count, + "deleted_annotation_count": deleted_count, + "skipped_seed_count": skipped_count, + "steps": step_results, + }, + ) + + result = { + "project_id": project_id, + "source_frame_id": source_frame_id, + "model": model_id, + "total_steps": total_steps, + "completed_steps": total_steps, + "processed_frame_count": processed_count, + "created_annotation_count": created_count, + "deleted_annotation_count": deleted_count, + "skipped_seed_count": skipped_count, + "steps": step_results, + } + _set_task_state( + db, + task, + status=TASK_STATUS_SUCCESS, + progress=100, + message="自动传播完成" if created_count > 0 else ( + "自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask" + ), + result=result, + finished=True, + ) + return result + except PropagationTaskCancelled: + task.status = TASK_STATUS_CANCELLED + task.progress = 100 + task.message = task.message or "任务已取消" + task.error = task.error or "Cancelled by user" + task.finished_at = task.finished_at or _now() + db.commit() + db.refresh(task) + publish_task_progress_event(task) + return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message} + except (ModelUnavailableError, NotImplementedError, ValueError) as exc: + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) + raise + except Exception as exc: # noqa: BLE001 + logger.exception("Propagation task failed: task_id=%s", task.id) + _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) + raise diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index ae7cfff..485b108 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -1,5 +1,8 @@ import numpy as np import cv2 +from pathlib import Path +from models import Annotation, ProcessingTask +from services.propagation_task_runner import run_propagate_project_task def _create_project_and_frame(client): @@ -294,6 +297,245 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch): assert len(listing.json()) == 1 +def test_queue_propagation_task_creates_processing_task(client, monkeypatch): + project = client.post("/api/projects", json={"name": "Queued Propagation"}).json() + frame = client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": 0, + "image_url": "frames/0.jpg", + "width": 640, + "height": 360, + }).json() + + class FakeAsyncResult: + id = "celery-propagate-1" + + queued = [] + monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult()) + monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None) + + response = client.post("/api/ai/propagate/task", json={ + "project_id": project["id"], + "frame_id": frame["id"], + "model": "sam2.1_hiera_tiny", + "steps": [{ + "direction": "forward", + "max_frames": 2, + "seed": { + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + "label": "胆囊", + }, + }], + }) + + assert response.status_code == 202 + body = response.json() + assert body["task_type"] == "propagate_masks" + assert body["status"] == "queued" + assert body["celery_task_id"] == "celery-propagate-1" + assert body["payload"]["model"] == "sam2.1_hiera_tiny" + assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊" + assert queued == [body["id"]] + + +def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch): + project = client.post("/api/projects", json={"name": "Propagation Model"}).json() + frame = client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": 0, + "image_url": "frames/0.jpg", + "width": 640, + "height": 360, + }).json() + + class FakeAsyncResult: + id = "celery-propagate-model" + + monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult()) + monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None) + + response = client.post("/api/ai/propagate/task", json={ + "project_id": project["id"], + "frame_id": frame["id"], + "model": "sam2", + "steps": [{ + "direction": "forward", + "max_frames": 2, + "seed": { + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + }, + }], + }) + + assert response.status_code == 202 + assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny" + + unsupported = client.post("/api/ai/propagate/task", json={ + "project_id": project["id"], + "frame_id": frame["id"], + "model": "sam3", + "steps": [{ + "direction": "forward", + "max_frames": 2, + "seed": { + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + }, + }], + }) + + assert unsupported.status_code == 400 + assert "Unsupported model" in unsupported.json()["detail"] + + +def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch): + project = client.post("/api/projects", json={"name": "Propagation Worker"}).json() + frames = [ + client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": idx, + "image_url": f"frames/{idx}.jpg", + "width": 640, + "height": 360, + }).json() + for idx in range(2) + ] + task = ProcessingTask( + task_type="propagate_masks", + status="queued", + progress=0, + project_id=project["id"], + payload={ + "project_id": project["id"], + "frame_id": frames[0]["id"], + "model": "sam2.1_hiera_tiny", + "include_source": False, + "save_annotations": True, + "steps": [{ + "direction": "forward", + "max_frames": 2, + "seed": { + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + "label": "胆囊", + "color": "#ff0000", + "class_metadata": {"id": "c1", "name": "胆囊"}, + }, + }], + }, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + published = [] + monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") + monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress))) + def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): + assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"] + return [ + {"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]}, + {"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]}, + ] + + monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video) + + result = run_propagate_project_task(db_session, task.id) + + db_session.refresh(task) + assert task.status == "success" + assert task.progress == 100 + assert task.result["model"] == "sam2.1_hiera_tiny" + assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny" + assert result["created_annotation_count"] == 1 + assert result["processed_frame_count"] == 2 + assert published[0][0] == "running" + assert published[-1] == ("success", 100) + listing = client.get(f"/api/ai/annotations?project_id={project['id']}") + assert listing.json()[0]["frame_id"] == frames[1]["id"] + assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation" + + +def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch): + project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json() + frames = [ + client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": idx, + "image_url": f"frames/{idx}.jpg", + "width": 640, + "height": 360, + }).json() + for idx in range(2) + ] + + def make_task(seed_polygon): + task = ProcessingTask( + task_type="propagate_masks", + status="queued", + progress=0, + project_id=project["id"], + payload={ + "project_id": project["id"], + "frame_id": frames[0]["id"], + "model": "sam2.1_hiera_tiny", + "include_source": False, + "save_annotations": True, + "steps": [{ + "direction": "forward", + "max_frames": 2, + "seed": { + "polygons": [seed_polygon], + "label": "胆囊", + "color": "#ff0000", + "source_annotation_id": 7, + "source_mask_id": "annotation-7", + }, + }], + }, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + return task + + seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]] + first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]] + changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]] + replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]] + + monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") + monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) + propagate_calls = [] + + def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): + propagate_calls.append(seed["polygons"][0]) + output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon + return [ + {"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]}, + {"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]}, + ] + + monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video) + + first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id) + assert first_result["created_annotation_count"] == 1 + assert len(propagate_calls) == 1 + + unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id) + assert unchanged_result["created_annotation_count"] == 0 + assert unchanged_result["skipped_seed_count"] == 1 + assert len(propagate_calls) == 1 + assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1 + + changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id) + assert changed_result["created_annotation_count"] == 1 + assert changed_result["deleted_annotation_count"] == 1 + assert len(propagate_calls) == 2 + annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() + assert len(annotations) == 1 + assert annotations[0].mask_data["polygons"] == [replacement_output_polygon] + assert annotations[0].mask_data["source_annotation_id"] == 7 + + def test_predict_validation_errors(client, monkeypatch): project, _, _ = _create_project_and_frame(client) diff --git a/backend/tests/test_dashboard.py b/backend/tests/test_dashboard.py index e04c236..1bcf7c7 100644 --- a/backend/tests/test_dashboard.py +++ b/backend/tests/test_dashboard.py @@ -110,3 +110,31 @@ def test_dashboard_overview_keeps_recent_success_tasks_in_progress_list(client, "updated_at": body["tasks"][0]["updated_at"], }, ] + + +def test_dashboard_overview_uses_processed_frame_count_for_propagation_tasks(client, db_session): + from models import ProcessingTask + + project = client.post("/api/projects", json={ + "name": "Propagation Project", + "status": "ready", + }).json() + task = ProcessingTask( + task_type="propagate_masks", + status="running", + progress=45, + message="向后传播 胆囊 (1/2)", + project_id=project["id"], + payload={"project_id": project["id"]}, + result={"processed_frame_count": 8, "created_annotation_count": 3}, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + response = client.get("/api/dashboard/overview") + + assert response.status_code == 200 + body = response.json() + assert body["tasks"][0]["task_id"] == task.id + assert body["tasks"][0]["frame_count"] == 8 diff --git a/backend/tests/test_tasks.py b/backend/tests/test_tasks.py index 482cd1f..e9b9d9f 100644 --- a/backend/tests/test_tasks.py +++ b/backend/tests/test_tasks.py @@ -84,6 +84,42 @@ def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch): assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing" +def test_retry_task_dispatches_propagation_worker_without_media_requirement(client, db_session, monkeypatch): + project = client.post("/api/projects", json={"name": "Retry Propagation"}).json() + task = ProcessingTask( + task_type="propagate_masks", + status="failed", + progress=100, + message="自动传播失败", + error="model unavailable", + project_id=project["id"], + payload={ + "project_id": project["id"], + "frame_id": 1, + "steps": [], + }, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + class FakeAsyncResult: + id = "celery-propagation-retry" + + queued = [] + monkeypatch.setattr("routers.tasks.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult()) + monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: None) + + response = client.post(f"/api/tasks/{task.id}/retry") + + assert response.status_code == 202 + body = response.json() + assert body["task_type"] == "propagate_masks" + assert body["celery_task_id"] == "celery-propagation-retry" + assert queued == [body["id"]] + assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending" + + def test_task_actions_reject_invalid_states(client, db_session): project = client.post("/api/projects", json={ "name": "Done", diff --git a/backend/worker_tasks.py b/backend/worker_tasks.py index 2f06b4c..5ef3319 100644 --- a/backend/worker_tasks.py +++ b/backend/worker_tasks.py @@ -5,6 +5,7 @@ import logging from celery_app import celery_app from database import SessionLocal from services.media_task_runner import run_parse_media_task +from services.propagation_task_runner import run_propagate_project_task logger = logging.getLogger(__name__) @@ -20,3 +21,16 @@ def parse_project_media(task_id: int) -> dict: raise exc finally: db.close() + + +@celery_app.task(name="ai.propagate_project") +def propagate_project_masks(task_id: int) -> dict: + """Run SAM video propagation for one queued task.""" + db = SessionLocal() + try: + return run_propagate_project_task(db, task_id) + except Exception as exc: # noqa: BLE001 + logger.exception("Propagation task failed: task_id=%s", task_id) + raise exc + finally: + db.close() diff --git a/doc/03-frontend-element-audit.md b/doc/03-frontend-element-audit.md index ac43b4a..b17ffd9 100644 --- a/doc/03-frontend-element-audit.md +++ b/doc/03-frontend-element-audit.md @@ -49,16 +49,18 @@ | 导入视频文件 | 真实可用 | 创建项目、上传源视频、刷新项目列表;不会自动拆帧 | | 生成帧按钮 | 真实可用 | 仅对已导入源视频且尚无帧、非 parsing 状态的项目显示,调用 `parseMedia(projectId, { parseFps })` | | 生成帧 FPS 滑块 | 真实可用 | 值传入 `/api/media/parse?parse_fps=...`,决定后台拆帧目标 FPS | +| 项目卡片 FPS 徽标 | 真实可用 | 右上角显示关键帧序列目标 `parse_fps`;原始视频帧率只在卡片底部以“原 xx fps”显示 | | 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 | | 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 | | 删除项目按钮 | 真实可用 | 点击垃圾桶按钮会确认删除,调用 `DELETE /api/projects/{id}`,成功后从项目库移除;若删除的是当前项目,会清空工作区当前项目、帧、mask 和选区 | -| alert 成功/失败提示 | 真实可用但粗糙 | 使用浏览器 `alert` | +| 操作成功/失败提示 | 真实可用 | 使用非阻塞 `TransientNotice` 浮层,自动消失,不会拦截后续按钮、输入框或画布操作 | ## 工作区 VideoWorkspace | 元素 | 状态 | 说明 | |------|------|------| | 当前项目名 | 真实可用 | 读取 `currentProject.name` | +| 顶栏操作提示 | 真实可用 | 保存、导出、传播范围选择等短反馈会自动消失;保存/导出/传播进行中和无帧项目提示会保留到状态变化 | | 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` | | 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 | | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前启用的 SAM 2 与 GPU 状态 | @@ -66,16 +68,16 @@ | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON | | “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` | | “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 | -| 参考帧/起止帧/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,单个“自动传播”按钮会按 seed mask 和前/后方向顺序调用 `POST /api/ai/propagate`,当前启用 SAM 2 video predictor,完成后刷新已保存标注 | +| 参考帧/起止帧/传播权重/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;工作区顶栏有独立“传播权重”下拉,可在传播前二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,不影响 AI 智能分割页的单帧推理权重选择;如果用户尚未显式设置范围,点击“自动传播”会先进入时间轴范围选择模式,播放进度条和视频处理进度条都可点击/拖拽回填传播起始帧和传播结束帧,再点击“开始传播”提交;用户也可直接改数字框后点击按钮传播。提交后前端把传播权重 id、seed mask、seed 来源 id 和前/后方向步骤提交到 `POST /api/ai/propagate/task`,后端先规范化/校验权重 id,再创建 `processing_tasks` 并由 Celery 执行对应 SAM 2.1 video predictor;worker 会按 seed 来源和几何/语义签名做幂等判断,未改变的 seed 直接跳过,已改变的 seed 会先删除同源旧自动传播标注再重新传播,避免重复传播产生重叠 mask;传播中顶栏显示任务进度、已处理帧次、删除旧区域数和已保存区域数,前端轮询 `GET /api/tasks/{task_id}` 并刷新已保存标注;任务可取消,若完成后 0 个新区域会明确提示没有生成新 mask 或已跳过未改变 mask | | “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}`;保存成功后会重新拉取后端标注,并用 saved annotation 替换本次提交的 draft mask,避免仍显示未保存 | ## CanvasArea 画布 | 元素 | 状态 | 说明 | |------|------|------| -| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL | +| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL;切换帧或容器尺寸变化时会按 86% 适配比例居中放大显示,默认留出画布边距,不铺满整个画布 | | 滚轮缩放 | 真实可用 | 改变 Konva Stage scale | -| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable | +| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable,拖拽结束会回写 React position state,避免 Konva 节点位置和前端状态脱节 | | 光标坐标显示 | 真实可用 | 根据 pointer position 计算 | | 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 | | 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 | @@ -100,17 +102,18 @@ | 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;布尔选择态中第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,避免主区域和扣除区域看起来像随机阴影差异;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口;点击工作区内已有 SAM 提示点会优先删除该提示点并重新推理,不会冒泡成新增提示点或 mask 选择 | | 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 | -| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y | +| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工作区顶栏按钮、工具栏按钮、AI 页按钮和快捷键 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z`、`Ctrl/Cmd+Y`;输入框聚焦时不拦截快捷键 | +| 紧凑/滚动布局 | 真实可用 | 工具按钮使用较紧凑的垂直间距;左侧高度不足时工具栏自身出现纵向滚动,不挤压画布 | ## FrameTimeline 时间轴 | 元素 | 状态 | 说明 | |------|------|------| | 帧缩略图 | 真实可用 | 使用 `frames[].url` | -| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` | +| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)`;非当前帧中,人工/AI 标注帧使用红色边框,自动传播/推理帧使用蓝色边框;当前帧仍用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边,避免状态颜色互相覆盖 | | 顶部 range 拖动 | 真实可用 | 改变当前帧 | | 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` | -| 自动传播帧进度条标记 | 真实可用 | 根据已保存标注回显的 `mask_data.source` / `propagated_from_frame_id` 识别自动传播生成的帧,并在顶部进度条对应帧区段覆盖浅蓝色;当前帧位置由播放进度条末端、时间提示和缩略图高亮表达 | +| 播放进度条 / 视频处理进度条 | 真实可用 | 播放进度条位于上方,视频处理进度条位于下方;视频处理进度条普通状态下可点击跳转到对应帧;根据已保存标注回显的 `mask_data.source` / `propagated_from_frame_id` 识别自动传播生成的帧并显示蓝色区段,人工绘制或 AI 智能分割生成的帧显示红色竖线,红/蓝标识也可点击跳转到对应帧;未处理背景使用中性灰以和红/蓝标记区分;只有工作区进入自动传播范围选择模式时,两条进度条才显示 amber 选区,并可点击/拖拽选择起止帧 | | 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps | | 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 | @@ -144,6 +147,7 @@ | 删除选中候选 | 真实可用 | 删除 AI 页当前选中的本页候选 mask;不会删除工作区已有 mask,Delete/Backspace 也遵循同一范围 | | 清空全体锚点 | 真实可用 | 清空 AI 页提示点和本页生成的候选 mask,不删除工作区已有 mask | | 背景图 / 空状态 | 真实可用 | 优先显示当前项目帧;没有项目帧时显示空状态提示,不再回退到外部演示图片 | +| AI 画布初始视图 | 真实可用 | 当前帧在 AI 画布中默认居中,并按 86% 适配比例尽量放大但保留边距 | ## TemplateRegistry 模板库 diff --git a/doc/04-api-contracts.md b/doc/04-api-contracts.md index 51e3642..4dc29fc 100644 --- a/doc/04-api-contracts.md +++ b/doc/04-api-contracts.md @@ -38,7 +38,8 @@ Authorization: Bearer | `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 | | `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url,以及 `timestamp_ms`、`source_frame_number` | | `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` | -| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 | +| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 单 seed 同步传播接口,供后端兼容和测试使用 | +| `queuePropagationTask(payload)` | `POST /api/ai/propagate/task` | 对齐 | 工作区自动传播入口;创建 Celery 后台任务并由任务表/进度流追踪 | | `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU 和四个 SAM 2.1 变体状态;`selected_model=sam3` 返回不支持 | | `analyzeMask(mask, frame, options?)` | `POST /api/ai/analyze-mask` | 对齐 | 后端计算选中 mask 的置信度来源、拓扑锚点数量、面积和 bbox | | `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 | @@ -78,7 +79,8 @@ Authorization: Bearer | POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 | | POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 | | POST | `/api/ai/predict` | 当前启用 SAM 2 点/框/interactive 推理 | -| POST | `/api/ai/propagate` | 当前启用 SAM 2 视频片段传播并保存标注 | +| POST | `/api/ai/propagate` | 当前启用 SAM 2 单 seed 同步视频片段传播并保存标注 | +| POST | `/api/ai/propagate/task` | 创建 SAM 2 自动传播后台任务;payload 可包含多个 seed/direction step | | POST | `/api/ai/analyze-mask` | 分析前端选中 mask 的后端几何属性和拓扑锚点 | | GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 | | POST | `/api/ai/auto` | 自动分割 | @@ -230,7 +232,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同 ### 视频片段传播请求体 -后端接口仍以单个 seed 为单位。工作区前端当前只提供一个“自动传播”按钮:当前打开帧作为参考帧,该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,前端会在本地把多个 seed 或前后双向范围拆成多次顺序调用,避免同时启动多个视频 tracker。 +`POST /api/ai/propagate` 仍是单 seed 同步接口。工作区实际使用 `POST /api/ai/propagate/task`:当前打开帧作为参考帧,该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,前端会在本地把多个 seed 或前后双向范围拆成 `steps`,一次提交为 `propagate_masks` 后台任务,避免长 HTTP 请求和多个视频 tracker 并发抢占 GPU。 单次调用示例: @@ -254,7 +256,31 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同 } ``` -SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` 会兼容归一化为 tiny,`model=sam3` 当前不支持。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `_propagation`,前端回显时会把该字段保留到 `Mask.metadata`,用于把自动传播帧在时间进度条上标为浅蓝色。 +后台任务调用示例: + +```json +{ + "project_id": 1, + "frame_id": 123, + "model": "sam2.1_hiera_tiny", + "include_source": false, + "save_annotations": true, + "steps": [ + { + "direction": "forward", + "max_frames": 30, + "seed": { + "polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], + "label": "胆囊", + "color": "#ff0000" + } + } + ] +} +``` + +SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` 会兼容归一化为 tiny,`model=sam3` 当前不支持。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `_propagation`,前端回显时会把该字段保留到 `Mask.metadata`,用于在视频处理进度条上把自动传播帧显示为蓝色区段。 +后台任务入队接口会先规范化/校验 `model` 字段中的 SAM 2.1 权重 id,再把规范化后的权重 id 写入 `processing_tasks.payload.model`;前端 seed 会携带 `source_mask_id` 和可用时的 `source_annotation_id`,worker 保存传播结果时会写入 `propagation_seed_key`、`propagation_seed_signature` 和 `propagation_direction`。同一 seed、同一权重、同一方向再次传播时,如果签名未变化,worker 会跳过该 seed;如果签名变化,worker 会先删除旧自动传播标注再保存新结果。任务运行中/完成后会写入 `processing_tasks.result.model`、`completed_steps`、`processed_frame_count`、`created_annotation_count`、`deleted_annotation_count`、`skipped_seed_count` 和每个 step 的权重/方向/数量结果;前端通过 `GET /api/tasks/{task_id}` 轮询,Dashboard 同时可通过 Redis/WebSocket 进度流显示该任务。 ## 已完成的接口对齐 @@ -270,6 +296,7 @@ SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` - `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。 - `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`。 - `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。 +- `queuePropagationTask()` 已接入 `/api/ai/propagate/task`,自动传播不再依赖长时间同步 HTTP 请求。 - `getTask()` 已接入 `GET /api/tasks/{taskId}`。 - `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`。 - `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`。 diff --git a/doc/07-current-requirements-freeze.md b/doc/07-current-requirements-freeze.md index 96033e7..a4577cb 100644 --- a/doc/07-current-requirements-freeze.md +++ b/doc/07-current-requirements-freeze.md @@ -48,10 +48,11 @@ - 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧。 - Canvas 显示当前帧图片。 - Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。 -- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 +- 时间轴支持缩略图点击切帧、range 拖动切帧、视频处理进度条点击切帧、人工/AI 标注帧和自动传播帧标识点击切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 - 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。 - 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps` 或 `original_fps`,格式为 `mm:ss.cc`。 -- 时间轴根据已保存标注回显的传播来源字段,把自动传播生成的帧在顶部进度条对应区段标为浅蓝色;不再使用竖线模式标记已编辑帧,当前帧位置由播放进度条末端、时间提示和缩略图高亮表达。 +- 时间轴顶部播放进度条只表达当前播放位置;其下方的视频处理进度条表达处理状态:人工绘制或 AI 智能分割生成的帧显示红色竖线,自动传播生成的帧显示蓝色区段,未处理背景使用中性灰以和标记保持明显区分。底部帧可视化栏中,人工/AI 标注帧缩略图边框为红色,自动传播/推理帧缩略图边框为蓝色,当前帧仍用青色外框高亮优先;如果当前帧同时是人工/AI 标注帧,则显示青色外框加红色内描边。 +- 自动传播提交前支持独立选择传播权重,范围限定为 SAM 2.1 tiny/small/base+/large 四个权重变体;该选择只影响传播任务,不提供 SAM2/SAM3 家族切换,也不改变 AI 智能分割页的单帧推理权重。 ## R5 工具栏 @@ -101,9 +102,11 @@ - 工作区传播功能以当前打开帧作为参考帧,并使用该帧全部 mask 作为 seed;用户不再选择“选中区域/当前帧全部”传播对象。 - 工作区传播功能允许设置传播起始帧和传播结束帧;前端以当前参考帧为 seed,只向起止范围内位于参考帧之前和之后的帧传播,源帧不重复保存。 - 工作区只保留一个“自动传播”按钮,点击后在指定范围内按前向/后向自动生成 mask。 -- 前端复用单 seed 后端接口;多个 seed 或双向范围会被拆成多次顺序调用 `POST /api/ai/propagate`,避免并发抢占 GPU。 -- `POST /api/ai/propagate` 当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。 -- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `_propagation`,并保留 label、color 和 class 元数据。 +- 前端会把多个 seed 或双向范围拆成 `steps`,通过 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务,避免长 HTTP 请求卡在浏览器侧,同时避免并发抢占 GPU。 +- `POST /api/ai/propagate` 作为单 seed 同步兼容接口保留;`POST /api/ai/propagate/task` 是工作区自动传播使用的任务接口。两者当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。 +- 自动传播任务写入 `processing_tasks`,前端轮询 `GET /api/tasks/{task_id}` 显示进度并刷新标注;Dashboard 也能看到该任务,任务可取消和重试。 +- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `_propagation`,并保留 label、color、class 元数据、seed 来源 id、seed 签名和传播方向。 +- 自动传播任务必须避免重复叠加:同一参考 seed、同一权重、同一方向且 seed 签名未变化时,worker 直接跳过;同一参考 seed 已变化时,worker 先删除对应旧自动传播标注,再保存新传播结果。 - AI 页面会对未放置点提示、后端错误和返回 0 个 mask 的情况显示明确反馈。 - AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。 - 后端返回 `polygons` 和 `scores`。 diff --git a/doc/08-current-design-freeze.md b/doc/08-current-design-freeze.md index 2ec2741..7b3c2c7 100644 --- a/doc/08-current-design-freeze.md +++ b/doc/08-current-design-freeze.md @@ -29,11 +29,13 @@ | 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、删除、导入视频/DICOM、显式生成帧 | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | -| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | -| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、自动传播帧浅蓝区段标记、左右方向键切帧、播放和当前/总时长显示 | +| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做;紧凑垂直布局,高度不足时自身滚动 | +| 工作区顶栏 | `src/components/VideoWorkspace.tsx` | 保存/导出/传播/导入 GT、显式撤销/重做按钮和工作区快捷键 | +| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、播放进度、视频处理进度条、自动传播范围选择、左右方向键切帧、播放和当前/总时长显示 | | 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、后端自定义分类、mask 后端属性分析 | | AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 | | 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 | +| 短提示浮层 | `src/components/TransientNotice.tsx` | 项目库和模板库的非阻塞成功/失败提示,自动消失 | ## 后端模块 @@ -49,6 +51,7 @@ | Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 | | Media | `backend/routers/media.py` | 上传媒体和拆帧 | | AI | `backend/routers/ai.py` | 当前启用 SAM 2 推理、视频传播、模型状态和标注保存 | +| 传播任务 | `backend/services/propagation_task_runner.py` | Celery 中执行自动传播 steps,写任务进度并保存传播标注 | | Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 | | SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 | | SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | 历史保留的 SAM 3 桥接源码和脚本;当前未接入 registry | @@ -85,7 +88,8 @@ 5. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。 6. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。 7. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 -8. 刷新项目列表。 +8. 刷新项目列表;项目卡片右上角 FPS 徽标显示生成关键帧序列时选择的 `parse_fps`,原始视频 FPS 仅作为底部“原 xx fps”辅助信息显示。 +9. 导入视频、生成帧、上传 DICOM 和失败反馈使用 `TransientNotice`,不再使用浏览器 `alert()` 阻塞操作;提示默认数秒后自动消失。 ### 任务控制 @@ -103,8 +107,10 @@ 3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。 4. 工作区调用 `GET /api/ai/annotations` 回显已保存标注时,会替换当前项目帧中的已保存 mask,但保留没有 `annotationId` 的未保存 draft mask;这保证 AI 页推送到工作区的候选 mask 不会被异步回显覆盖,并会在合并完成后恢复仍然存在的已选 mask id。 5. `CanvasArea` 会把全局 `selectedMaskIds` 中仍存在于当前帧的 id 同步回本地选区,避免帧初始化时的临时清空覆盖 AI 页推送过来的选中态。 -6. `FrameTimeline` 根据已保存标注回显到 `Mask.metadata` 的 `source` / `propagated_from_frame_id` 计算自动传播生成的帧,并在顶部时间进度条对应帧区段覆盖浅蓝色;当前帧不额外渲染竖线,由播放进度条末端、时间提示和缩略图高亮表达。 -7. 当前帧传入 `CanvasArea`。 +6. `CanvasArea` 根据容器和帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距;滚轮缩放和拖拽平移仍由用户后续控制。 +7. `FrameTimeline` 顶部播放进度条显示当前播放位置;其下方视频处理进度条根据 `Mask.metadata.source` / `propagated_from_frame_id` 计算自动传播帧并显示蓝色区段,对人工绘制或 AI 智能分割等非传播 mask 帧显示红色竖线。普通状态下,视频处理进度条可点击跳转到对应帧,红色人工/AI 标注帧和蓝色自动传播帧标识本身也可点击跳转。处理条未处理背景使用中性灰,和红色/蓝色标记保持明显区分。底部缩略图导航轴对非当前帧使用红色边框标识人工/AI 标注帧,使用蓝色边框标识自动传播/推理帧;当前帧使用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则以青色外框加红色内描边同时表达两个状态。工作区只有进入自动传播范围选择模式时,播放进度条和视频处理进度条才显示 amber 覆盖层,并可点击/拖拽设置传播起止帧。 +8. 当前帧传入 `CanvasArea`。 +9. 工作区顶栏短状态文本会在空闲状态下自动消失;保存、导出、导入 GT 和传播任务运行中仍保留进度状态,无帧项目提示也会保留。 ### AI 点/框推理 @@ -126,8 +132,9 @@ 16. AI 页面参数开关文案只做展示增强:“局部专注模式(自动裁剪无锚区域)”仍控制 `cropMode/crop_to_prompt`,“严格除杂模式(自动清理干涉点)”仍控制 `autoDeleteBg/auto_filter_background/min_score`。 17. AI 页面“遮罩清晰度”滑杆只调节候选 mask 的 Konva preview opacity,不写入 `Mask.segmentation`、分类元数据或后端 payload。 18. AI 画布左上角根据正向点、反向点、边界框选和视口控制显示上下文提示,说明点击/拖拽、删除提示点和执行推理的操作方式。 -19. Canvas 按当前帧过滤并渲染 mask。 -19. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex`、`metadata.source=ai_segmentation` 和保存状态 `draft`。 +19. AI 画布根据容器和当前帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距。 +20. Canvas 按当前帧过滤并渲染 mask。 +21. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex`、`metadata.source=ai_segmentation` 和保存状态 `draft`。 20. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;保存成功后本次提交的 draft mask id 会从本地保留列表中排除,并由后端 saved annotation 回显替换。 21. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 22. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 @@ -135,15 +142,17 @@ ### 视频片段传播 1. 用户在工作区打开一帧作为参考帧;该帧全部 mask 都会作为传播 seed,不再提供传播对象下拉。 -2. 用户设置传播起始帧和传播结束帧,并点击唯一的“自动传播”按钮。 +2. 用户可以直接修改传播起始帧/结束帧数字框,并可通过工作区顶栏“传播权重”下拉独立选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重;该入口不提供 SAM2/SAM3 家族切换,默认跟随全局 AI 权重,用户手动选择后不再被 AI 页权重切换覆盖。 3. `VideoWorkspace` 以当前参考帧为 seed,将起止帧拆成 `backward` 和/或 `forward` 两段;只包含当前帧时不传播。 -4. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。 -5. 前端对每个 seed、每个方向顺序调用 `POST /api/ai/propagate`,`include_source=false`、`save_annotations=true`;顺序调用是为了避免多个视频 tracker 并发抢占 GPU。 -6. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。 -7. `model` 为任一 SAM 2.1 变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播。 -8. `model=sam3` 当前不支持;SAM 3 video tracker 代码保留但没有接入产品路径。 -9. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。 -10. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注;`annotationToMask()` 会保留传播来源 metadata,供时间轴浅蓝色进度条区段显示。 +4. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color、class 元数据、`source_mask_id` 和可用时的 `source_annotation_id`。 +5. 前端把传播权重 id、每个 seed、每个方向组装成 `steps`,一次调用 `POST /api/ai/propagate/task`,`include_source=false`、`save_annotations=true`;接口先规范化/校验 `model` 字段中的权重 id,再创建 `processing_tasks.task_type=propagate_masks` 并投递 Celery,避免长 HTTP 请求阻塞前端等待。 +6. `VideoWorkspace` 记录返回的 `task_id`,轮询 `GET /api/tasks/{task_id}` 显示任务 message、步骤进度、已处理帧次和已保存区域数;任务运行期间提供取消传播按钮,调用通用 `POST /api/tasks/{task_id}/cancel`。 +7. Celery worker 逐 step 顺序执行传播,避免多个视频 tracker 并发抢占 GPU;每个 step 开始/完成都会写入 `processing_tasks.progress/result/message` 并发布 Redis `seg:progress`,Dashboard 可同步显示。每个 step 开始前,worker 会用 seed 来源 id、规范化权重 id、传播方向和 seed 签名查找旧传播标注:签名相同则跳过该 seed;签名不同则先删除对应方向的旧自动传播标注,再执行新的 video predictor 传播。 +8. 后端按项目帧序列截取片段,下载对应帧到临时目录,并写成 `000000.jpg` 这类纯数字文件名;这是 `SAM2VideoPredictor` 对视频帧排序的要求,和项目库中持久化的 `frame_%06d.jpg` 对象名无关。 +9. `model` 为任一 SAM 2.1 权重变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播;`model=sam2` 会在入队时规范化为 tiny,任务 payload/result 会保留规范化后的权重 id;单个 SAM2 video predictor 调用内部暂不提供逐帧流式进度。 +10. `model=sam3` 当前不支持;SAM 3 video tracker 代码保留但没有接入产品路径。 +11. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录权重传播来源,同时写入 `propagation_seed_key`、`propagation_seed_signature`、`propagation_direction`、`source_annotation_id` 和 `source_mask_id` 供后续幂等传播判断。 +12. 前端轮询到已创建区域后刷新 `GET /api/ai/annotations` 并回显新标注;任务结束后如果后端返回 0 个新区域,工作区会明确提示没有生成新的 mask,若是未改变 seed 被跳过则提示未改变 mask 已跳过。`annotationToMask()` 会保留传播来源 metadata,供时间轴视频处理进度条显示蓝色传播区段。 ### 手工绘制与历史栈 @@ -154,7 +163,7 @@ 5. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。 6. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。 7. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。 -8. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。 +8. 工具栏按钮、工作区顶栏按钮和 AI 页按钮调用 `undoMasks()` / `redoMasks()`;工作区由 `VideoWorkspace` 统一处理 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z` 和 `Ctrl/Cmd+Y`,并在输入框、下拉框和可编辑文本聚焦时跳过快捷键,避免影响帧范围输入。 ### Polygon 逐点编辑 @@ -203,6 +212,7 @@ 11. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。 12. 同一次点击会把这些已选 mask 移动到前端 `masks` 数组末尾;`CanvasArea` 按数组顺序渲染,后渲染的 Path 显示在最上层,方便用户继续编辑刚换标签的区域。该显示置顶不改变模板 `zIndex` 或后端导出语义覆盖规则。 13. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。 +14. 模板保存、删除和 JSON 导入失败使用 `TransientNotice` 非阻塞提示,默认数秒后自动消失。 ### 导出 @@ -222,7 +232,8 @@ - `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。 - `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。 - `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。 -- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`。 +- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`,作为单 seed 同步兼容接口保留。 +- `queuePropagationTask()` 使用 `POST /api/ai/propagate/task`,请求体为 `project_id`、`frame_id`、`model`、`steps`、`include_source`、`save_annotations`,返回 `ProcessingTask`。 - `saveAnnotation()` 使用 `POST /api/ai/annotate`。 - `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。 - `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。 @@ -235,7 +246,7 @@ - SAM 2.1 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。 - SAM 3 前端入口、后端 registry 入口和状态展示均已禁用;`model=sam3` 会返回不支持。 - 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 -- 后端 `/api/ai/propagate` 当前支持所选 SAM 2.1 mask seed 视频传播;当前前端默认向后传播 30 帧并保存结果标注。 +- 后端 `/api/ai/propagate/task` 当前支持所选 SAM 2.1 mask seed 视频传播后台任务;同步 `/api/ai/propagate` 仍保留为单 seed 兼容接口。 - 后端 `/api/ai/models/status` 返回 GPU 和四个 SAM 2.1 变体的真实运行状态。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 diff --git a/doc/09-test-plan.md b/doc/09-test-plan.md index 35ac5ef..c76eba5 100644 --- a/doc/09-test-plan.md +++ b/doc/09-test-plan.md @@ -16,12 +16,12 @@ |------|----------|--------| | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R2 项目管理 | `src/lib/api.test.ts`, `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、项目卡片删除、DELETE 契约、后端 CRUD、删除级联、帧列表 | -| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | -| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、回显已保存标注时保留本地未保存 draft mask、缩略图/range/自动传播帧浅蓝进度条区段标记、当前帧由进度条末端和缩略图高亮表达/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | -| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合、polygon 顶点直接拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区 SAM 提示点点击删除且不冒泡新增点、撤销/重做历史栈 | -| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | +| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `src/components/TransientNotice.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、项目卡片显示目标 parse_fps 而非原视频 FPS、扩展名校验、自动建项目、关联项目、创建异步任务、非阻塞自动消失操作提示、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | +| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、工作区短状态自动消失、工作区/AI 画布底图默认居中且保留边距、回显已保存标注时保留本地未保存 draft mask、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、缩略图红/蓝边框、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条和视频处理进度条选择传播范围、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | +| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/store/useStore.test.ts` | 工具切换、工具栏紧凑垂直布局和高度不足时滚动、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合、polygon 顶点直接拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区 SAM 提示点点击删除且不冒泡新增点、工作区顶栏撤销/重做按钮、撤销/重做快捷键和输入框快捷键跳过、撤销/重做历史栈 | +| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播创建 Celery 任务、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名 metadata、未改变 seed 跳过、已改变 seed 先删旧自动传播标注再重传、传播中轮询任务进度、传播任务取消/重试、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、保存后用后端 saved annotation 替换已提交 draft、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | -| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD | +| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/components/TransientNotice.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 解包/打包、后端模板 CRUD | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts`, `backend/tests/test_ai.py` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签并移动到前端渲染最上层、自定义分类 PATCH 后端模板、选中 mask 后端属性分析、重新提取拓扑锚点 | | R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动进度区、最近完成任务保留显示、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | @@ -34,13 +34,13 @@ |------|--------|----------|----------| | R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 | | R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 | -| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 | -| R4 | 工作区加载帧、无帧项目不自动解析、后端标注回显保留本地未保存 draft mask、Canvas 底图、缩略图/range/自动传播帧浅蓝进度条区段标记、当前帧由进度条末端和缩略图高亮表达/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | -| R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | -| R5 | 顶点直接拖动编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区 SAM 提示点删除优先级、撤销/重做、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | -| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py` | 已覆盖 | +| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、项目卡片 FPS 徽标显示 `parse_fps`、视频/DICOM 拆帧任务、非阻塞自动消失操作提示、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_media.py`, `test_tasks.py` | 已覆盖 | +| R4 | 工作区加载帧、无帧项目不自动解析、工作区短状态自动消失、后端标注回显保留本地未保存 draft mask、Canvas/AI 底图居中适配且保留边距、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、缩略图红/蓝边框、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条/视频处理进度条拖拽选择传播范围、Canvas/AI 画布拖拽平移回写 position state、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx` | 已覆盖 | +| R5 | 工具切换、工具栏紧凑滚动布局、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | +| R5 | 顶点直接拖动编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区 SAM 提示点删除优先级、工作区顶栏撤销/重做按钮、撤销/重做快捷键、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `VideoWorkspace.test.tsx`, `useStore.test.ts` | 已覆盖 | +| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播 Celery 任务入队、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名 metadata、未改变 seed 跳过、已改变 seed 先删旧自动传播标注再重传、前端任务轮询进度、传播任务 runner 保存标注和结果权重 id、传播任务重试、传播空结果提示、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_tasks.py`, `test_sam2_engine.py` | 已覆盖 | | R7 | 保存、保存后替换已提交 draft、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 | -| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | +| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | | R9 | 模板选择、分类展示、分类选择、已选 mask 换标签并置顶显示、自定义分类写入后端模板、后端属性分析、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts`, `test_ai.py` | 已覆盖 | | R10 | Dashboard 概览、任务进度区、最近完成任务保留显示、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | | R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 | @@ -56,7 +56,10 @@ - R6:补充 `ModelStatusBadge.test.tsx` 中 SAM 3 不展示测试,避免禁用入口重新出现在前端。 - R6:补充后端 `selected_model=sam3` 拒绝测试和 semantic 禁用测试,避免后端继续暴露 SAM 3 产品能力。 - R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。 -- R6:补充 `propagateMasks()` API 封装和 `VideoWorkspace` 自动传播按钮测试,验证当前参考帧全部 mask 会按范围发送到后端视频传播接口。 +- R6:补充 `propagateMasks()` 同步兼容接口和 `queuePropagationTask()` 任务接口测试,验证当前参考帧全部 mask 会按范围组装为后台传播 steps。 +- R6:补充 `VideoWorkspace` 自动传播进度测试,验证传播任务运行中显示进度,后端返回 0 个新区域时给出明确反馈。 +- R4/R6:补充时间轴传播范围选择测试,验证点击“自动传播”后可在播放进度条或视频处理进度条上拖拽回填起止帧,再提交后台传播任务。 +- R6/R10:补充 `queuePropagationTask()`、`POST /api/ai/propagate/task`、传播 Celery runner 和传播任务重试测试,验证工作区自动传播不再依赖长 HTTP 请求,并验证传给 `SAM2VideoPredictor` 的临时帧文件名是纯数字序列。 - R6:`backend/tests/test_sam3_engine.py` 已标记跳过,仅作为历史保留实现的参考测试,不计入当前产品功能覆盖。 - R3:补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps`、`max_frames`、`target_width` 会进入任务。 - R3:补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms`、`source_frame_number` 和 `result.frame_sequence` 元数据。 diff --git a/doc/10-installation.md b/doc/10-installation.md new file mode 100644 index 0000000..de213f0 --- /dev/null +++ b/doc/10-installation.md @@ -0,0 +1,490 @@ +# Installation / 部署安装指南 + +本文件记录当前仓库的真实安装和部署方式。它面向一台新的 Linux 机器,目标是跑起完整系统: + +- React 前端:默认 `http://localhost:3000` +- FastAPI 后端:默认 `http://localhost:8000` +- PostgreSQL:项目、帧、模板、标注、任务元数据 +- Redis:Celery broker/result backend 与进度 pub/sub +- MinIO:视频、DICOM、拆帧图片等对象存储 +- Celery worker:执行视频/DICOM 拆帧等后台任务 +- SAM 2.1:当前产品启用 tiny/small/base+/large;SAM 3 源码保留但产品入口禁用,正常部署不需要安装 SAM 3 + +--- + +## 1. 前置条件 + +推荐环境: + +| 项 | 建议 | +|----|------| +| OS | Ubuntu 22.04 LTS 或相近 Linux | +| Python | 3.11 | +| Node.js | 22.x | +| 数据库 | PostgreSQL 14+ | +| 缓存/队列 | Redis 6+ | +| 对象存储 | MinIO | +| 视频处理 | FFmpeg | +| GPU | NVIDIA GPU + CUDA,用于 SAM 2.1 推理;无 GPU 时可 CPU 运行但会很慢 | + +安装系统依赖: + +```bash +sudo apt update +sudo apt install -y \ + postgresql postgresql-contrib \ + redis-server \ + ffmpeg \ + libpq-dev build-essential curl ca-certificates gnupg wget +``` + +安装 MinIO: + +```bash +cd /tmp +wget https://dl.min.io/server/minio/release/linux-amd64/minio +chmod +x minio +sudo mv minio /usr/local/bin/ +mkdir -p ~/minio_data +``` + +--- + +## 2. 获取代码 + +```bash +cd ~/Desktop +git clone Seg_Server +cd Seg_Server +``` + +如果已经有仓库,进入项目根目录即可: + +```bash +cd /home/wkmgc/Desktop/Seg_Server +``` + +后续命令默认在项目根目录执行,除非特别说明。 + +--- + +## 3. 配置 PostgreSQL + +默认后端配置来自 `backend/config.py`: + +```text +postgresql://seguser:segpass123@localhost:5432/segserver +``` + +创建数据库和用户: + +```bash +sudo systemctl start postgresql + +sudo -u postgres psql -c "CREATE DATABASE segserver;" +sudo -u postgres psql -c "CREATE USER seguser WITH PASSWORD 'segpass123';" +sudo -u postgres psql -c "GRANT ALL PRIVILEGES ON DATABASE segserver TO seguser;" +sudo -u postgres psql -d segserver -c "GRANT ALL ON SCHEMA public TO seguser;" +sudo -u postgres psql -c "ALTER DATABASE segserver OWNER TO seguser;" +``` + +验收: + +```bash +pg_isready +psql "postgresql://seguser:segpass123@localhost:5432/segserver" -c "select 1;" +``` + +--- + +## 4. 启动 Redis 和 MinIO + +Redis: + +```bash +sudo systemctl start redis-server +redis-cli ping +``` + +MinIO: + +```bash +nohup minio server ~/minio_data --console-address :9001 > /tmp/minio.log 2>&1 & +curl http://localhost:9000/minio/health/live +``` + +默认 MinIO 账号密码是: + +```text +minioadmin / minioadmin +``` + +后端启动时会检查并创建 bucket: + +```text +seg-media +``` + +--- + +## 5. 安装后端 Python 环境 + +推荐使用 Conda: + +```bash +conda create -n seg_server python=3.11 -y +conda activate seg_server +``` + +安装 PyTorch。根据机器 CUDA 版本选择合适 wheel。示例: + +```bash +# CUDA 12.4 示例 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 + +# 无 GPU / CPU 示例 +# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +``` + +安装后端依赖: + +```bash +cd backend +pip install -r requirements.txt +cd .. +``` + +确认关键包: + +```bash +python - <<'PY' +import fastapi, sqlalchemy, redis, celery, minio, torch +print("torch:", torch.__version__, "cuda:", torch.cuda.is_available()) +PY +``` + +--- + +## 6. 配置后端环境变量 + +后端从 `backend/.env` 读取配置;该文件被 `.gitignore` 忽略,不要提交真实密码或本机路径。 + +创建 `backend/.env`: + +```bash +cat > backend/.env <<'EOF' +db_url=postgresql://seguser:segpass123@localhost:5432/segserver +redis_url=redis://localhost:6379/0 + +minio_endpoint=localhost:9000 +minio_access_key=minioadmin +minio_secret_key=minioadmin +minio_secure=false + +sam_default_model=sam2.1_hiera_tiny +sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt +sam_model_config=configs/sam2.1/sam2.1_hiera_t.yaml + +app_env=development +cors_origins=["http://localhost:3000","http://127.0.0.1:3000"] +EOF +``` + +如果前端通过局域网 IP 访问,例如 `http://192.168.3.11:3000`,需要把该地址加入 `cors_origins`,同时前端也要配置 API 地址。 + +--- + +## 7. 准备 SAM 2.1 权重 + +当前产品入口只暴露 SAM 2.1 变体: + +- `sam2.1_hiera_tiny` +- `sam2.1_hiera_small` +- `sam2.1_hiera_base_plus` +- `sam2.1_hiera_large` + +下载脚本: + +```bash +cd backend +python download_sam2.py +cd .. +``` + +脚本默认下载到: + +```text +/home/wkmgc/Desktop/Seg_Server/models/ +``` + +推荐文件名: + +```text +models/sam2.1_hiera_tiny.pt +models/sam2.1_hiera_small.pt +models/sam2.1_hiera_base_plus.pt +models/sam2.1_hiera_large.pt +``` + +可以只部署 tiny;前端会显示四个选项,但只有本地存在 checkpoint 的模型会显示可用。 + +注意:SAM 3 相关脚本和源码是历史保留。当前前端入口隐藏 SAM 3,后端 registry 不暴露 `sam3`,正常部署不需要下载 SAM 3 权重,也不要把 Hugging Face token 写进项目文件。 + +--- + +## 8. 安装前端依赖 + +```bash +npm install +``` + +如需指定前端访问的后端地址,在项目根目录创建 `.env`: + +```bash +cat > .env <<'EOF' +VITE_API_BASE_URL=http://localhost:8000 +VITE_WS_PROGRESS_URL=ws://localhost:8000/ws/progress +EOF +``` + +如果不设置,前端会按当前浏览器 hostname 推导: + +```text +http://:8000 +ws://:8000/ws/progress +``` + +--- + +## 9. 手动启动所有服务 + +开 4 个终端分别启动。 + +终端 A:FastAPI 后端 + +```bash +conda activate seg_server +cd /home/wkmgc/Desktop/Seg_Server/backend +uvicorn main:app --host 0.0.0.0 --port 8000 --reload +``` + +终端 B:Celery worker + +```bash +conda activate seg_server +cd /home/wkmgc/Desktop/Seg_Server/backend +celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 +``` + +终端 C:前端开发服务 + +```bash +cd /home/wkmgc/Desktop/Seg_Server +npm run dev +``` + +终端 D:确认基础设施 + +```bash +pg_isready +redis-cli ping +curl http://localhost:9000/minio/health/live +``` + +访问: + +| 服务 | 地址 | +|------|------| +| 前端 | `http://localhost:3000` | +| FastAPI Docs | `http://localhost:8000/docs` | +| Health | `http://localhost:8000/health` | +| MinIO Console | `http://localhost:9001` | + +默认开发登录: + +```text +admin / 123456 +``` + +--- + +## 10. 一键启动脚本 + +项目根目录有 `start_services.sh`: + +```bash +chmod +x start_services.sh +./start_services.sh +``` + +脚本会检查/启动: + +```text +PostgreSQL -> Redis -> MinIO -> FastAPI -> Celery worker -> 前端 +``` + +使用前必须检查脚本里的本机路径和 sudo 逻辑: + +- `PROJECT_DIR="/home/wkmgc/Desktop/Seg_Server"` +- `CONDA_ENV="seg_server"` +- MinIO 数据目录 `/home/wkmgc/minio_data` +- 脚本里包含本机 sudo 密码写法,迁移机器时应移除或改成安全的 systemd/service 管理方式 + +--- + +## 11. 生产构建方式 + +前端构建: + +```bash +npm run build +``` + +生产模式启动前端静态服务: + +```bash +NODE_ENV=production npm start +``` + +后端生产启动示例: + +```bash +cd backend +uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +Celery worker 仍需要单独启动: + +```bash +cd backend +celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 +``` + +实际生产建议用 systemd、supervisor 或容器编排托管 FastAPI、Celery、前端静态服务、MinIO、Redis、PostgreSQL。 + +--- + +## 12. 部署验收 Checklist + +基础服务: + +```bash +pg_isready +redis-cli ping +curl http://localhost:9000/minio/health/live +curl http://localhost:8000/health +``` + +后端模型状态: + +```bash +curl http://localhost:8000/api/ai/models/status +``` + +前端质量检查: + +```bash +npm run lint +npm run test:run +npm run build +``` + +后端测试: + +```bash +conda activate seg_server +python -m pytest backend/tests +``` + +手工业务验收: + +1. 打开 `http://localhost:3000`。 +2. 使用 `admin / 123456` 登录。 +3. 创建项目或上传视频。 +4. 在项目库点击“生成帧”,选择 FPS。 +5. Dashboard 中应看到任务进度;Celery 日志应显示拆帧任务。 +6. 进入分割工作区,能看到帧、时间轴和画布。 +7. 手工画一个多边形 mask,点击“结构化归档保存”。 +8. 刷新工作区后,已保存标注应回显。 +9. AI 智能分割中选择可用 SAM 2.1 模型,放置点或框,执行分割。 +10. 导出 JSON 或 PNG Mask ZIP。 + +--- + +## 13. 常见问题 + +### 前端打不开或请求后端失败 + +检查: + +```bash +curl http://localhost:8000/health +cat .env +``` + +如果通过局域网 IP 访问前端,确保: + +- `.env` 中 `VITE_API_BASE_URL` 是浏览器可访问的后端地址。 +- `backend/.env` 中 `cors_origins` 包含前端地址。 + +### Dashboard WebSocket 经常断开 + +检查: + +```bash +redis-cli ping +curl http://localhost:8000/health +``` + +同时确认前端 `VITE_WS_PROGRESS_URL` 指向真实可访问的: + +```text +ws://:8000/ws/progress +``` + +### 生成帧没有进度 + +检查 Celery worker 是否启动: + +```bash +ps aux | grep celery +tail -f /tmp/celery.log +``` + +检查 Redis: + +```bash +redis-cli ping +``` + +### MinIO 上传失败 + +检查: + +```bash +curl http://localhost:9000/minio/health/live +tail -f /tmp/minio.log +``` + +如果磁盘空间不足,MinIO 可能拒绝写入。清理 `~/minio_data`、旧日志、旧模型权重或迁移数据目录。 + +### SAM 2 模型不可用 + +检查: + +```bash +ls -lh models/ +curl http://localhost:8000/api/ai/models/status +``` + +常见原因: + +- checkpoint 文件不存在。 +- `backend/.env` 中 `sam_model_path` 指向旧文件名。 +- `sam2` Python 包未正确安装。 +- PyTorch/CUDA 不匹配。 + +### 不需要 SAM 3 + +当前版本不用 SAM 3。不要为了正常部署执行 `backend/setup_sam3_env.sh`,也不要在项目里保存 Hugging Face token。 + diff --git a/doc/README.md b/doc/README.md index f323266..7105344 100644 --- a/doc/README.md +++ b/doc/README.md @@ -20,6 +20,7 @@ | [07-current-requirements-freeze.md](./07-current-requirements-freeze.md) | 当前版本需求冻结,测试以此为准 | | [08-current-design-freeze.md](./08-current-design-freeze.md) | 当前版本设计冻结,记录模块、数据流和接口边界 | | [09-test-plan.md](./09-test-plan.md) | 需求到测试文件的覆盖矩阵和运行命令 | +| [10-installation.md](./10-installation.md) | 系统安装部署指南,覆盖 PostgreSQL、Redis、MinIO、后端、Celery、前端和 SAM 2.1 权重 | ## 状态标记 diff --git a/src/components/AISegmentation.test.tsx b/src/components/AISegmentation.test.tsx index 5a526da..6537c50 100644 --- a/src/components/AISegmentation.test.tsx +++ b/src/components/AISegmentation.test.tsx @@ -169,6 +169,26 @@ describe('AISegmentation', () => { })); }); + it('handles stage drag end for move-tool canvas panning', () => { + render(); + + expect(screen.getByTestId('konva-stage')).toHaveAttribute('data-has-drag-end', 'true'); + }); + + it('centers the active frame with a large default fit inside the AI canvas', async () => { + Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, get: () => 1000 }); + Object.defineProperty(HTMLElement.prototype, 'clientHeight', { configurable: true, get: () => 700 }); + + render(); + + await waitFor(() => { + const stage = screen.getByTestId('konva-stage'); + expect(Number(stage.getAttribute('data-scale-x'))).toBeCloseTo(1.34375, 4); + expect(Number(stage.getAttribute('data-x'))).toBeCloseTo(70, 0); + expect(Number(stage.getAttribute('data-y'))).toBeCloseTo(108, 0); + }); + }); + it('combines the AI page box prompt with later positive and negative refinement points', async () => { apiMock.predictMask.mockResolvedValueOnce({ masks: [] }); render(); diff --git a/src/components/AISegmentation.tsx b/src/components/AISegmentation.tsx index d39d52e..3c18d6e 100644 --- a/src/components/AISegmentation.tsx +++ b/src/components/AISegmentation.tsx @@ -1,4 +1,4 @@ -import React, { useState, useCallback, useEffect } from 'react'; +import React, { useState, useCallback, useEffect, useRef } from 'react'; import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Undo, Redo, Loader2, XCircle, Trash2 } from 'lucide-react'; import { cn } from '../lib/utils'; import { Stage, Layer, Image as KonvaImage, Circle, Path, Group, Rect } from 'react-konva'; @@ -14,8 +14,10 @@ interface AISegmentationProps { type PromptPoint = { x: number; y: number; type: 'pos' | 'neg' }; type PromptBox = { x1: number; y1: number; x2: number; y2: number }; type ToolHint = { title: string; body: string }; +const DEFAULT_IMAGE_FIT_RATIO = 0.86; export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { + const canvasContainerRef = useRef(null); const storeActiveTool = useStore((state) => state.activeTool); const setActiveTool = useStore((state) => state.setActiveTool); const masks = useStore((state) => state.masks); @@ -42,8 +44,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const [aiMaskIds, setAiMaskIds] = useState([]); // Canvas state + const [stageSize, setStageSize] = useState({ width: 800, height: 600 }); const [scale, setScale] = useState(1); const [position, setPosition] = useState({ x: 0, y: 0 }); + const lastAutoFitKeyRef = useRef(''); const [points, setPoints] = useState([]); const [promptBox, setPromptBox] = useState(null); const [boxStart, setBoxStart] = useState<{ x: number; y: number } | null>(null); @@ -59,6 +63,42 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const modelCanInfer = selectedModelStatus?.available ?? true; const effectiveTool = storeActiveTool; + + useEffect(() => { + const handleResize = () => { + if (!canvasContainerRef.current) return; + setStageSize({ + width: canvasContainerRef.current.clientWidth, + height: canvasContainerRef.current.clientHeight, + }); + }; + + handleResize(); + window.addEventListener('resize', handleResize); + return () => window.removeEventListener('resize', handleResize); + }, []); + + useEffect(() => { + if (!currentFrame?.id || stageSize.width <= 0 || stageSize.height <= 0) return; + const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0; + const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0; + if (imageWidth <= 0 || imageHeight <= 0) return; + + const fitKey = `${currentFrame.id}:${stageSize.width}x${stageSize.height}:${imageWidth}x${imageHeight}`; + if (lastAutoFitKeyRef.current === fitKey) return; + lastAutoFitKeyRef.current = fitKey; + + const nextScale = Math.max( + 0.05, + Math.min(stageSize.width / imageWidth, stageSize.height / imageHeight) * DEFAULT_IMAGE_FIT_RATIO, + ); + setScale(nextScale); + setPosition({ + x: (stageSize.width - imageWidth * nextScale) / 2, + y: (stageSize.height - imageHeight * nextScale) / 2, + }); + }, [currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, stageSize.height, stageSize.width]); + const toolHint = React.useMemo(() => { if (!currentFrame) return null; if (effectiveTool === 'point_pos') { @@ -146,6 +186,14 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { }); }; + const handleStageDragEnd = (e: any) => { + const stage = e.target; + setPosition({ + x: stage.x(), + y: stage.y(), + }); + }; + const handleMouseMove = (e: any) => { const stage = e.target.getStage(); if (!stage) return; @@ -557,7 +605,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
-
+
{!currentFrame && (
请先在项目库选择项目并生成帧 @@ -570,8 +618,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
)} {/* Background Image */} diff --git a/src/components/CanvasArea.test.tsx b/src/components/CanvasArea.test.tsx index e2a5190..70792f4 100644 --- a/src/components/CanvasArea.test.tsx +++ b/src/components/CanvasArea.test.tsx @@ -292,6 +292,26 @@ describe('CanvasArea', () => { expect(screen.getByText('遮罩数: 1')).toBeInTheDocument(); }); + it('handles stage drag end when the move tool pans the canvas', () => { + render(); + + expect(screen.getByTestId('konva-stage')).toHaveAttribute('data-has-drag-end', 'true'); + }); + + it('centers the frame image with a large default fit that keeps margins', async () => { + Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, get: () => 1000 }); + Object.defineProperty(HTMLElement.prototype, 'clientHeight', { configurable: true, get: () => 700 }); + + render(); + + await waitFor(() => { + const stage = screen.getByTestId('konva-stage'); + expect(Number(stage.getAttribute('data-scale-x'))).toBeCloseTo(1.34375, 4); + expect(Number(stage.getAttribute('data-x'))).toBeCloseTo(70, 0); + expect(Number(stage.getAttribute('data-y'))).toBeCloseTo(108, 0); + }); + }); + it('publishes the selected mask ids for the ontology panel', async () => { useStore.setState({ masks: [ @@ -663,7 +683,10 @@ describe('CanvasArea', () => { expect(selectedPaths[0]).toHaveAttribute('data-stroke', '#facc15'); expect(selectedPaths[0]).toHaveAttribute('data-dash', ''); expect(selectedPaths[1]).toHaveAttribute('data-stroke', '#fb7185'); - expect(selectedPaths[1]).toHaveAttribute('data-dash', '6,4'); + const scale = Number(screen.getByTestId('konva-stage').getAttribute('data-scale-x')) || 1; + const dash = selectedPaths[1].getAttribute('data-dash')?.split(',').map(Number); + expect(dash?.[0]).toBeCloseTo(6 / scale, 4); + expect(dash?.[1]).toBeCloseTo(4 / scale, 4); fireEvent.click(screen.getByRole('button', { name: '从主区域去除' })); expect(useStore.getState().masks).toHaveLength(2); diff --git a/src/components/CanvasArea.tsx b/src/components/CanvasArea.tsx index a319bb9..b9a42d4 100644 --- a/src/components/CanvasArea.tsx +++ b/src/components/CanvasArea.tsx @@ -24,6 +24,7 @@ const EDIT_POLYGON_TOOL = 'edit_polygon'; const POINT_TOOL = 'create_point'; const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']); const POLYGON_CLOSE_RADIUS = 8; +const DEFAULT_IMAGE_FIT_RATIO = 0.86; function clamp(value: number, min: number, max: number): number { return Math.min(Math.max(value, min), max); @@ -245,6 +246,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota const previousFrameIdRef = useRef(frame?.id); const [isInferencing, setIsInferencing] = useState(false); const [inferenceMessage, setInferenceMessage] = useState(''); + const lastAutoFitKeyRef = useRef(''); const masks = useStore((state) => state.masks); const addMask = useStore((state) => state.addMask); @@ -256,8 +258,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota const aiModel = useStore((state) => state.aiModel); const activeTemplateId = useStore((state) => state.activeTemplateId); const activeClass = useStore((state) => state.activeClass); - const undoMasks = useStore((state) => state.undoMasks); - const redoMasks = useStore((state) => state.redoMasks); const effectiveTool = activeTool || storeActiveTool; @@ -374,6 +374,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota return () => window.removeEventListener('resize', handleResize); }, []); + useEffect(() => { + if (!frame?.id || stageSize.width <= 0 || stageSize.height <= 0) return; + const imageWidth = frame.width || image?.naturalWidth || image?.width || 0; + const imageHeight = frame.height || image?.naturalHeight || image?.height || 0; + if (imageWidth <= 0 || imageHeight <= 0) return; + + const fitKey = `${frame.id}:${stageSize.width}x${stageSize.height}:${imageWidth}x${imageHeight}`; + if (lastAutoFitKeyRef.current === fitKey) return; + lastAutoFitKeyRef.current = fitKey; + + const nextScale = Math.max( + 0.05, + Math.min(stageSize.width / imageWidth, stageSize.height / imageHeight) * DEFAULT_IMAGE_FIT_RATIO, + ); + setScale(nextScale); + setPosition({ + x: (stageSize.width - imageWidth * nextScale) / 2, + y: (stageSize.height - imageHeight * nextScale) / 2, + }); + }, [frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, stageSize.height, stageSize.width]); + useEffect(() => { setManualStart(null); setManualCurrent(null); @@ -458,6 +479,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota }); }; + const handleStageDragEnd = (e: any) => { + const stage = e.target; + setPosition({ + x: stage.x(), + y: stage.y(), + }); + }; + const stagePoint = (e: any): CanvasPoint | null => { const stage = e.target.getStage(); const relPos = stage?.getRelativePointerPosition(); @@ -839,18 +868,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { - const key = event.key.toLowerCase(); - if ((event.metaKey || event.ctrlKey) && key === 'z') { - event.preventDefault(); - if (event.shiftKey) redoMasks(); - else undoMasks(); - return; - } - if ((event.metaKey || event.ctrlKey) && key === 'y') { - event.preventDefault(); - redoMasks(); - return; - } if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) { const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex); if (currentPoints.length > 3) { @@ -880,7 +897,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota window.addEventListener('keydown', handleKeyDown); return () => window.removeEventListener('keydown', handleKeyDown); - }, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]); + }, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, updatePolygonMask]); const boxRect = React.useMemo(() => { if (!boxStart || !boxCurrent) return null; @@ -1080,6 +1097,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota x={position.x} y={position.y} draggable={effectiveTool === 'move'} + onDragEnd={handleStageDragEnd} onClick={handleStageClick} > diff --git a/src/components/Dashboard.tsx b/src/components/Dashboard.tsx index cb6c56d..9bf6a1d 100644 --- a/src/components/Dashboard.tsx +++ b/src/components/Dashboard.tsx @@ -42,7 +42,7 @@ export function Dashboard() { status: task.message || task.status, raw_status: task.status, error: task.error, - frame_count: Number(task.result?.frames_extracted || 0), + frame_count: Number(task.result?.frames_extracted || task.result?.processed_frame_count || 0), updated_at: task.updated_at, }); diff --git a/src/components/FrameTimeline.test.tsx b/src/components/FrameTimeline.test.tsx index 0f3ef03..201ae42 100644 --- a/src/components/FrameTimeline.test.tsx +++ b/src/components/FrameTimeline.test.tsx @@ -51,7 +51,7 @@ describe('FrameTimeline', () => { expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0); }); - it('marks propagated frames as light-blue progress bar segments', () => { + it('renders a processing progress bar with red annotation markers and blue propagation segments', () => { useStore.setState({ currentFrameIndex: 1, frames: [ @@ -76,13 +76,162 @@ describe('FrameTimeline', () => { render(); - expect(screen.getByText('自动传播 1 帧')).toBeInTheDocument(); + expect(screen.getByLabelText('视频处理进度条')).toBeInTheDocument(); + expect(screen.getByText('人工/AI 1 帧 · 自动传播 1 帧')).toBeInTheDocument(); expect(screen.queryByTestId('current-frame-line')).not.toBeInTheDocument(); expect(screen.getAllByTestId('propagated-frame-segment')).toHaveLength(1); - expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-sky-200'); + expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-blue-500'); + expect(screen.getAllByTestId('annotated-frame-marker')).toHaveLength(1); + expect(screen.getByTestId('annotated-frame-marker').className).toContain('bg-red-500'); expect(screen.queryByLabelText('跳转到已编辑帧 3')).not.toBeInTheDocument(); }); + it('jumps from the processing progress bar and frame status markers', () => { + useStore.setState({ + frames: [ + { id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 }, + { id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 }, + { id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 }, + { id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 }, + { id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 }, + ], + masks: [ + { id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' }, + { + id: 'm2', + frameId: 'f4', + pathData: 'M 0 0 Z', + label: 'Tracked', + color: '#3b82f6', + metadata: { source: 'sam2.1_hiera_tiny_propagation' }, + }, + ], + }); + + render(); + + const processingBar = screen.getByLabelText('视频处理进度条'); + vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({ + left: 0, + right: 100, + top: 0, + bottom: 10, + width: 100, + height: 10, + x: 0, + y: 0, + toJSON: () => ({}), + }); + fireEvent.pointerDown(processingBar, { clientX: 50, pointerId: 1 }); + expect(useStore.getState().currentFrameIndex).toBe(2); + + fireEvent.click(screen.getByRole('button', { name: '跳转到人工/AI 标注帧 2' })); + expect(useStore.getState().currentFrameIndex).toBe(1); + + fireEvent.click(screen.getByRole('button', { name: '跳转到自动传播帧 4' })); + expect(useStore.getState().currentFrameIndex).toBe(3); + }); + + it('hides the propagation range overlay until range selection is active', () => { + useStore.setState({ + frames: [ + { id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 }, + { id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 }, + { id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 }, + ], + }); + + render(); + + expect(screen.queryByTestId('propagation-range-overlay')).not.toBeInTheDocument(); + }); + + it('uses red thumbnail borders for manual or AI frames and blue for propagated frames', () => { + useStore.setState({ + frames: [ + { id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 }, + { id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 }, + { id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 }, + ], + masks: [ + { id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' }, + { + id: 'm2', + frameId: 'f3', + pathData: 'M 0 0 Z', + label: 'Tracked', + color: '#3b82f6', + metadata: { propagated_from_frame_id: 'f1' }, + }, + ], + }); + + render(); + + expect(screen.getByAltText('frame-0').closest('div')?.className).toContain('border-cyan-500'); + expect(screen.getByAltText('frame-1').closest('div')?.className).toContain('border-red-500'); + expect(screen.getByAltText('frame-2').closest('div')?.className).toContain('border-blue-500'); + }); + + it('keeps the current frame blue border while showing an inner red ring for annotated frames', () => { + useStore.setState({ + currentFrameIndex: 1, + frames: [ + { id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 }, + { id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 }, + ], + masks: [ + { id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' }, + ], + }); + + render(); + + const currentAnnotatedTile = screen.getByAltText('frame-1').closest('div'); + expect(currentAnnotatedTile?.className).toContain('border-cyan-500'); + expect(currentAnnotatedTile?.className).toContain('inset_0_0_0_2px_rgba(239,68,68,0.95)'); + }); + + it('selects a propagation range from the playback and processing progress bars', () => { + const onPropagationRangeChange = vi.fn(); + useStore.setState({ + frames: [ + { id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 }, + { id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 }, + { id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 }, + { id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 }, + { id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 }, + ], + }); + + render( + , + ); + + const playbackBar = screen.getByTestId('playback-progress-bar'); + vi.spyOn(playbackBar, 'getBoundingClientRect').mockReturnValue({ + left: 0, + right: 100, + top: 0, + bottom: 10, + width: 100, + height: 10, + x: 0, + y: 0, + toJSON: () => ({}), + }); + fireEvent.pointerDown(playbackBar, { clientX: 25, pointerId: 1 }); + fireEvent.pointerMove(playbackBar, { clientX: 75, pointerId: 1 }); + fireEvent.pointerUp(playbackBar, { clientX: 75, pointerId: 1 }); + + expect(onPropagationRangeChange).toHaveBeenLastCalledWith(2, 4); + expect(screen.getAllByTestId('propagation-range-overlay')).toHaveLength(2); + }); + it('changes frames with left and right arrow keys without leaving bounds', () => { useStore.setState({ currentFrameIndex: 1, diff --git a/src/components/FrameTimeline.tsx b/src/components/FrameTimeline.tsx index a097820..ef4b4be 100644 --- a/src/components/FrameTimeline.tsx +++ b/src/components/FrameTimeline.tsx @@ -3,13 +3,29 @@ import { Play, Pause } from 'lucide-react'; import { cn } from '../lib/utils'; import { useStore } from '../store/useStore'; -export function FrameTimeline() { +interface FrameTimelineProps { + propagationRange?: { + startFrame: number; + endFrame: number; + }; + propagationRangeSelectionActive?: boolean; + propagationRangeDisabled?: boolean; + onPropagationRangeChange?: (startFrame: number, endFrame: number) => void; +} + +export function FrameTimeline({ + propagationRange, + propagationRangeSelectionActive = false, + propagationRangeDisabled = false, + onPropagationRangeChange, +}: FrameTimelineProps = {}) { const frames = useStore((state) => state.frames); const currentProject = useStore((state) => state.currentProject); const currentFrameIndex = useStore((state) => state.currentFrameIndex); const masks = useStore((state) => state.masks); const setCurrentFrame = useStore((state) => state.setCurrentFrame); const [isPlaying, setIsPlaying] = useState(false); + const [rangeDragAnchorFrame, setRangeDragAnchorFrame] = useState(null); const totalFrames = frames.length; const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0; @@ -23,21 +39,42 @@ export function FrameTimeline() { }, [currentProject?.original_fps, currentProject?.parse_fps]); const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0; const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0; + const isPropagatedMask = (mask: (typeof masks)[number]) => { + const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : ''; + return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined; + }; const propagatedFrameMarkers = useMemo(() => { const frameIds = new Set(frames.map((frame) => frame.id)); const propagatedIds = new Set( masks .filter((mask) => frameIds.has(mask.frameId)) - .filter((mask) => { - const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : ''; - return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined; - }) + .filter(isPropagatedMask) .map((mask) => mask.frameId), ); return frames .map((frame, index) => ({ frame, index })) .filter(({ frame }) => propagatedIds.has(frame.id)); }, [frames, masks]); + const propagatedFrameIds = useMemo( + () => new Set(propagatedFrameMarkers.map(({ frame }) => frame.id)), + [propagatedFrameMarkers], + ); + const annotatedFrameMarkers = useMemo(() => { + const frameIds = new Set(frames.map((frame) => frame.id)); + const annotatedIds = new Set( + masks + .filter((mask) => frameIds.has(mask.frameId)) + .filter((mask) => !isPropagatedMask(mask)) + .map((mask) => mask.frameId), + ); + return frames + .map((frame, index) => ({ frame, index })) + .filter(({ frame }) => annotatedIds.has(frame.id)); + }, [frames, masks]); + const annotatedFrameIds = useMemo( + () => new Set(annotatedFrameMarkers.map(({ frame }) => frame.id)), + [annotatedFrameMarkers], + ); const formatTime = (seconds: number) => { const safeSeconds = Math.max(0, seconds); @@ -47,6 +84,80 @@ export function FrameTimeline() { return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`; }; + const clampFrame = (frame: number) => Math.min(Math.max(frame, 1), Math.max(totalFrames, 1)); + const normalizeRange = (startFrame: number, endFrame: number) => ({ + startFrame: Math.min(clampFrame(startFrame), clampFrame(endFrame)), + endFrame: Math.max(clampFrame(startFrame), clampFrame(endFrame)), + }); + const selectedRange = propagationRange + ? normalizeRange(propagationRange.startFrame, propagationRange.endFrame) + : null; + const visibleSelectedRange = propagationRangeSelectionActive ? selectedRange : null; + const rangeLeft = visibleSelectedRange && totalFrames > 0 ? ((visibleSelectedRange.startFrame - 1) / totalFrames) * 100 : 0; + const rangeWidth = visibleSelectedRange && totalFrames > 0 + ? ((visibleSelectedRange.endFrame - visibleSelectedRange.startFrame + 1) / totalFrames) * 100 + : 0; + + const frameFromPointerEvent = (event: React.PointerEvent) => { + const rect = event.currentTarget.getBoundingClientRect(); + const ratio = rect.width > 0 ? (event.clientX - rect.left) / rect.width : 0; + return clampFrame(Math.round(Math.min(Math.max(ratio, 0), 1) * Math.max(totalFrames - 1, 0)) + 1); + }; + + const jumpToFrame = (frame: number) => { + if (totalFrames === 0) return; + setIsPlaying(false); + setCurrentFrame(clampFrame(frame) - 1); + }; + + const updatePropagationRangeFromPointer = ( + event: React.PointerEvent, + anchorFrame = rangeDragAnchorFrame, + ) => { + if (!propagationRangeSelectionActive || propagationRangeDisabled || totalFrames === 0 || !onPropagationRangeChange) return; + const frame = frameFromPointerEvent(event); + const startFrame = anchorFrame ?? frame; + const nextRange = normalizeRange(startFrame, frame); + onPropagationRangeChange(nextRange.startFrame, nextRange.endFrame); + }; + + const handleRangePointerDown = (event: React.PointerEvent) => { + if (!propagationRangeSelectionActive || propagationRangeDisabled || totalFrames === 0 || !onPropagationRangeChange) return; + event.preventDefault(); + setIsPlaying(false); + const frame = frameFromPointerEvent(event); + setRangeDragAnchorFrame(frame); + event.currentTarget.setPointerCapture?.(event.pointerId); + onPropagationRangeChange(frame, frame); + }; + + const handleProcessingBarPointerDown = (event: React.PointerEvent) => { + if (propagationRangeSelectionActive) { + handleRangePointerDown(event); + return; + } + if (totalFrames === 0) return; + event.preventDefault(); + jumpToFrame(frameFromPointerEvent(event)); + }; + + const handleFrameMarkerClick = (event: React.MouseEvent, frame: number) => { + event.stopPropagation(); + if (propagationRangeSelectionActive) return; + jumpToFrame(frame); + }; + + const handleRangePointerMove = (event: React.PointerEvent) => { + if (rangeDragAnchorFrame === null) return; + updatePropagationRangeFromPointer(event, rangeDragAnchorFrame); + }; + + const handleRangePointerUp = (event: React.PointerEvent) => { + if (rangeDragAnchorFrame === null) return; + updatePropagationRangeFromPointer(event, rangeDragAnchorFrame); + setRangeDragAnchorFrame(null); + }; + useEffect(() => { if (!isPlaying || totalFrames <= 1) return; @@ -99,41 +210,48 @@ export function FrameTimeline() { : []; return ( -
-
+
+
{formatTime(currentSeconds)}
{formatTime(totalSeconds)}
- setCurrentFrame(parseInt(e.target.value) - 1)} - className="w-full absolute inset-0 opacity-0 cursor-ew-resize z-20" + className={cn( + "w-full absolute left-0 right-0 top-0 h-7 opacity-0 cursor-ew-resize z-20", + propagationRangeSelectionActive && "pointer-events-none", + )} disabled={totalFrames === 0} /> -
+
setRangeDragAnchorFrame(null)} + > + {visibleSelectedRange && ( +
+ )}
0 ? (currentFrame / totalFrames) * 100 : 0}%` }} /> - {propagatedFrameMarkers.map(({ frame, index }) => { - const left = totalFrames > 0 ? (index / totalFrames) * 100 : 0; - const width = totalFrames > 0 ? 100 / totalFrames : 0; - return ( -
- ); - })}
0 ? (currentFrame / totalFrames) * 100 : 0}%` }} @@ -141,8 +259,66 @@ export function FrameTimeline() { {formatTime(currentSeconds)}
+
setRangeDragAnchorFrame(null)} + > + {visibleSelectedRange && ( +
+ )} + {propagatedFrameMarkers.map(({ frame, index }) => { + const left = totalFrames > 0 ? (index / totalFrames) * 100 : 0; + const width = totalFrames > 0 ? 100 / totalFrames : 0; + return ( +
- 自动传播 {propagatedFrameMarkers.length} 帧 + 人工/AI {annotatedFrameMarkers.length} 帧 · 自动传播 {propagatedFrameMarkers.length} 帧
@@ -176,13 +352,33 @@ export function FrameTimeline() { } const frame = frames[idx]; const isCurrent = idx === currentFrameIndex; + const isPropagatedFrame = propagatedFrameIds.has(frame.id); + const isAnnotatedFrame = annotatedFrameIds.has(frame.id); return (
setCurrentFrame(idx)} + title={ + isPropagatedFrame + ? `自动传播帧 ${idx + 1}` + : isAnnotatedFrame + ? `人工/AI 标注帧 ${idx + 1}` + : `视频帧 ${idx + 1}` + } className={cn( "relative shrink-0 rounded-sm transition-all cursor-pointer flex items-center justify-center overflow-hidden group mx-0.5", - isCurrent ? "w-28 h-16 border-2 border-cyan-500 bg-gray-700 shadow-[0_0_15px_rgba(6,182,212,0.3)] z-10" : "w-16 h-12 border border-white/5 bg-gray-800/50 opacity-40 hover:opacity-100" + isCurrent + ? cn( + "w-28 h-16 border-2 border-cyan-500 bg-gray-700 z-10", + isAnnotatedFrame + ? "shadow-[inset_0_0_0_2px_rgba(239,68,68,0.95),0_0_15px_rgba(6,182,212,0.3)]" + : "shadow-[0_0_15px_rgba(6,182,212,0.3)]", + ) + : isPropagatedFrame + ? "w-16 h-12 border border-blue-500 bg-blue-950/30 opacity-80 shadow-[0_0_10px_rgba(59,130,246,0.22)] hover:opacity-100" + : isAnnotatedFrame + ? "w-16 h-12 border border-red-500 bg-red-950/30 opacity-85 shadow-[0_0_10px_rgba(239,68,68,0.22)] hover:opacity-100" + : "w-16 h-12 border border-white/5 bg-gray-800/50 opacity-40 hover:opacity-100" )} > {frame.url ? ( diff --git a/src/components/ProjectLibrary.test.tsx b/src/components/ProjectLibrary.test.tsx index 4e44d7f..8394506 100644 --- a/src/components/ProjectLibrary.test.tsx +++ b/src/components/ProjectLibrary.test.tsx @@ -42,6 +42,27 @@ describe('ProjectLibrary', () => { expect(onProjectSelect).toHaveBeenCalled(); }); + it('shows the generated frame sequence FPS on project cards instead of source FPS', async () => { + apiMock.getProjects.mockResolvedValueOnce([ + { + id: 'p-fps', + name: 'Frame Rate Demo', + status: 'ready', + frames: 120, + fps: '12FPS', + parse_fps: 12, + original_fps: 29.97, + video_path: 'uploads/demo.mp4', + }, + ]); + + render(); + + expect(await screen.findByText('12FPS')).toBeInTheDocument(); + expect(screen.getByText('原 30.0fps')).toBeInTheDocument(); + expect(screen.queryByText('30FPS')).not.toBeInTheDocument(); + }); + it('creates a new project from the modal', async () => { apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' }); @@ -74,6 +95,7 @@ describe('ProjectLibrary', () => { }))); expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3'); expect(apiMock.parseMedia).not.toHaveBeenCalled(); + expect(await screen.findByRole('status')).toHaveTextContent('视频导入成功'); }); it('generates frames from an imported video with the selected FPS', async () => { @@ -89,6 +111,8 @@ describe('ProjectLibrary', () => { fireEvent.click(screen.getByRole('button', { name: '开始生成帧' })); await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 })); + expect(await screen.findByRole('status')).toHaveTextContent('生成帧任务已入队 #22'); + expect(await screen.findByText('12FPS')).toBeInTheDocument(); }); it('deletes a project from the project card without entering the workspace', async () => { @@ -129,5 +153,6 @@ describe('ProjectLibrary', () => { await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm])); expect(apiMock.parseMedia).toHaveBeenCalledWith('77'); + expect(await screen.findByRole('status')).toHaveTextContent('DICOM 上传成功: 1 个文件'); }); }); diff --git a/src/components/ProjectLibrary.tsx b/src/components/ProjectLibrary.tsx index cb240bc..08a7dfc 100644 --- a/src/components/ProjectLibrary.tsx +++ b/src/components/ProjectLibrary.tsx @@ -4,6 +4,7 @@ import { cn } from '../lib/utils'; import { useStore } from '../store/useStore'; import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch, deleteProject } from '../lib/api'; import type { Project } from '../store/useStore'; +import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice'; interface ProjectLibraryProps { onProjectSelect: () => void; @@ -31,9 +32,24 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { const [frameParseFps, setFrameParseFps] = useState(30); const [isGeneratingFrames, setIsGeneratingFrames] = useState(false); const [deletingProjectId, setDeletingProjectId] = useState(null); + const [notice, setNotice] = useState(null); const videoInputRef = useRef(null); const dicomInputRef = useRef(null); + const showNotice = (message: string, tone: NoticeTone = 'info') => { + setNotice({ id: Date.now(), message, tone }); + }; + + const frameSequenceLabel = (project: Project) => { + if (project.source_type === 'dicom') return 'DICOM'; + if (project.video_path && (project.frames ?? 0) === 0 && project.status !== 'parsing') return '待生成帧'; + if (project.parse_fps && project.parse_fps > 0) { + const rounded = Math.round(project.parse_fps * 10) / 10; + return `${Number.isInteger(rounded) ? rounded.toFixed(0) : rounded.toFixed(1)}FPS`; + } + return project.fps || '30FPS'; + }; + useEffect(() => { setIsLoading(true); getProjects() @@ -81,7 +97,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { } } catch (err) { console.error('Delete project failed:', err); - alert('删除项目失败,请检查后端服务'); + showNotice('删除项目失败,请检查后端服务', 'error'); } finally { setDeletingProjectId(null); } @@ -102,12 +118,12 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { description: `导入于 ${new Date().toLocaleString()}`, }); const result = await uploadMedia(pendingFile, String(newProject.id)); - alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`); + showNotice(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`, 'success'); const data = await getProjects(); setProjects(data); } catch (err) { console.error('Upload failed:', err); - alert('上传失败,请检查后端服务'); + showNotice('上传失败,请检查后端服务', 'error'); } finally { setIsLoading(false); setPendingFile(null); @@ -127,14 +143,14 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { setIsGeneratingFrames(true); try { const task = await parseMedia(frameProject.id, { parseFps: frameParseFps }); - alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`); + showNotice(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`, 'success'); const data = await getProjects(); setProjects(data); setShowFrameConfig(false); setFrameProject(null); } catch (err) { console.error('Frame generation failed:', err); - alert('生成帧失败,请检查后端服务或项目源文件'); + showNotice('生成帧失败,请检查后端服务或项目源文件', 'error'); } finally { setIsGeneratingFrames(false); } @@ -144,19 +160,19 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { if (!files || files.length === 0) return; const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm')); if (dcmFiles.length === 0) { - alert('未选择有效的 .dcm 文件'); + showNotice('未选择有效的 .dcm 文件', 'error'); return; } setIsLoading(true); try { const result = await uploadDicomBatch(dcmFiles); await parseMedia(String(result.project_id)); - alert(`DICOM 上传成功: ${result.uploaded_count} 个文件`); + showNotice(`DICOM 上传成功: ${result.uploaded_count} 个文件`, 'success'); const data = await getProjects(); setProjects(data); } catch (err) { console.error('DICOM upload failed:', err); - alert('DICOM 上传失败,请检查后端服务'); + showNotice('DICOM 上传失败,请检查后端服务', 'error'); } finally { setIsLoading(false); if (dicomInputRef.current) dicomInputRef.current.value = ''; @@ -175,6 +191,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { return (
+ setNotice(null)} />

视频与连续帧项目库

@@ -263,7 +280,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) { )}
- {proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))} + {frameSequenceLabel(proj)} {proj.status === 'ready' ? ( diff --git a/src/components/TemplateRegistry.test.tsx b/src/components/TemplateRegistry.test.tsx index 4035394..239a1f3 100644 --- a/src/components/TemplateRegistry.test.tsx +++ b/src/components/TemplateRegistry.test.tsx @@ -83,6 +83,32 @@ describe('TemplateRegistry', () => { expect(screen.getByText('分类A')).toBeInTheDocument(); }); + it('shows JSON import errors as transient notices instead of blocking alerts', async () => { + apiMock.getTemplates.mockResolvedValueOnce([]); + + render(); + fireEvent.click(screen.getByText('新建方案')); + fireEvent.click(screen.getByText('批量导入')); + fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), { + target: { value: '{broken-json' }, + }); + fireEvent.click(screen.getByRole('button', { name: '导入' })); + + expect(await screen.findByRole('status')).toHaveTextContent('JSON 解析失败'); + }); + + it('shows template save errors as transient notices', async () => { + apiMock.getTemplates.mockResolvedValueOnce([]); + apiMock.createTemplate.mockRejectedValueOnce(new Error('boom')); + + render(); + fireEvent.click(screen.getByText('新建方案')); + fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'Bad Template' } }); + fireEvent.click(screen.getByRole('button', { name: '保存' })); + + expect(await screen.findByRole('status')).toHaveTextContent('保存失败,请查看控制台'); + }); + it('edits an existing template through the backend and store', async () => { apiMock.getTemplates.mockResolvedValueOnce([ { diff --git a/src/components/TemplateRegistry.tsx b/src/components/TemplateRegistry.tsx index 61e31ab..736376b 100644 --- a/src/components/TemplateRegistry.tsx +++ b/src/components/TemplateRegistry.tsx @@ -4,6 +4,7 @@ import { cn } from '../lib/utils'; import { useStore } from '../store/useStore'; import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api'; import type { Template, TemplateClass } from '../store/useStore'; +import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice'; // HSL to Hex color generator function hslToHex(h: number, s: number, l: number): string { @@ -59,6 +60,11 @@ export function TemplateRegistry() { const [editClasses, setEditClasses] = useState([]); const [editingClassId, setEditingClassId] = useState(null); const [dragOverIndex, setDragOverIndex] = useState(null); + const [notice, setNotice] = useState(null); + + const showNotice = (message: string, tone: NoticeTone = 'info') => { + setNotice({ id: Date.now(), message, tone }); + }; useEffect(() => { setIsLoading(true); @@ -106,7 +112,7 @@ export function TemplateRegistry() { setShowModal(false); } catch (err) { console.error('Failed to save template:', err); - alert('保存失败,请查看控制台'); + showNotice('保存失败,请查看控制台', 'error'); } finally { setIsSaving(false); } @@ -122,6 +128,7 @@ export function TemplateRegistry() { } } catch (err) { console.error('Failed to delete template:', err); + showNotice('删除失败,请检查后端服务', 'error'); } }; @@ -168,7 +175,7 @@ export function TemplateRegistry() { colors = data.colors; names = data.names; } else { - alert('格式错误:请提供 [[colors...], [names...]] 或 {colors, names}'); + showNotice('格式错误:请提供 [[colors...], [names...]] 或 {colors, names}', 'error'); return; } @@ -188,7 +195,7 @@ export function TemplateRegistry() { setShowImport(false); setImportText(''); } catch (e) { - alert('JSON 解析失败'); + showNotice('JSON 解析失败', 'error'); } }; @@ -212,6 +219,7 @@ export function TemplateRegistry() { return (
+ setNotice(null)} />

分割模板与分类优先级管理库

定义业务语义本体树架构、约束覆盖遮罩优先级(Z-Index裁决权重),以及真实标签数据的向下兼容转换映射(Dict Translation)原则。

diff --git a/src/components/ToolsPalette.test.tsx b/src/components/ToolsPalette.test.tsx index d425e16..9974cf4 100644 --- a/src/components/ToolsPalette.test.tsx +++ b/src/components/ToolsPalette.test.tsx @@ -42,4 +42,13 @@ describe('ToolsPalette', () => { expect(setActiveTool).toHaveBeenCalledWith('sam_trigger'); expect(onTriggerAI).toHaveBeenCalled(); }); + + it('uses compact vertically scrollable layout for smaller workspaces', () => { + const { container } = render(); + const palette = container.firstElementChild; + + expect(palette).toHaveClass('overflow-y-auto'); + expect(screen.getByTitle('创建多边形 (P)')).toHaveClass('h-9'); + expect(screen.getByTitle('打开 AI 智能分割')).toHaveClass('h-9'); + }); }); diff --git a/src/components/ToolsPalette.tsx b/src/components/ToolsPalette.tsx index cfd344c..f32f978 100644 --- a/src/components/ToolsPalette.tsx +++ b/src/components/ToolsPalette.tsx @@ -40,8 +40,8 @@ export function ToolsPalette({ ]; return ( -
-
+
+
{tools.map(tool => { const Icon = tool.icon; const isActive = activeTool === tool.id; @@ -51,7 +51,7 @@ export function ToolsPalette({ onClick={() => setActiveTool(tool.id)} title={tool.label} className={cn( - "w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2", + "w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5", isActive ? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500" : tool.id.includes('merge') ? "bg-green-500/10 text-green-500" @@ -59,12 +59,12 @@ export function ToolsPalette({ : "text-gray-500 hover:bg-white/5 hover:text-white" )} > - + ) })} -
+
{aiTools.map(tool => { const Icon = tool.icon; @@ -75,13 +75,13 @@ export function ToolsPalette({ onClick={() => setActiveTool(tool.id)} title={tool.label} className={cn( - "w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2 border", + "w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border", isActive ? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]` : "text-gray-500 hover:bg-white/5 hover:text-white border-transparent" )} > - + ) })} @@ -93,32 +93,32 @@ export function ToolsPalette({ }} title="打开 AI 智能分割" className={cn( - "w-10 h-10 rounded-lg flex items-center justify-center transition-all", + "w-9 h-9 rounded-md flex items-center justify-center transition-all", activeTool === 'sam_trigger' ? "bg-cyan-600 text-white shadow-lg shadow-cyan-900/20" : "text-gray-500 hover:bg-white/5" )} > - + -
+
diff --git a/src/components/TransientNotice.test.tsx b/src/components/TransientNotice.test.tsx new file mode 100644 index 0000000..59749ba --- /dev/null +++ b/src/components/TransientNotice.test.tsx @@ -0,0 +1,28 @@ +import { act, render, screen } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; +import { TransientNotice } from './TransientNotice'; + +describe('TransientNotice', () => { + it('renders a non-blocking notice and dismisses it after the timeout', () => { + vi.useFakeTimers(); + const onDismiss = vi.fn(); + + render( + , + ); + + expect(screen.getByRole('status')).toHaveTextContent('操作已完成'); + expect(screen.getByRole('status').parentElement).toHaveClass('pointer-events-none'); + + act(() => { + vi.advanceTimersByTime(1000); + }); + + expect(onDismiss).toHaveBeenCalledTimes(1); + vi.useRealTimers(); + }); +}); diff --git a/src/components/TransientNotice.tsx b/src/components/TransientNotice.tsx new file mode 100644 index 0000000..04d3279 --- /dev/null +++ b/src/components/TransientNotice.tsx @@ -0,0 +1,47 @@ +import React, { useEffect } from 'react'; +import { cn } from '../lib/utils'; + +export type NoticeTone = 'info' | 'success' | 'error'; + +export interface NoticeState { + id: number; + message: string; + tone?: NoticeTone; +} + +interface TransientNoticeProps { + notice: NoticeState | null; + onDismiss: () => void; + durationMs?: number; +} + +const toneClasses: Record = { + info: 'border-cyan-400/30 bg-cyan-950/85 text-cyan-100 shadow-cyan-950/30', + success: 'border-emerald-400/30 bg-emerald-950/85 text-emerald-100 shadow-emerald-950/30', + error: 'border-red-400/30 bg-red-950/85 text-red-100 shadow-red-950/30', +}; + +export function TransientNotice({ notice, onDismiss, durationMs = 3600 }: TransientNoticeProps) { + useEffect(() => { + if (!notice) return undefined; + const timer = window.setTimeout(onDismiss, durationMs); + return () => window.clearTimeout(timer); + }, [durationMs, notice, onDismiss]); + + if (!notice) return null; + + const tone = notice.tone || 'info'; + return ( +
+
+ {notice.message} +
+
+ ); +} diff --git a/src/components/VideoWorkspace.test.tsx b/src/components/VideoWorkspace.test.tsx index a22a9a9..86f5c74 100644 --- a/src/components/VideoWorkspace.test.tsx +++ b/src/components/VideoWorkspace.test.tsx @@ -1,4 +1,4 @@ -import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'; +import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { resetStore } from '../test/storeTestUtils'; import { useStore } from '../store/useStore'; @@ -8,7 +8,9 @@ const apiMock = vi.hoisted(() => ({ getProjectFrames: vi.fn(), parseMedia: vi.fn(), propagateMasks: vi.fn(), + queuePropagationTask: vi.fn(), getTask: vi.fn(), + cancelTask: vi.fn(), getTemplates: vi.fn(), getProjectAnnotations: vi.fn(), saveAnnotation: vi.fn(), @@ -26,7 +28,9 @@ vi.mock('../lib/api', () => ({ getProjectFrames: apiMock.getProjectFrames, parseMedia: apiMock.parseMedia, propagateMasks: apiMock.propagateMasks, + queuePropagationTask: apiMock.queuePropagationTask, getTask: apiMock.getTask, + cancelTask: apiMock.cancelTask, getTemplates: apiMock.getTemplates, getProjectAnnotations: apiMock.getProjectAnnotations, saveAnnotation: apiMock.saveAnnotation, @@ -48,7 +52,15 @@ describe('VideoWorkspace', () => { apiMock.getTemplates.mockResolvedValue([]); apiMock.getProjectAnnotations.mockResolvedValue([]); apiMock.annotationToMask.mockReturnValue(null); - apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' }); + apiMock.queuePropagationTask.mockResolvedValue({ id: 31, status: 'queued', progress: 0, message: '自动传播任务已入队' }); + apiMock.getTask.mockResolvedValue({ + id: 31, + status: 'success', + progress: 100, + message: '自动传播完成', + result: { processed_frame_count: 3, created_annotation_count: 2, completed_steps: 1 }, + }); + apiMock.cancelTask.mockResolvedValue({ id: 31, status: 'cancelled', progress: 100, message: '任务已取消' }); apiMock.propagateMasks.mockResolvedValue({ model: 'sam2.1_hiera_tiny', direction: 'forward', @@ -81,6 +93,60 @@ describe('VideoWorkspace', () => { expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1'); }); + it('exposes workspace undo/redo buttons and keyboard shortcuts without hijacking inputs', async () => { + const mask = { + id: 'mask-undo', + frameId: '10', + pathData: 'M 0 0 Z', + label: 'Draft', + color: '#06b6d4', + }; + useStore.setState({ + currentProject: null, + masks: [mask], + maskHistory: [[]], + maskFuture: [], + }); + + render(); + fireEvent.click(screen.getByRole('button', { name: '撤销操作' })); + + expect(useStore.getState().masks).toEqual([]); + + fireEvent.click(screen.getByRole('button', { name: '重做操作' })); + expect(useStore.getState().masks).toEqual([mask]); + + fireEvent.keyDown(window, { key: 'z', ctrlKey: true }); + expect(useStore.getState().masks).toEqual([]); + + fireEvent.keyDown(window, { key: 'z', ctrlKey: true, shiftKey: true }); + expect(useStore.getState().masks).toEqual([mask]); + + fireEvent.keyDown(screen.getByLabelText('传播起始帧'), { key: 'z', ctrlKey: true }); + expect(useStore.getState().masks).toEqual([mask]); + }); + + it('auto-dismisses short workspace operation messages without blocking later actions', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, + ]); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(1)); + + vi.useFakeTimers(); + fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' })); + expect(screen.getByText('没有待保存标注')).toBeInTheDocument(); + + act(() => { + vi.advanceTimersByTime(3600); + }); + + expect(screen.queryByText('没有待保存标注')).not.toBeInTheDocument(); + expect(screen.getByRole('button', { name: '结构化归档保存' })).not.toBeDisabled(); + vi.useRealTimers(); + }); + it('does not auto-generate frames when a media project has no frames yet', async () => { apiMock.getProjectFrames.mockResolvedValueOnce([]); @@ -417,26 +483,190 @@ describe('VideoWorkspace', () => { }); fireEvent.click(screen.getByRole('button', { name: '自动传播' })); + expect(apiMock.queuePropagationTask).not.toHaveBeenCalled(); + fireEvent.click(screen.getByRole('button', { name: '开始传播' })); - await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({ + await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith({ project_id: 1, frame_id: 10, model: 'sam2.1_hiera_tiny', - direction: 'forward', - max_frames: 2, include_source: false, save_annotations: true, - seed: { + steps: [{ + direction: 'forward', + max_frames: 2, + seed: { + polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], + bbox: [0.1, 0.1, 0.2, 0.2], + points: undefined, + label: '胆囊', + color: '#ff0000', + class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 }, + template_id: 2, + source_mask_id: 'mask-1', + source_annotation_id: undefined, + }, + }], + })); + await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,删除旧区域 0 个,保存 2 个区域')).toBeInTheDocument()); + }); + + it('uses the separately selected propagation weight when queueing propagation', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, + { id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 }, + ]); + apiMock.buildAnnotationPayload.mockReturnValueOnce({ + project_id: 1, + frame_id: 10, + mask_data: { polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], - bbox: [0.1, 0.1, 0.2, 0.2], - points: undefined, label: '胆囊', color: '#ff0000', - class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 }, - template_id: 2, }, - })); - await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,保存 2 个区域')).toBeInTheDocument()); + bbox: [0.1, 0.1, 0.2, 0.2], + }); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(2)); + act(() => { + useStore.setState({ + aiModel: 'sam2.1_hiera_tiny', + masks: [{ + id: 'mask-propagation-model', + frameId: '10', + pathData: 'M 0 0 Z', + label: '胆囊', + color: '#ff0000', + segmentation: [[64, 36, 192, 36, 192, 108]], + bbox: [64, 36, 128, 72], + }], + }); + }); + + const propagationWeightSelect = screen.getByLabelText('传播权重'); + fireEvent.change(propagationWeightSelect, { target: { value: 'sam2.1_hiera_small' } }); + expect(propagationWeightSelect).toHaveValue('sam2.1_hiera_small'); + fireEvent.click(screen.getByRole('button', { name: '自动传播' })); + fireEvent.click(screen.getByRole('button', { name: '开始传播' })); + + await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith(expect.objectContaining({ + model: 'sam2.1_hiera_small', + }))); + await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,删除旧区域 0 个,保存 2 个区域')).toBeInTheDocument()); + }); + + it('shows propagation task progress and reports empty results', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, + { id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 }, + ]); + apiMock.buildAnnotationPayload.mockReturnValueOnce({ + project_id: 1, + frame_id: 10, + mask_data: { + polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], + label: '胆囊', + color: '#ff0000', + }, + bbox: [0.1, 0.1, 0.2, 0.2], + }); + apiMock.queuePropagationTask.mockResolvedValueOnce({ id: 44, status: 'queued', progress: 0, message: '自动传播任务已入队' }); + apiMock.getTask.mockResolvedValueOnce({ + id: 44, + status: 'success', + progress: 100, + message: '自动传播完成,但没有生成新的 mask', + result: { processed_frame_count: 2, created_annotation_count: 0, completed_steps: 1 }, + }); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(2)); + act(() => { + useStore.setState({ + masks: [{ + id: 'mask-progress', + frameId: 10 as unknown as string, + pathData: 'M 0 0 Z', + label: '胆囊', + color: '#ff0000', + segmentation: [[64, 36, 192, 36, 192, 108]], + bbox: [64, 36, 128, 72], + }], + }); + }); + + fireEvent.click(screen.getByRole('button', { name: '自动传播' })); + fireEvent.click(screen.getByRole('button', { name: '开始传播' })); + + const progressPanel = await screen.findByLabelText('自动传播进度'); + expect(progressPanel).toBeInTheDocument(); + expect(within(progressPanel).getByText('0%')).toBeInTheDocument(); + + expect(await screen.findByText(/没有生成新的 mask/)).toBeInTheDocument(); + }); + + it('lets users select the propagation range on the timeline before queueing', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 }, + { id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 }, + { id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 }, + { id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 }, + { id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 }, + ]); + apiMock.buildAnnotationPayload.mockReturnValueOnce({ + project_id: 1, + frame_id: 10, + mask_data: { + polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], + label: '胆囊', + color: '#ff0000', + }, + bbox: [0.1, 0.1, 0.2, 0.2], + }); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(5)); + act(() => { + useStore.setState({ + masks: [{ + id: 'mask-timeline-range', + frameId: '10', + pathData: 'M 0 0 Z', + label: '胆囊', + color: '#ff0000', + segmentation: [[64, 36, 192, 36, 192, 108]], + bbox: [64, 36, 128, 72], + }], + }); + }); + + fireEvent.click(screen.getByRole('button', { name: '自动传播' })); + const processingBar = screen.getByLabelText('视频处理进度条'); + vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({ + left: 0, + right: 100, + top: 0, + bottom: 10, + width: 100, + height: 10, + x: 0, + y: 0, + toJSON: () => ({}), + }); + fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 }); + fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 }); + fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 }); + + expect(screen.getByLabelText('传播起始帧')).toHaveValue(2); + expect(screen.getByLabelText('传播结束帧')).toHaveValue(5); + + fireEvent.click(screen.getByRole('button', { name: '开始传播' })); + + await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith(expect.objectContaining({ + frame_id: 10, + steps: [expect.objectContaining({ direction: 'forward', max_frames: 5 })], + }))); }); it('auto-propagates all reference-frame masks in both directions inside the selected range', async () => { @@ -445,13 +675,16 @@ describe('VideoWorkspace', () => { { id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 }, { id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 }, ]); - apiMock.propagateMasks.mockResolvedValue({ - model: 'sam2.1_hiera_tiny', - direction: 'forward', - source_frame_id: 11, - processed_frame_count: 2, - created_annotation_count: 1, - annotations: [], + apiMock.getTask.mockResolvedValue({ + id: 31, + status: 'success', + progress: 100, + message: '自动传播完成', + result: { + processed_frame_count: 8, + created_annotation_count: 4, + completed_steps: 4, + }, }); apiMock.buildAnnotationPayload .mockReturnValueOnce({ @@ -505,27 +738,14 @@ describe('VideoWorkspace', () => { fireEvent.change(screen.getByLabelText('传播结束帧'), { target: { value: '3' } }); fireEvent.click(screen.getByRole('button', { name: '自动传播' })); - await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledTimes(4)); - expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(1, expect.objectContaining({ - direction: 'backward', - max_frames: 2, - seed: expect.objectContaining({ label: '胆囊' }), - })); - expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(2, expect.objectContaining({ - direction: 'forward', - max_frames: 2, - seed: expect.objectContaining({ label: '胆囊' }), - })); - expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(3, expect.objectContaining({ - direction: 'backward', - max_frames: 2, - seed: expect.objectContaining({ label: '肝脏' }), - })); - expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(4, expect.objectContaining({ - direction: 'forward', - max_frames: 2, - seed: expect.objectContaining({ label: '肝脏' }), - })); - await waitFor(() => expect(screen.getByText('已自动传播 2 个参考 mask,处理 8 帧次,保存 4 个区域')).toBeInTheDocument()); + await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledTimes(1)); + const queuedPayload = apiMock.queuePropagationTask.mock.calls[0][0]; + expect(queuedPayload.steps).toEqual([ + expect.objectContaining({ direction: 'backward', max_frames: 2, seed: expect.objectContaining({ label: '胆囊' }) }), + expect.objectContaining({ direction: 'forward', max_frames: 2, seed: expect.objectContaining({ label: '胆囊' }) }), + expect.objectContaining({ direction: 'backward', max_frames: 2, seed: expect.objectContaining({ label: '肝脏' }) }), + expect.objectContaining({ direction: 'forward', max_frames: 2, seed: expect.objectContaining({ label: '肝脏' }) }), + ]); + await waitFor(() => expect(screen.getByText('已自动传播 2 个参考 mask,处理 8 帧次,删除旧区域 0 个,保存 4 个区域')).toBeInTheDocument()); }); }); diff --git a/src/components/VideoWorkspace.tsx b/src/components/VideoWorkspace.tsx index d2a06b5..6599a0e 100644 --- a/src/components/VideoWorkspace.tsx +++ b/src/components/VideoWorkspace.tsx @@ -1,16 +1,19 @@ import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { Redo, Undo } from 'lucide-react'; import { useStore } from '../store/useStore'; import { annotationToMask, buildAnnotationPayload, + cancelTask, deleteAnnotation, exportCoco, exportMasks, getProjectAnnotations, getProjectFrames, + getTask, getTemplates, importGtMask, - propagateMasks, + queuePropagationTask, saveAnnotation, updateAnnotation, } from '../lib/api'; @@ -19,9 +22,20 @@ import { ToolsPalette } from './ToolsPalette'; import { OntologyInspector } from './OntologyInspector'; import { FrameTimeline } from './FrameTimeline'; import { ModelStatusBadge } from './ModelStatusBadge'; -import type { Frame, Mask } from '../store/useStore'; +import { DEFAULT_AI_MODEL_ID, SAM2_MODEL_OPTIONS, type AiModelId, type Frame, type Mask } from '../store/useStore'; type PropagationDirection = 'forward' | 'backward'; +type PropagationProgress = { + currentStep: number; + completedSteps: number; + totalSteps: number; + processedCount: number; + createdCount: number; + label: string; +} | null; + +const PROPAGATION_POLL_INTERVAL_MS = 250; +const STATUS_MESSAGE_TTL_MS = 3600; export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { const gtMaskInputRef = React.useRef(null); @@ -53,6 +67,47 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void const [statusMessage, setStatusMessage] = useState(''); const [propagationStartFrame, setPropagationStartFrame] = useState(1); const [propagationEndFrame, setPropagationEndFrame] = useState(1); + const [isPropagationRangeSelecting, setIsPropagationRangeSelecting] = useState(false); + const [hasExplicitPropagationRange, setHasExplicitPropagationRange] = useState(false); + const [propagationProgress, setPropagationProgress] = useState(null); + const [propagationTaskId, setPropagationTaskId] = useState(null); + const [propagationWeight, setPropagationWeight] = useState(aiModel || DEFAULT_AI_MODEL_ID); + const [hasCustomPropagationWeight, setHasCustomPropagationWeight] = useState(false); + + useEffect(() => { + if (!hasCustomPropagationWeight) { + setPropagationWeight(aiModel || DEFAULT_AI_MODEL_ID); + } + }, [aiModel, hasCustomPropagationWeight]); + + const propagationWeightLabel = useMemo( + () => SAM2_MODEL_OPTIONS.find((option) => option.id === propagationWeight)?.label || propagationWeight, + [propagationWeight], + ); + + useEffect(() => { + const handleWorkspaceShortcuts = (event: KeyboardEvent) => { + const target = event.target as HTMLElement | null; + const tagName = target?.tagName?.toLowerCase(); + if (tagName === 'input' || tagName === 'textarea' || tagName === 'select' || target?.isContentEditable) return; + if (!event.metaKey && !event.ctrlKey) return; + + const key = event.key.toLowerCase(); + if (key === 'z') { + event.preventDefault(); + if (event.shiftKey) redoMasks(); + else undoMasks(); + return; + } + if (key === 'y') { + event.preventDefault(); + redoMasks(); + } + }; + + window.addEventListener('keydown', handleWorkspaceShortcuts); + return () => window.removeEventListener('keydown', handleWorkspaceShortcuts); + }, [redoMasks, undoMasks]); const hydrateSavedAnnotations = useCallback(async ( projectId: string, @@ -143,6 +198,13 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]); const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]); const currentFrameNumber = currentFrameIndex + 1; + const isWorkspaceBusy = isSaving || isExporting || isImportingGt || isPropagating || Boolean(propagationProgress); + + useEffect(() => { + if (!statusMessage || isWorkspaceBusy || totalFrames === 0) return undefined; + const timer = window.setTimeout(() => setStatusMessage(''), STATUS_MESSAGE_TTL_MS); + return () => window.clearTimeout(timer); + }, [isWorkspaceBusy, statusMessage, totalFrames]); useEffect(() => { if (totalFrames === 0) { @@ -152,6 +214,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void } setPropagationStartFrame(currentFrameNumber); setPropagationEndFrame(Math.min(totalFrames, currentFrameNumber + 29)); + setIsPropagationRangeSelecting(false); + setHasExplicitPropagationRange(false); }, [currentFrameNumber, totalFrames]); const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => { @@ -341,6 +405,9 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) { return null; } + const sourceAnnotationId = seedMask.annotationId && /^\d+$/.test(seedMask.annotationId) + ? Number(seedMask.annotationId) + : undefined; return { polygons: seedPayload.mask_data?.polygons, bbox: seedPayload.bbox, @@ -349,12 +416,33 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void color: seedPayload.mask_data?.color, class_metadata: seedPayload.mask_data?.class, template_id: seedPayload.template_id, + source_mask_id: seedMask.id, + source_annotation_id: sourceAnnotationId, }; }, [activeTemplateId, currentFrame, currentProject?.id]); - const handleAutoPropagate = async () => { + const handlePropagationRangeChange = useCallback((startFrame: number, endFrame: number) => { + const nextStart = clampFrameNumber(startFrame); + const nextEnd = clampFrameNumber(endFrame); + setPropagationStartFrame(nextStart); + setPropagationEndFrame(nextEnd); + setHasExplicitPropagationRange(true); + setStatusMessage(`已选择自动传播范围:第 ${Math.min(nextStart, nextEnd)}-${Math.max(nextStart, nextEnd)} 帧`); + }, [clampFrameNumber]); + + const handlePropagationStartInput = (value: number) => { + setPropagationStartFrame(clampFrameNumber(value || 1)); + setHasExplicitPropagationRange(true); + }; + + const handlePropagationEndInput = (value: number) => { + setPropagationEndFrame(clampFrameNumber(value || 1)); + setHasExplicitPropagationRange(true); + }; + + const runAutoPropagate = async () => { if (!currentProject?.id || !currentFrame?.id) return; - const seedMasks = masks.filter((mask) => mask.frameId === currentFrame.id); + const seedMasks = masks.filter((mask) => String(mask.frameId) === String(currentFrame.id)); if (seedMasks.length === 0) { setStatusMessage('请先在当前参考帧创建或保存至少一个 mask'); return; @@ -390,37 +478,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void return; } + setIsPropagationRangeSelecting(false); setIsPropagating(true); - setStatusMessage(`${aiModel.toUpperCase()} 正在以第 ${currentFrameNumber} 帧为参考,自动传播 ${seeds.length} 个 mask 到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`); + const totalSteps = seeds.length * propagationDirections.length; + setPropagationProgress({ + currentStep: 0, + completedSteps: 0, + totalSteps, + processedCount: 0, + createdCount: 0, + label: '准备传播', + }); + setStatusMessage(`${propagationWeightLabel} 权重正在以第 ${currentFrameNumber} 帧为参考,自动传播 ${seeds.length} 个 mask 到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`); try { - let createdCount = 0; - let processedCount = 0; - for (const { seed } of seeds) { - for (const { direction, maxFrames } of propagationDirections) { - const result = await propagateMasks({ - project_id: Number(currentProject.id), - frame_id: Number(currentFrame.id), - model: aiModel, - direction, - max_frames: maxFrames, - include_source: false, - save_annotations: true, - seed, - }); - createdCount += result.created_annotation_count; - processedCount += result.processed_frame_count; + const steps = seeds.flatMap(({ seed }) => ( + propagationDirections.map(({ direction, maxFrames }) => ({ + seed, + direction, + max_frames: maxFrames, + })) + )); + const task = await queuePropagationTask({ + project_id: Number(currentProject.id), + frame_id: Number(currentFrame.id), + model: propagationWeight, + steps, + include_source: false, + save_annotations: true, + }); + setPropagationTaskId(task.id); + setStatusMessage(`自动传播任务已入队 #${task.id},可在 Dashboard 查看进度`); + + let currentTask = task; + while (!['success', 'failed', 'cancelled'].includes(currentTask.status)) { + await new Promise((resolve) => setTimeout(resolve, PROPAGATION_POLL_INTERVAL_MS)); + currentTask = await getTask(task.id); + const result = currentTask.result || {}; + const completedSteps = Number(result.completed_steps || 0); + const processedCount = Number(result.processed_frame_count || 0); + const createdCount = Number(result.created_annotation_count || 0); + setPropagationProgress({ + currentStep: Math.min(completedSteps + 1, totalSteps), + completedSteps, + totalSteps, + processedCount, + createdCount, + label: currentTask.message || `自动传播任务 #${task.id}`, + }); + setStatusMessage(currentTask.message || `自动传播任务 #${task.id} 运行中...`); + if (createdCount > 0) { + await hydrateSavedAnnotations(currentProject.id, frames); } } + + const result = currentTask.result || {}; + const createdCount = Number(result.created_annotation_count || 0); + const processedCount = Number(result.processed_frame_count || 0); + const skippedCount = Number(result.skipped_seed_count || 0); + const deletedCount = Number(result.deleted_annotation_count || 0); await hydrateSavedAnnotations(currentProject.id, frames); - setStatusMessage(`已自动传播 ${seeds.length} 个参考 mask,处理 ${processedCount} 帧次,保存 ${createdCount} 个区域`); + if (currentTask.status === 'failed') { + setStatusMessage(currentTask.error ? `传播失败:${currentTask.error}` : '传播失败,请检查权重状态或后端日志'); + return; + } + if (currentTask.status === 'cancelled') { + setStatusMessage('自动传播任务已取消'); + return; + } + setStatusMessage(createdCount > 0 + ? `已自动传播 ${seeds.length} 个参考 mask,处理 ${processedCount} 帧次,删除旧区域 ${deletedCount} 个,保存 ${createdCount} 个区域` + : skippedCount > 0 + ? `自动传播已完成:${skippedCount} 个未改变 mask 已跳过,没有生成重复区域` + : `自动传播已完成,但没有生成新的 mask;请检查参考 mask、传播范围或 ${propagationWeightLabel} 权重状态`); } catch (err) { console.error('Propagation failed:', err); - setStatusMessage('传播失败,请检查模型状态或后端日志'); + const detail = (err as any)?.response?.data?.detail; + setStatusMessage(detail ? `传播失败:${detail}` : '传播失败,请检查权重状态或后端日志'); } finally { setIsPropagating(false); + setPropagationProgress(null); + setPropagationTaskId(null); } }; + const handleAutoPropagate = async () => { + if (!hasExplicitPropagationRange && !isPropagationRangeSelecting) { + setIsPropagationRangeSelecting(true); + setStatusMessage('请在播放进度条或视频处理进度条上点击/拖拽选择传播起止帧,再点击“开始传播”'); + return; + } + await runAutoPropagate(); + }; + + const handleCancelPropagationRangeSelection = () => { + setIsPropagationRangeSelecting(false); + setHasExplicitPropagationRange(false); + setPropagationStartFrame(currentFrameNumber || 1); + setPropagationEndFrame(Math.min(Math.max(totalFrames, 1), (currentFrameNumber || 1) + 29)); + setStatusMessage('已取消自动传播范围选择'); + }; + + const handleCancelPropagation = async () => { + if (!propagationTaskId) return; + try { + await cancelTask(propagationTaskId); + setStatusMessage(`正在取消自动传播任务 #${propagationTaskId}...`); + } catch (err) { + console.error('Cancel propagation failed:', err); + setStatusMessage('取消自动传播失败,请稍后重试'); + } + }; + + const propagationPercent = propagationProgress + ? Math.round((propagationProgress.completedSteps / Math.max(propagationProgress.totalSteps, 1)) * 100) + : 0; + return (
{/* Top Header / Status bar */} @@ -436,7 +608,68 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void {statusMessage} )} + {propagationProgress && ( +
+
+ {propagationProgress.label} + {propagationPercent}% +
+
+
+
+
+ )} +
+ + +
+
+ 传播权重 + +
void min={1} max={Math.max(totalFrames, 1)} value={propagationStartFrame} - onChange={(event) => setPropagationStartFrame(clampFrameNumber(Number(event.target.value) || 1))} + onChange={(event) => handlePropagationStartInput(Number(event.target.value))} disabled={isPropagating || isSaving || isExporting || isImportingGt || totalFrames === 0} className="h-6 w-14 rounded bg-black/20 border border-white/10 px-1 text-[10px] text-gray-300 outline-none focus:border-cyan-500/50 disabled:opacity-40" /> @@ -471,7 +704,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void min={1} max={Math.max(totalFrames, 1)} value={propagationEndFrame} - onChange={(event) => setPropagationEndFrame(clampFrameNumber(Number(event.target.value) || 1))} + onChange={(event) => handlePropagationEndInput(Number(event.target.value))} disabled={isPropagating || isSaving || isExporting || isImportingGt || totalFrames === 0} className="h-6 w-14 rounded bg-black/20 border border-white/10 px-1 text-[10px] text-gray-300 outline-none focus:border-cyan-500/50 disabled:opacity-40" /> @@ -481,8 +714,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating} className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed" > - {isPropagating ? '传播中...' : '自动传播'} + {isPropagating ? '传播中...' : isPropagationRangeSelecting ? '开始传播' : '自动传播'} + {isPropagationRangeSelecting && ( + + )} + {propagationTaskId && ( + + )}
{/* Bottom Timeline */} - +
); } diff --git a/src/lib/api.test.ts b/src/lib/api.test.ts index b2946db..0ea32c1 100644 --- a/src/lib/api.test.ts +++ b/src/lib/api.test.ts @@ -53,10 +53,12 @@ describe('api client contracts', () => { name: 'Demo', status: 'ready', frames: 12, - fps: '30FPS', + fps: '10FPS', thumbnail_url: 'thumb', video_path: 'uploads/demo.mp4', source_type: 'video', + original_fps: 29.97, + parse_fps: 10, createdAt: 'created', updatedAt: 'updated', }), @@ -184,7 +186,7 @@ describe('api client contracts', () => { }); it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => { - const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api'); + const { deleteAnnotation, getProjectAnnotations, propagateMasks, queuePropagationTask, saveAnnotation, updateAnnotation } = await import('./api'); const saved = { id: 1, project_id: 9, @@ -267,6 +269,48 @@ describe('api client contracts', () => { }, { timeout: 600000, }); + + axiosMock.client.post.mockResolvedValueOnce({ + data: { + id: 33, + task_type: 'propagate_masks', + status: 'queued', + progress: 0, + message: '自动传播任务已入队', + }, + }); + await expect(queuePropagationTask({ + project_id: 9, + frame_id: 5, + model: 'sam2.1_hiera_tiny', + steps: [{ + seed: { + polygons: [[[0, 0], [1, 0], [1, 1]]], + label: 'mask', + color: '#06b6d4', + }, + direction: 'forward', + max_frames: 30, + }], + include_source: false, + save_annotations: true, + })).resolves.toEqual(expect.objectContaining({ id: 33, task_type: 'propagate_masks' })); + expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate/task', { + project_id: 9, + frame_id: 5, + model: 'sam2.1_hiera_tiny', + steps: [{ + seed: { + polygons: [[[0, 0], [1, 0], [1, 1]]], + label: 'mask', + color: '#06b6d4', + }, + direction: 'forward', + max_frames: 30, + }], + include_source: false, + save_annotations: true, + }); }); it('imports GT masks through multipart form data', async () => { diff --git a/src/lib/api.ts b/src/lib/api.ts index 078cd8a..6713713 100644 --- a/src/lib/api.ts +++ b/src/lib/api.ts @@ -49,6 +49,12 @@ function normalizeProjectStatus(status?: string): Project['status'] { return 'pending'; } +function formatProjectFps(value?: number | null): string { + if (!value || value <= 0) return '30FPS'; + const rounded = Math.round(value * 10) / 10; + return `${Number.isInteger(rounded) ? rounded.toFixed(0) : rounded.toFixed(1)}FPS`; +} + function mapProject(p: any): Project { return { id: String(p.id), @@ -56,7 +62,7 @@ function mapProject(p: any): Project { description: p.description, status: normalizeProjectStatus(p.status), frames: p.frame_count ?? 0, - fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS', + fps: formatProjectFps(p.parse_fps ?? p.original_fps), thumbnail_url: p.thumbnail_url, video_path: p.video_path, source_type: p.source_type, @@ -346,6 +352,8 @@ export interface PropagateMasksPayload { category?: string; }; template_id?: number; + source_mask_id?: string; + source_annotation_id?: number; }; direction?: 'forward' | 'backward' | 'both'; max_frames?: number; @@ -353,6 +361,19 @@ export interface PropagateMasksPayload { save_annotations?: boolean; } +export interface PropagateTaskPayload { + project_id: number; + frame_id: number; + model?: AiModelId; + steps: Array<{ + seed: PropagateMasksPayload['seed']; + direction: 'forward' | 'backward'; + max_frames: number; + }>; + include_source?: boolean; + save_annotations?: boolean; +} + export interface PropagateMasksResult { model: AiModelId; direction: string; @@ -652,6 +673,11 @@ export async function propagateMasks(payload: PropagateMasksPayload): Promise { + const response = await apiClient.post('/api/ai/propagate/task', payload); + return response.data; +} + export async function saveAnnotation(payload: SaveAnnotationPayload): Promise { const response = await apiClient.post('/api/ai/annotate', payload); return response.data; diff --git a/src/test/setup.tsx b/src/test/setup.tsx index a052e03..c34b3ad 100644 --- a/src/test/setup.tsx +++ b/src/test/setup.tsx @@ -32,7 +32,7 @@ function makeStageEvent(x = 120, y = 80) { } vi.mock('react-konva', () => ({ - Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => { + Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel, onDragEnd, scaleX, scaleY, x, y, width, height }: any) => { const coords = (event: React.MouseEvent, fallbackX: number, fallbackY: number) => ({ x: event.clientX || fallbackX, y: event.clientY || fallbackY, @@ -40,6 +40,13 @@ vi.mock('react-konva', () => ({ return (
{ const point = coords(event, 120, 80); onClick?.(makeStageEvent(point.x, point.y)); @@ -57,6 +64,12 @@ vi.mock('react-konva', () => ({ onMouseMove?.(makeStageEvent(point.x, point.y)); }} onWheel={() => onWheel?.(makeStageEvent())} + onDragEnd={(event) => onDragEnd?.({ + target: { + x: () => event.clientX || 0, + y: () => event.clientY || 0, + }, + })} > {children}