feat: 完善分割工作区传播与交互闭环
功能增加:新增后端传播任务执行器,支持异步自动传播、传播进度、结果统计、取消/重试状态同步。 功能增加:传播请求支持指定 SAM2.1 tiny/small/base+/large 权重,并记录 seed mask、source annotation 和传播范围。 功能增加:传播逻辑增加 seed 签名,未变化的 mask 二次传播会跳过,已变化的 mask 会先清理旧自动传播结果再重新生成,避免重复重叠。 功能增加:工作区增加传播范围二次选择、传播进度提示、人工/AI 标注帧红色标识、自动传播帧蓝色标识和当前帧双层边框。 功能增加:新增临时提示组件,让工具操作提示自动消失且不阻塞后续操作。 功能增加:补充项目删除、模板删除、任务失败详情、任务取消/重试等前后端联动状态。 功能增加:新增安装部署文档,补充当前需求冻结、设计冻结、接口契约、测试计划和 AGENTS/README 项目说明。 Bugfix:修复自动传播接口 404、传播后看不到任务进度、传播结果重复堆叠和已编辑帧提示不清晰的问题。 Bugfix:修复 AI 分割框选/点选交互、单候选 mask、删除选点、工作区保存与候选 mask 推送相关问题。 Bugfix:修复 Canvas 多边形顶点拖动告警、工具栏提示缺失、项目库 FPS 展示和若干 UI 文案/可用性问题。 测试:补充 AI 分割、Canvas、Dashboard、FrameTimeline、ProjectLibrary、TemplateRegistry、ToolsPalette、VideoWorkspace、API 和后端任务/AI/dashboard 测试。 验证:npm run lint;npm run test:run;python -m pytest backend/tests -q。
This commit is contained in:
19
AGENTS.md
19
AGENTS.md
@@ -78,10 +78,11 @@ Seg_Server/
|
||||
│ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
|
||||
│ │ ├── templates.py # /api/templates
|
||||
│ │ ├── media.py # /api/media/upload、/upload/dicom、/parse
|
||||
│ │ ├── ai.py # /api/ai/predict、/propagate、/models/status、/auto、/annotate
|
||||
│ │ ├── ai.py # /api/ai/predict、/propagate、/propagate/task、/models/status、/auto、/annotate
|
||||
│ │ └── export.py # /api/export/{project_id}/coco、/masks
|
||||
│ └── services/
|
||||
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
|
||||
│ ├── propagation_task_runner.py # Celery 自动传播任务 runner
|
||||
│ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理和 video predictor 传播封装
|
||||
│ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry
|
||||
│ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper;当前未被产品入口调用
|
||||
@@ -105,6 +106,7 @@ Seg_Server/
|
||||
├── OntologyInspector.tsx
|
||||
├── FrameTimeline.tsx
|
||||
├── AISegmentation.tsx
|
||||
├── TransientNotice.tsx
|
||||
└── TemplateRegistry.tsx
|
||||
```
|
||||
|
||||
@@ -115,6 +117,7 @@ Seg_Server/
|
||||
- `doc/03-frontend-element-audit.md`:哪些前端元素是真功能,哪些是 Mock/UI-only。
|
||||
- `doc/04-api-contracts.md`:前后端接口契约,以及当前不一致点。
|
||||
- `doc/05-implementation-plan.md`:建议的后续实施顺序。
|
||||
- `doc/10-installation.md`:完整安装部署流程,覆盖 PostgreSQL、Redis、MinIO、FastAPI、Celery、前端和 SAM 2.1 权重。
|
||||
|
||||
---
|
||||
|
||||
@@ -195,6 +198,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- `POST /api/tasks/{task_id}/retry`
|
||||
- `POST /api/ai/predict`
|
||||
- `POST /api/ai/propagate`
|
||||
- `POST /api/ai/propagate/task`
|
||||
- `GET /api/ai/models/status`
|
||||
- `POST /api/ai/auto`
|
||||
- `POST /api/ai/annotate`
|
||||
@@ -220,12 +224,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
||||
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表、删除项目;删除当前项目后会清空工作区当前项目、帧、mask 和选区。
|
||||
3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。
|
||||
4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。
|
||||
4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数;项目库和模板库的成功/失败短反馈使用非阻塞 `TransientNotice`,会自动消失。
|
||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在顶部进度条对应区段标为浅蓝色,当前帧位置由播放进度条末端、时间提示和缩略图高亮表达;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 后可按住顶点直接拖动并实时更新 polygon,也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。
|
||||
9. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;前端只保留一个“自动传播”按钮,会按 seed mask 和前/后方向顺序调用单 seed `POST /api/ai/propagate`,避免多个视频 tracker 并发抢占 GPU;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`CanvasArea` 会按容器和帧尺寸默认居中放大底图并保留边距;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在视频处理进度条显示为蓝色区段,人工/AI 标注帧显示红色竖线;视频处理进度条和红/蓝标识可点击跳转到对应帧;底部缩略图中人工/AI 标注帧用红色边框、自动传播/推理帧用蓝色边框,当前帧仍以青色外框高亮优先;若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边;只有进入自动传播范围选择模式时,播放进度条和视频处理进度条才显示黄色范围框,并可点击/拖拽选择传播起止帧;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,左侧 `ToolsPalette` 使用紧凑垂直布局并在高度不足时自身滚动;点击 mask 后可按住顶点直接拖动并实时更新 polygon,也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;AI 画布会按容器和当前帧尺寸默认居中放大底图并保留边距;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。
|
||||
9. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;用户可直接修改数字框,也可点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏有独立“传播权重”选择器,可为本次传播二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,也不影响 AI 单帧分割权重;前端会按传播权重 id、seed mask、seed 来源 id 和前/后方向组装 `steps` 并调用 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务;后端入队时会规范化/校验权重 id 并把规范化后的 id 写入任务 payload/result;Celery worker 顺序执行各 step,避免多个视频 tracker 并发抢占 GPU;每个 step 会根据 seed 来源 id、权重、方向和 seed 签名做幂等判断,未改变的 seed 直接跳过,已改变的 seed 会先删除同源旧自动传播标注再重传;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 权重变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`。工作区轮询 `GET /api/tasks/{task_id}` 展示进度并刷新标注,Dashboard 也能显示/取消/重试传播任务。
|
||||
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||
12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。
|
||||
@@ -242,7 +246,8 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
||||
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||
- 工作区“自动传播”按钮已接入 `POST /api/ai/propagate`;当前启用所选 SAM 2.1 变体的视频 predictor,完成后刷新后端已保存标注。
|
||||
- 工作区“自动传播”按钮已接入 `POST /api/ai/propagate/task`;若用户尚未显式设置范围,第一次点击会进入时间轴范围选择模式,第二次点击“开始传播”才提交后台任务;当前启用所选 SAM 2.1 变体的视频 predictor 后台任务,运行中轮询任务进度,完成后刷新后端已保存标注;同步 `POST /api/ai/propagate` 仍作为单 seed 兼容接口保留。
|
||||
- 工作区顶栏短状态会自动消失;保存、导出、导入 GT、传播进行中和无帧项目提示会保留到状态变化。
|
||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
||||
- 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。
|
||||
|
||||
15
README.md
15
README.md
@@ -13,8 +13,8 @@
|
||||
## 核心功能
|
||||
|
||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS,项目卡片可删除项目及其关联帧、标注和任务记录
|
||||
- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||
- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 Celery 后台 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,工作区和 AI 画布会默认居中放大底图并保留边距;支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||
- **项目工作区** — 项目创建、帧浏览、多图层标注、自动传播帧提示、进度追踪
|
||||
@@ -104,6 +104,7 @@ Seg_Server/
|
||||
│ │ ├── ai.py # SAM 推理与模型状态接口
|
||||
│ │ └── export.py # 数据导出
|
||||
│ └── services/ # 业务服务
|
||||
│ ├── propagation_task_runner.py # Celery 自动传播任务 runner
|
||||
│ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理 + video predictor 传播
|
||||
│ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry
|
||||
│ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper;当前未被产品入口调用
|
||||
@@ -125,10 +126,11 @@ Seg_Server/
|
||||
│ ├── ProjectLibrary.tsx # 项目库列表
|
||||
│ ├── VideoWorkspace.tsx # 核心分割工作区布局
|
||||
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染)
|
||||
│ ├── ToolsPalette.tsx # 左侧工具栏
|
||||
│ ├── ToolsPalette.tsx # 左侧紧凑工具栏(高度不足时滚动)
|
||||
│ ├── OntologyInspector.tsx # 右侧本体/属性检查面板
|
||||
│ ├── FrameTimeline.tsx # 底部时间轴
|
||||
│ ├── AISegmentation.tsx # AI 智能分割引擎界面
|
||||
│ ├── TransientNotice.tsx # 非阻塞自动消失短提示
|
||||
│ └── TemplateRegistry.tsx # 模板库管理
|
||||
├── models/ # SAM 2 模型权重(.pt 文件)
|
||||
├── uploads/ # 临时上传目录
|
||||
@@ -154,6 +156,7 @@ Seg_Server/
|
||||
- `doc/03-frontend-element-audit.md` — 前端逐元素功能审计,标注真实可用、部分可用、Mock/UI-only、接口不通
|
||||
- `doc/04-api-contracts.md` — 前后端接口契约和已知不一致
|
||||
- `doc/06-fastapi-docs-explained.md` — `http://192.168.3.11:8000/docs` 的作用说明
|
||||
- `doc/10-installation.md` — 独立安装部署指南,覆盖 PostgreSQL、Redis、MinIO、FastAPI、Celery、前端和 SAM 2.1 权重
|
||||
|
||||
---
|
||||
|
||||
@@ -315,7 +318,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
|
||||
- 测试 Redis 连接
|
||||
- 懒加载所选 SAM 2.1 模型;`GET /api/ai/models/status` 会返回 tiny/small/base+/large 和 GPU 的真实可用状态,`selected_model=sam3` 会返回不支持
|
||||
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
|
||||
- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`
|
||||
- `/api/ai/propagate/task` 支持从当前帧 seed 区域向视频片段创建后台传播任务:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`;同步 `/api/ai/propagate` 仍作为单 seed 兼容接口保留
|
||||
|
||||
### 步骤 6.1: 启动 Celery Worker
|
||||
|
||||
@@ -329,7 +332,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||
```
|
||||
|
||||
视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的任务进度区展示 queued/running/success/failed/cancelled 最近任务,处理中统计只计算 queued/running;WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||
视频导入只创建项目并把源视频保存到 MinIO,不会自动拆帧;用户在项目库点击“生成帧”后,再选择目标 FPS 并调用 `POST /api/media/parse`。项目库和模板库的成功/失败反馈使用非阻塞短提示,会自动消失,不再用浏览器 `alert()` 阻塞后续操作。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的任务进度区展示 queued/running/success/failed/cancelled 最近任务,处理中统计只计算 queued/running;WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||
|
||||
### 步骤 7: 安装前端依赖并构建
|
||||
|
||||
@@ -467,7 +470,7 @@ pip install -e . --no-build-isolation
|
||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||
- 工作区 SAM 2.1 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。
|
||||
- AI 页面只显示本页最新生成的 SAM 2.1 候选,不会把工作区已有 mask 带入 AI 画布;重复执行高精度分割会替换上一次 AI 页候选;新生成 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。
|
||||
- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;前端只保留一个“自动传播”按钮,会按 seed 和方向顺序调用 `/api/ai/propagate`,并在完成后刷新已保存标注。传播结果回显后,时间进度条会把自动传播生成的帧区段标为浅蓝色。
|
||||
- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;用户可直接修改数字框,也可先点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏可单独选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换;前端会把传播权重 id、seed、seed 来源 id 和方向组装为 `/api/ai/propagate/task` 后台任务。后端入队时会规范化/校验权重 id,并把规范化后的 id 写入任务 payload/result;worker 会按 seed 来源、权重、方向和 seed 签名去重,未改变的 mask 二次传播时直接跳过,已改变的 mask 会先删除同源旧自动传播标注再重传,避免同一个 mask 传播两次产生重叠。任务进度写入 `processing_tasks` 并可在 Dashboard 查看/取消/重试,工作区轮询任务状态并刷新已保存标注。传播结果回显后,视频处理进度条会把自动传播生成的帧区段标为蓝色,人工/AI 标注帧显示为红色竖线;普通状态下点击视频处理进度条或红/蓝帧标识可跳转到对应帧,底部缩略图也会用红色边框标识人工/AI 标注帧、蓝色边框标识传播/推理帧;当前帧如果同时是人工/AI 标注帧,会显示青色外框加红色内描边。
|
||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -21,10 +21,15 @@ from schemas import (
|
||||
PredictResponse,
|
||||
PropagateRequest,
|
||||
PropagateResponse,
|
||||
PropagateTaskRequest,
|
||||
ProcessingTaskOut,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
)
|
||||
from progress_events import publish_task_progress_event
|
||||
from statuses import TASK_STATUS_QUEUED
|
||||
from worker_tasks import propagate_project_masks
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -586,6 +591,66 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/propagate/task",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.model)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
for step in payload.steps:
|
||||
direction = step.direction.lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward or backward")
|
||||
seed = step.seed.model_dump(exclude_none=True)
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
task_payload = payload.model_dump(exclude_none=True)
|
||||
task_payload["model"] = model_id
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="自动传播任务已入队",
|
||||
project_id=payload.project_id,
|
||||
payload=task_payload,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = propagate_project_masks.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id)
|
||||
return task
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
|
||||
@@ -36,6 +36,7 @@ def _iso_or_none(value: datetime | None) -> str | None:
|
||||
|
||||
|
||||
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
result = task.result or {}
|
||||
return {
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
@@ -44,7 +45,7 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"raw_status": task.status,
|
||||
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
||||
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
|
||||
"error": task.error,
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ from statuses import (
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_QUEUED,
|
||||
)
|
||||
from worker_tasks import parse_project_media
|
||||
from worker_tasks import parse_project_media, propagate_project_masks
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -109,7 +109,8 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
if not project.video_path:
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
if not is_propagation_task and not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
payload = dict(previous.payload or {})
|
||||
@@ -124,13 +125,14 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
project_id=project.id,
|
||||
payload=payload,
|
||||
)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
if not is_propagation_task:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
@@ -218,6 +218,8 @@ class PropagationSeed(BaseModel):
|
||||
color: Optional[str] = None
|
||||
class_metadata: Optional[dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
source_mask_id: Optional[str] = None
|
||||
source_annotation_id: Optional[int] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
@@ -240,6 +242,21 @@ class PropagateResponse(BaseModel):
|
||||
annotations: list[AnnotationOut]
|
||||
|
||||
|
||||
class PropagateTaskStep(BaseModel):
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
|
||||
|
||||
class PropagateTaskRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2.1_hiera_tiny"
|
||||
steps: list[PropagateTaskStep]
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
512
backend/services/propagation_task_runner.py
Normal file
512
backend/services/propagation_task_runner.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Background SAM video propagation runner used by Celery workers."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import download_file
|
||||
from models import Annotation, Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
from statuses import (
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PropagationTaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted propagation task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise PropagationTaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return min(max(float(value), 0.0), 1.0)
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _stable_json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"""Return a stable signature for seed geometry and semantic attrs."""
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
"points": seed.get("points") or [],
|
||||
"labels": seed.get("labels") or [],
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
"class_metadata": seed.get("class_metadata") or {},
|
||||
"template_id": seed.get("template_id"),
|
||||
}
|
||||
return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _seed_key(seed: dict[str, Any]) -> str:
|
||||
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
return f"annotation:{source_annotation_id}"
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
return f"mask:{source_mask_id}"
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return _stable_json({
|
||||
"template_id": seed.get("template_id"),
|
||||
"class_id": class_id,
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
})
|
||||
|
||||
|
||||
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match for propagation annotations created before seed keys."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
and previous_class_id == class_id
|
||||
)
|
||||
|
||||
|
||||
def _is_propagation_annotation(
|
||||
annotation: Annotation,
|
||||
model_id: str,
|
||||
source_frame: Frame,
|
||||
seed_key: str,
|
||||
seed: dict[str, Any],
|
||||
) -> bool:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if source != f"{model_id}_propagation":
|
||||
return False
|
||||
if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id):
|
||||
return False
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key is not None:
|
||||
return previous_seed_key == seed_key
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
|
||||
|
||||
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
||||
previous_direction = mask_data.get("propagation_direction")
|
||||
return previous_direction in {None, direction}
|
||||
|
||||
|
||||
def _prepare_seed_propagation(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
model_id: str,
|
||||
source_frame: Frame,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
) -> dict[str, Any]:
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.all()
|
||||
)
|
||||
matching = [
|
||||
annotation for annotation in previous_annotations
|
||||
if _is_propagation_annotation(annotation, model_id, source_frame, seed_key, seed)
|
||||
and _direction_matches(annotation.mask_data or {}, direction)
|
||||
]
|
||||
if matching and all((annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature for annotation in matching):
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
|
||||
deleted_count = 0
|
||||
if matching:
|
||||
for annotation in matching:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"skip": False,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
|
||||
path = directory / f"{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _save_propagated_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
selected_frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
propagated: list[dict[str, Any]],
|
||||
seed: dict[str, Any],
|
||||
) -> list[Annotation]:
|
||||
created: list[Annotation] = []
|
||||
if payload.get("save_annotations", True) is False:
|
||||
return created
|
||||
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
direction = str(payload.get("current_direction") or "")
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=int(payload["project_id"]),
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"propagation_seed_key": seed_key,
|
||||
"propagation_seed_signature": seed_signature,
|
||||
"propagation_direction": direction,
|
||||
"source_annotation_id": source_annotation_id,
|
||||
"source_mask_id": source_mask_id,
|
||||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygon_bbox(polygon),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
return created
|
||||
|
||||
|
||||
def _run_one_step(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
source_position: int,
|
||||
step: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
direction = str(step.get("direction") or "forward").lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise ValueError("direction must be forward or backward")
|
||||
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
|
||||
seed = step.get("seed") or {}
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
seed_state = _prepare_seed_propagation(
|
||||
db,
|
||||
payload=payload,
|
||||
model_id=model_id,
|
||||
source_frame=source_frame,
|
||||
seed=seed,
|
||||
direction=direction,
|
||||
)
|
||||
if seed_state["skip"]:
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": 0,
|
||||
"created_annotation_count": 0,
|
||||
"deleted_annotation_count": 0,
|
||||
"skipped_seed_count": 1,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
model_id,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
|
||||
save_payload = {**payload, "current_direction": direction}
|
||||
created = _save_propagated_annotations(
|
||||
db,
|
||||
payload=save_payload,
|
||||
selected_frames=selected_frames,
|
||||
source_frame=source_frame,
|
||||
propagated=propagated,
|
||||
seed=seed,
|
||||
)
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]),
|
||||
"skipped_seed_count": 0,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
|
||||
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Run one queued SAM propagation task and update persisted progress."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
|
||||
|
||||
payload = task.payload or {}
|
||||
project_id = int(payload.get("project_id") or task.project_id or 0)
|
||||
source_frame_id = int(payload.get("frame_id") or 0)
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
except ValueError as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
|
||||
raise ValueError(f"Project not found: {project_id}")
|
||||
|
||||
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
|
||||
if not source_frame:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
|
||||
raise ValueError(f"Frame not found: {source_frame_id}")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
|
||||
raise ValueError("Source frame is not in project frame sequence")
|
||||
|
||||
steps = payload.get("steps") or []
|
||||
if not steps:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
|
||||
raise ValueError("Propagation task has no steps")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
|
||||
|
||||
step_results: list[dict[str, Any]] = []
|
||||
created_count = 0
|
||||
processed_count = 0
|
||||
deleted_count = 0
|
||||
skipped_count = 0
|
||||
total_steps = len(steps)
|
||||
|
||||
try:
|
||||
for index, step in enumerate(steps, start=1):
|
||||
_ensure_not_cancelled(db, task)
|
||||
seed_label = (step.get("seed") or {}).get("label") or "mask"
|
||||
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
|
||||
progress_before = 5 + int(((index - 1) / total_steps) * 90)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=progress_before,
|
||||
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index - 1,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = _run_one_step(
|
||||
db,
|
||||
payload=payload,
|
||||
frames=frames,
|
||||
source_frame=source_frame,
|
||||
source_position=source_position,
|
||||
step=step,
|
||||
)
|
||||
step_results.append(result)
|
||||
created_count += int(result["created_annotation_count"])
|
||||
processed_count += int(result["processed_frame_count"])
|
||||
deleted_count += int(result.get("deleted_annotation_count") or 0)
|
||||
skipped_count += int(result.get("skipped_seed_count") or 0)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=5 + int((index / total_steps) * 90),
|
||||
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": total_steps,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="自动传播完成" if created_count > 0 else (
|
||||
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
|
||||
),
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
return result
|
||||
except PropagationTaskCancelled:
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
|
||||
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task.id)
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
@@ -1,5 +1,8 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from models import Annotation, ProcessingTask
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
@@ -294,6 +297,245 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
assert len(listing.json()) == 1
|
||||
|
||||
|
||||
def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Queued Propagation"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagate-1"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"label": "胆囊",
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert response.status_code == 202
|
||||
body = response.json()
|
||||
assert body["task_type"] == "propagate_masks"
|
||||
assert body["status"] == "queued"
|
||||
assert body["celery_task_id"] == "celery-propagate-1"
|
||||
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
|
||||
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
|
||||
assert queued == [body["id"]]
|
||||
|
||||
|
||||
def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Model"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagate-model"
|
||||
|
||||
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam2",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny"
|
||||
|
||||
unsupported = client.post("/api/ai/propagate/task", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"model": "sam3",
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
assert unsupported.status_code == 400
|
||||
assert "Unsupported model" in unsupported.json()["detail"]
|
||||
|
||||
|
||||
def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Worker"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(2)
|
||||
]
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": False,
|
||||
"save_annotations": True,
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊"},
|
||||
},
|
||||
}],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
published = []
|
||||
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
||||
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress)))
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"]
|
||||
return [
|
||||
{"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]},
|
||||
{"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
result = run_propagate_project_task(db_session, task.id)
|
||||
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
assert task.result["model"] == "sam2.1_hiera_tiny"
|
||||
assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny"
|
||||
assert result["created_annotation_count"] == 1
|
||||
assert result["processed_frame_count"] == 2
|
||||
assert published[0][0] == "running"
|
||||
assert published[-1] == ("success", 100)
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
||||
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(2)
|
||||
]
|
||||
|
||||
def make_task(seed_polygon):
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": False,
|
||||
"save_annotations": True,
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [seed_polygon],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"source_annotation_id": 7,
|
||||
"source_mask_id": "annotation-7",
|
||||
},
|
||||
}],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
return task
|
||||
|
||||
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
||||
first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]]
|
||||
replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
||||
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
||||
propagate_calls = []
|
||||
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
propagate_calls.append(seed["polygons"][0])
|
||||
output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon
|
||||
return [
|
||||
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
||||
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
||||
assert first_result["created_annotation_count"] == 1
|
||||
assert len(propagate_calls) == 1
|
||||
|
||||
unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
||||
assert unchanged_result["created_annotation_count"] == 0
|
||||
assert unchanged_result["skipped_seed_count"] == 1
|
||||
assert len(propagate_calls) == 1
|
||||
assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1
|
||||
|
||||
changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id)
|
||||
assert changed_result["created_annotation_count"] == 1
|
||||
assert changed_result["deleted_annotation_count"] == 1
|
||||
assert len(propagate_calls) == 2
|
||||
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0].mask_data["polygons"] == [replacement_output_polygon]
|
||||
assert annotations[0].mask_data["source_annotation_id"] == 7
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
|
||||
@@ -110,3 +110,31 @@ def test_dashboard_overview_keeps_recent_success_tasks_in_progress_list(client,
|
||||
"updated_at": body["tasks"][0]["updated_at"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_dashboard_overview_uses_processed_frame_count_for_propagation_tasks(client, db_session):
|
||||
from models import ProcessingTask
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Propagation Project",
|
||||
"status": "ready",
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="running",
|
||||
progress=45,
|
||||
message="向后传播 胆囊 (1/2)",
|
||||
project_id=project["id"],
|
||||
payload={"project_id": project["id"]},
|
||||
result={"processed_frame_count": 8, "created_annotation_count": 3},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
response = client.get("/api/dashboard/overview")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["tasks"][0]["task_id"] == task.id
|
||||
assert body["tasks"][0]["frame_count"] == 8
|
||||
|
||||
@@ -84,6 +84,42 @@ def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
|
||||
|
||||
|
||||
def test_retry_task_dispatches_propagation_worker_without_media_requirement(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Retry Propagation"}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="failed",
|
||||
progress=100,
|
||||
message="自动传播失败",
|
||||
error="model unavailable",
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": 1,
|
||||
"steps": [],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-propagation-retry"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.tasks.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: None)
|
||||
|
||||
response = client.post(f"/api/tasks/{task.id}/retry")
|
||||
|
||||
assert response.status_code == 202
|
||||
body = response.json()
|
||||
assert body["task_type"] == "propagate_masks"
|
||||
assert body["celery_task_id"] == "celery-propagation-retry"
|
||||
assert queued == [body["id"]]
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
|
||||
|
||||
|
||||
def test_task_actions_reject_invalid_states(client, db_session):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Done",
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from celery_app import celery_app
|
||||
from database import SessionLocal
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,3 +21,16 @@ def parse_project_media(task_id: int) -> dict:
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="ai.propagate_project")
|
||||
def propagate_project_masks(task_id: int) -> dict:
|
||||
"""Run SAM video propagation for one queued task."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return run_propagate_project_task(db, task_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task_id)
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -49,16 +49,18 @@
|
||||
| 导入视频文件 | 真实可用 | 创建项目、上传源视频、刷新项目列表;不会自动拆帧 |
|
||||
| 生成帧按钮 | 真实可用 | 仅对已导入源视频且尚无帧、非 parsing 状态的项目显示,调用 `parseMedia(projectId, { parseFps })` |
|
||||
| 生成帧 FPS 滑块 | 真实可用 | 值传入 `/api/media/parse?parse_fps=...`,决定后台拆帧目标 FPS |
|
||||
| 项目卡片 FPS 徽标 | 真实可用 | 右上角显示关键帧序列目标 `parse_fps`;原始视频帧率只在卡片底部以“原 xx fps”显示 |
|
||||
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
|
||||
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
|
||||
| 删除项目按钮 | 真实可用 | 点击垃圾桶按钮会确认删除,调用 `DELETE /api/projects/{id}`,成功后从项目库移除;若删除的是当前项目,会清空工作区当前项目、帧、mask 和选区 |
|
||||
| alert 成功/失败提示 | 真实可用但粗糙 | 使用浏览器 `alert` |
|
||||
| 操作成功/失败提示 | 真实可用 | 使用非阻塞 `TransientNotice` 浮层,自动消失,不会拦截后续按钮、输入框或画布操作 |
|
||||
|
||||
## 工作区 VideoWorkspace
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 当前项目名 | 真实可用 | 读取 `currentProject.name` |
|
||||
| 顶栏操作提示 | 真实可用 | 保存、导出、传播范围选择等短反馈会自动消失;保存/导出/传播进行中和无帧项目提示会保留到状态变化 |
|
||||
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
|
||||
| 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 |
|
||||
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前启用的 SAM 2 与 GPU 状态 |
|
||||
@@ -66,16 +68,16 @@
|
||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
|
||||
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 |
|
||||
| 参考帧/起止帧/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,单个“自动传播”按钮会按 seed mask 和前/后方向顺序调用 `POST /api/ai/propagate`,当前启用 SAM 2 video predictor,完成后刷新已保存标注 |
|
||||
| 参考帧/起止帧/传播权重/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;工作区顶栏有独立“传播权重”下拉,可在传播前二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,不影响 AI 智能分割页的单帧推理权重选择;如果用户尚未显式设置范围,点击“自动传播”会先进入时间轴范围选择模式,播放进度条和视频处理进度条都可点击/拖拽回填传播起始帧和传播结束帧,再点击“开始传播”提交;用户也可直接改数字框后点击按钮传播。提交后前端把传播权重 id、seed mask、seed 来源 id 和前/后方向步骤提交到 `POST /api/ai/propagate/task`,后端先规范化/校验权重 id,再创建 `processing_tasks` 并由 Celery 执行对应 SAM 2.1 video predictor;worker 会按 seed 来源和几何/语义签名做幂等判断,未改变的 seed 直接跳过,已改变的 seed 会先删除同源旧自动传播标注再重新传播,避免重复传播产生重叠 mask;传播中顶栏显示任务进度、已处理帧次、删除旧区域数和已保存区域数,前端轮询 `GET /api/tasks/{task_id}` 并刷新已保存标注;任务可取消,若完成后 0 个新区域会明确提示没有生成新 mask 或已跳过未改变 mask |
|
||||
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}`;保存成功后会重新拉取后端标注,并用 saved annotation 替换本次提交的 draft mask,避免仍显示未保存 |
|
||||
|
||||
## CanvasArea 画布
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL |
|
||||
| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL;切换帧或容器尺寸变化时会按 86% 适配比例居中放大显示,默认留出画布边距,不铺满整个画布 |
|
||||
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
|
||||
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
|
||||
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable,拖拽结束会回写 React position state,避免 Konva 节点位置和前端状态脱节 |
|
||||
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
|
||||
| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 |
|
||||
| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
|
||||
@@ -100,17 +102,18 @@
|
||||
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;布尔选择态中第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,避免主区域和扣除区域看起来像随机阴影差异;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
|
||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口;点击工作区内已有 SAM 提示点会优先删除该提示点并重新推理,不会冒泡成新增提示点或 mask 选择 |
|
||||
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
||||
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y |
|
||||
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工作区顶栏按钮、工具栏按钮、AI 页按钮和快捷键 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z`、`Ctrl/Cmd+Y`;输入框聚焦时不拦截快捷键 |
|
||||
| 紧凑/滚动布局 | 真实可用 | 工具按钮使用较紧凑的垂直间距;左侧高度不足时工具栏自身出现纵向滚动,不挤压画布 |
|
||||
|
||||
## FrameTimeline 时间轴
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 帧缩略图 | 真实可用 | 使用 `frames[].url` |
|
||||
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
|
||||
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)`;非当前帧中,人工/AI 标注帧使用红色边框,自动传播/推理帧使用蓝色边框;当前帧仍用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边,避免状态颜色互相覆盖 |
|
||||
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
|
||||
| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` |
|
||||
| 自动传播帧进度条标记 | 真实可用 | 根据已保存标注回显的 `mask_data.source` / `propagated_from_frame_id` 识别自动传播生成的帧,并在顶部进度条对应帧区段覆盖浅蓝色;当前帧位置由播放进度条末端、时间提示和缩略图高亮表达 |
|
||||
| 播放进度条 / 视频处理进度条 | 真实可用 | 播放进度条位于上方,视频处理进度条位于下方;视频处理进度条普通状态下可点击跳转到对应帧;根据已保存标注回显的 `mask_data.source` / `propagated_from_frame_id` 识别自动传播生成的帧并显示蓝色区段,人工绘制或 AI 智能分割生成的帧显示红色竖线,红/蓝标识也可点击跳转到对应帧;未处理背景使用中性灰以和红/蓝标记区分;只有工作区进入自动传播范围选择模式时,两条进度条才显示 amber 选区,并可点击/拖拽选择起止帧 |
|
||||
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
|
||||
| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 |
|
||||
|
||||
@@ -144,6 +147,7 @@
|
||||
| 删除选中候选 | 真实可用 | 删除 AI 页当前选中的本页候选 mask;不会删除工作区已有 mask,Delete/Backspace 也遵循同一范围 |
|
||||
| 清空全体锚点 | 真实可用 | 清空 AI 页提示点和本页生成的候选 mask,不删除工作区已有 mask |
|
||||
| 背景图 / 空状态 | 真实可用 | 优先显示当前项目帧;没有项目帧时显示空状态提示,不再回退到外部演示图片 |
|
||||
| AI 画布初始视图 | 真实可用 | 当前帧在 AI 画布中默认居中,并按 86% 适配比例尽量放大但保留边距 |
|
||||
|
||||
## TemplateRegistry 模板库
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ Authorization: Bearer <token>
|
||||
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
|
||||
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url,以及 `timestamp_ms`、`source_frame_number` |
|
||||
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
|
||||
| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 |
|
||||
| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 单 seed 同步传播接口,供后端兼容和测试使用 |
|
||||
| `queuePropagationTask(payload)` | `POST /api/ai/propagate/task` | 对齐 | 工作区自动传播入口;创建 Celery 后台任务并由任务表/进度流追踪 |
|
||||
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU 和四个 SAM 2.1 变体状态;`selected_model=sam3` 返回不支持 |
|
||||
| `analyzeMask(mask, frame, options?)` | `POST /api/ai/analyze-mask` | 对齐 | 后端计算选中 mask 的置信度来源、拓扑锚点数量、面积和 bbox |
|
||||
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
|
||||
@@ -78,7 +79,8 @@ Authorization: Bearer <token>
|
||||
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
|
||||
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
|
||||
| POST | `/api/ai/predict` | 当前启用 SAM 2 点/框/interactive 推理 |
|
||||
| POST | `/api/ai/propagate` | 当前启用 SAM 2 视频片段传播并保存标注 |
|
||||
| POST | `/api/ai/propagate` | 当前启用 SAM 2 单 seed 同步视频片段传播并保存标注 |
|
||||
| POST | `/api/ai/propagate/task` | 创建 SAM 2 自动传播后台任务;payload 可包含多个 seed/direction step |
|
||||
| POST | `/api/ai/analyze-mask` | 分析前端选中 mask 的后端几何属性和拓扑锚点 |
|
||||
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
|
||||
| POST | `/api/ai/auto` | 自动分割 |
|
||||
@@ -230,7 +232,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同
|
||||
|
||||
### 视频片段传播请求体
|
||||
|
||||
后端接口仍以单个 seed 为单位。工作区前端当前只提供一个“自动传播”按钮:当前打开帧作为参考帧,该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,前端会在本地把多个 seed 或前后双向范围拆成多次顺序调用,避免同时启动多个视频 tracker。
|
||||
`POST /api/ai/propagate` 仍是单 seed 同步接口。工作区实际使用 `POST /api/ai/propagate/task`:当前打开帧作为参考帧,该帧全部 mask 作为 seed;用户设置传播起始帧和传播结束帧后,前端会在本地把多个 seed 或前后双向范围拆成 `steps`,一次提交为 `propagate_masks` 后台任务,避免长 HTTP 请求和多个视频 tracker 并发抢占 GPU。
|
||||
|
||||
单次调用示例:
|
||||
|
||||
@@ -254,7 +256,31 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同
|
||||
}
|
||||
```
|
||||
|
||||
SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` 会兼容归一化为 tiny,`model=sam3` 当前不支持。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `<model_id>_propagation`,前端回显时会把该字段保留到 `Mask.metadata`,用于把自动传播帧在时间进度条上标为浅蓝色。
|
||||
后台任务调用示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"project_id": 1,
|
||||
"frame_id": 123,
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": false,
|
||||
"save_annotations": true,
|
||||
"steps": [
|
||||
{
|
||||
"direction": "forward",
|
||||
"max_frames": 30,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` 会兼容归一化为 tiny,`model=sam3` 当前不支持。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `<model_id>_propagation`,前端回显时会把该字段保留到 `Mask.metadata`,用于在视频处理进度条上把自动传播帧显示为蓝色区段。
|
||||
后台任务入队接口会先规范化/校验 `model` 字段中的 SAM 2.1 权重 id,再把规范化后的权重 id 写入 `processing_tasks.payload.model`;前端 seed 会携带 `source_mask_id` 和可用时的 `source_annotation_id`,worker 保存传播结果时会写入 `propagation_seed_key`、`propagation_seed_signature` 和 `propagation_direction`。同一 seed、同一权重、同一方向再次传播时,如果签名未变化,worker 会跳过该 seed;如果签名变化,worker 会先删除旧自动传播标注再保存新结果。任务运行中/完成后会写入 `processing_tasks.result.model`、`completed_steps`、`processed_frame_count`、`created_annotation_count`、`deleted_annotation_count`、`skipped_seed_count` 和每个 step 的权重/方向/数量结果;前端通过 `GET /api/tasks/{task_id}` 轮询,Dashboard 同时可通过 Redis/WebSocket 进度流显示该任务。
|
||||
|
||||
## 已完成的接口对齐
|
||||
|
||||
@@ -270,6 +296,7 @@ SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2`
|
||||
- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。
|
||||
- `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`。
|
||||
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。
|
||||
- `queuePropagationTask()` 已接入 `/api/ai/propagate/task`,自动传播不再依赖长时间同步 HTTP 请求。
|
||||
- `getTask()` 已接入 `GET /api/tasks/{taskId}`。
|
||||
- `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`。
|
||||
- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`。
|
||||
|
||||
@@ -48,10 +48,11 @@
|
||||
- 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧。
|
||||
- Canvas 显示当前帧图片。
|
||||
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、视频处理进度条点击切帧、人工/AI 标注帧和自动传播帧标识点击切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
||||
- 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。
|
||||
- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps` 或 `original_fps`,格式为 `mm:ss.cc`。
|
||||
- 时间轴根据已保存标注回显的传播来源字段,把自动传播生成的帧在顶部进度条对应区段标为浅蓝色;不再使用竖线模式标记已编辑帧,当前帧位置由播放进度条末端、时间提示和缩略图高亮表达。
|
||||
- 时间轴顶部播放进度条只表达当前播放位置;其下方的视频处理进度条表达处理状态:人工绘制或 AI 智能分割生成的帧显示红色竖线,自动传播生成的帧显示蓝色区段,未处理背景使用中性灰以和标记保持明显区分。底部帧可视化栏中,人工/AI 标注帧缩略图边框为红色,自动传播/推理帧缩略图边框为蓝色,当前帧仍用青色外框高亮优先;如果当前帧同时是人工/AI 标注帧,则显示青色外框加红色内描边。
|
||||
- 自动传播提交前支持独立选择传播权重,范围限定为 SAM 2.1 tiny/small/base+/large 四个权重变体;该选择只影响传播任务,不提供 SAM2/SAM3 家族切换,也不改变 AI 智能分割页的单帧推理权重。
|
||||
|
||||
## R5 工具栏
|
||||
|
||||
@@ -101,9 +102,11 @@
|
||||
- 工作区传播功能以当前打开帧作为参考帧,并使用该帧全部 mask 作为 seed;用户不再选择“选中区域/当前帧全部”传播对象。
|
||||
- 工作区传播功能允许设置传播起始帧和传播结束帧;前端以当前参考帧为 seed,只向起止范围内位于参考帧之前和之后的帧传播,源帧不重复保存。
|
||||
- 工作区只保留一个“自动传播”按钮,点击后在指定范围内按前向/后向自动生成 mask。
|
||||
- 前端复用单 seed 后端接口;多个 seed 或双向范围会被拆成多次顺序调用 `POST /api/ai/propagate`,避免并发抢占 GPU。
|
||||
- `POST /api/ai/propagate` 当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。
|
||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `<model_id>_propagation`,并保留 label、color 和 class 元数据。
|
||||
- 前端会把多个 seed 或双向范围拆成 `steps`,通过 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务,避免长 HTTP 请求卡在浏览器侧,同时避免并发抢占 GPU。
|
||||
- `POST /api/ai/propagate` 作为单 seed 同步兼容接口保留;`POST /api/ai/propagate/task` 是工作区自动传播使用的任务接口。两者当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。
|
||||
- 自动传播任务写入 `processing_tasks`,前端轮询 `GET /api/tasks/{task_id}` 显示进度并刷新标注;Dashboard 也能看到该任务,任务可取消和重试。
|
||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `<model_id>_propagation`,并保留 label、color、class 元数据、seed 来源 id、seed 签名和传播方向。
|
||||
- 自动传播任务必须避免重复叠加:同一参考 seed、同一权重、同一方向且 seed 签名未变化时,worker 直接跳过;同一参考 seed 已变化时,worker 先删除对应旧自动传播标注,再保存新传播结果。
|
||||
- AI 页面会对未放置点提示、后端错误和返回 0 个 mask 的情况显示明确反馈。
|
||||
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
|
||||
- 后端返回 `polygons` 和 `scores`。
|
||||
|
||||
@@ -29,11 +29,13 @@
|
||||
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、删除、导入视频/DICOM、显式生成帧 |
|
||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、自动传播帧浅蓝区段标记、左右方向键切帧、播放和当前/总时长显示 |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做;紧凑垂直布局,高度不足时自身滚动 |
|
||||
| 工作区顶栏 | `src/components/VideoWorkspace.tsx` | 保存/导出/传播/导入 GT、显式撤销/重做按钮和工作区快捷键 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、播放进度、视频处理进度条、自动传播范围选择、左右方向键切帧、播放和当前/总时长显示 |
|
||||
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、后端自定义分类、mask 后端属性分析 |
|
||||
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
||||
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
|
||||
| 短提示浮层 | `src/components/TransientNotice.tsx` | 项目库和模板库的非阻塞成功/失败提示,自动消失 |
|
||||
|
||||
## 后端模块
|
||||
|
||||
@@ -49,6 +51,7 @@
|
||||
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
|
||||
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
|
||||
| AI | `backend/routers/ai.py` | 当前启用 SAM 2 推理、视频传播、模型状态和标注保存 |
|
||||
| 传播任务 | `backend/services/propagation_task_runner.py` | Celery 中执行自动传播 steps,写任务进度并保存传播标注 |
|
||||
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
|
||||
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 |
|
||||
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | 历史保留的 SAM 3 桥接源码和脚本;当前未接入 registry |
|
||||
@@ -85,7 +88,8 @@
|
||||
5. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。
|
||||
6. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。
|
||||
7. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||
8. 刷新项目列表。
|
||||
8. 刷新项目列表;项目卡片右上角 FPS 徽标显示生成关键帧序列时选择的 `parse_fps`,原始视频 FPS 仅作为底部“原 xx fps”辅助信息显示。
|
||||
9. 导入视频、生成帧、上传 DICOM 和失败反馈使用 `TransientNotice`,不再使用浏览器 `alert()` 阻塞操作;提示默认数秒后自动消失。
|
||||
|
||||
### 任务控制
|
||||
|
||||
@@ -103,8 +107,10 @@
|
||||
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。
|
||||
4. 工作区调用 `GET /api/ai/annotations` 回显已保存标注时,会替换当前项目帧中的已保存 mask,但保留没有 `annotationId` 的未保存 draft mask;这保证 AI 页推送到工作区的候选 mask 不会被异步回显覆盖,并会在合并完成后恢复仍然存在的已选 mask id。
|
||||
5. `CanvasArea` 会把全局 `selectedMaskIds` 中仍存在于当前帧的 id 同步回本地选区,避免帧初始化时的临时清空覆盖 AI 页推送过来的选中态。
|
||||
6. `FrameTimeline` 根据已保存标注回显到 `Mask.metadata` 的 `source` / `propagated_from_frame_id` 计算自动传播生成的帧,并在顶部时间进度条对应帧区段覆盖浅蓝色;当前帧不额外渲染竖线,由播放进度条末端、时间提示和缩略图高亮表达。
|
||||
7. 当前帧传入 `CanvasArea`。
|
||||
6. `CanvasArea` 根据容器和帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距;滚轮缩放和拖拽平移仍由用户后续控制。
|
||||
7. `FrameTimeline` 顶部播放进度条显示当前播放位置;其下方视频处理进度条根据 `Mask.metadata.source` / `propagated_from_frame_id` 计算自动传播帧并显示蓝色区段,对人工绘制或 AI 智能分割等非传播 mask 帧显示红色竖线。普通状态下,视频处理进度条可点击跳转到对应帧,红色人工/AI 标注帧和蓝色自动传播帧标识本身也可点击跳转。处理条未处理背景使用中性灰,和红色/蓝色标记保持明显区分。底部缩略图导航轴对非当前帧使用红色边框标识人工/AI 标注帧,使用蓝色边框标识自动传播/推理帧;当前帧使用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则以青色外框加红色内描边同时表达两个状态。工作区只有进入自动传播范围选择模式时,播放进度条和视频处理进度条才显示 amber 覆盖层,并可点击/拖拽设置传播起止帧。
|
||||
8. 当前帧传入 `CanvasArea`。
|
||||
9. 工作区顶栏短状态文本会在空闲状态下自动消失;保存、导出、导入 GT 和传播任务运行中仍保留进度状态,无帧项目提示也会保留。
|
||||
|
||||
### AI 点/框推理
|
||||
|
||||
@@ -126,8 +132,9 @@
|
||||
16. AI 页面参数开关文案只做展示增强:“局部专注模式(自动裁剪无锚区域)”仍控制 `cropMode/crop_to_prompt`,“严格除杂模式(自动清理干涉点)”仍控制 `autoDeleteBg/auto_filter_background/min_score`。
|
||||
17. AI 页面“遮罩清晰度”滑杆只调节候选 mask 的 Konva preview opacity,不写入 `Mask.segmentation`、分类元数据或后端 payload。
|
||||
18. AI 画布左上角根据正向点、反向点、边界框选和视口控制显示上下文提示,说明点击/拖拽、删除提示点和执行推理的操作方式。
|
||||
19. Canvas 按当前帧过滤并渲染 mask。
|
||||
19. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex`、`metadata.source=ai_segmentation` 和保存状态 `draft`。
|
||||
19. AI 画布根据容器和当前帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距。
|
||||
20. Canvas 按当前帧过滤并渲染 mask。
|
||||
21. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex`、`metadata.source=ai_segmentation` 和保存状态 `draft`。
|
||||
20. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;保存成功后本次提交的 draft mask id 会从本地保留列表中排除,并由后端 saved annotation 回显替换。
|
||||
21. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||
22. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||
@@ -135,15 +142,17 @@
|
||||
### 视频片段传播
|
||||
|
||||
1. 用户在工作区打开一帧作为参考帧;该帧全部 mask 都会作为传播 seed,不再提供传播对象下拉。
|
||||
2. 用户设置传播起始帧和传播结束帧,并点击唯一的“自动传播”按钮。
|
||||
2. 用户可以直接修改传播起始帧/结束帧数字框,并可通过工作区顶栏“传播权重”下拉独立选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重;该入口不提供 SAM2/SAM3 家族切换,默认跟随全局 AI 权重,用户手动选择后不再被 AI 页权重切换覆盖。
|
||||
3. `VideoWorkspace` 以当前参考帧为 seed,将起止帧拆成 `backward` 和/或 `forward` 两段;只包含当前帧时不传播。
|
||||
4. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。
|
||||
5. 前端对每个 seed、每个方向顺序调用 `POST /api/ai/propagate`,`include_source=false`、`save_annotations=true`;顺序调用是为了避免多个视频 tracker 并发抢占 GPU。
|
||||
6. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。
|
||||
7. `model` 为任一 SAM 2.1 变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播。
|
||||
8. `model=sam3` 当前不支持;SAM 3 video tracker 代码保留但没有接入产品路径。
|
||||
9. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。
|
||||
10. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注;`annotationToMask()` 会保留传播来源 metadata,供时间轴浅蓝色进度条区段显示。
|
||||
4. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color、class 元数据、`source_mask_id` 和可用时的 `source_annotation_id`。
|
||||
5. 前端把传播权重 id、每个 seed、每个方向组装成 `steps`,一次调用 `POST /api/ai/propagate/task`,`include_source=false`、`save_annotations=true`;接口先规范化/校验 `model` 字段中的权重 id,再创建 `processing_tasks.task_type=propagate_masks` 并投递 Celery,避免长 HTTP 请求阻塞前端等待。
|
||||
6. `VideoWorkspace` 记录返回的 `task_id`,轮询 `GET /api/tasks/{task_id}` 显示任务 message、步骤进度、已处理帧次和已保存区域数;任务运行期间提供取消传播按钮,调用通用 `POST /api/tasks/{task_id}/cancel`。
|
||||
7. Celery worker 逐 step 顺序执行传播,避免多个视频 tracker 并发抢占 GPU;每个 step 开始/完成都会写入 `processing_tasks.progress/result/message` 并发布 Redis `seg:progress`,Dashboard 可同步显示。每个 step 开始前,worker 会用 seed 来源 id、规范化权重 id、传播方向和 seed 签名查找旧传播标注:签名相同则跳过该 seed;签名不同则先删除对应方向的旧自动传播标注,再执行新的 video predictor 传播。
|
||||
8. 后端按项目帧序列截取片段,下载对应帧到临时目录,并写成 `000000.jpg` 这类纯数字文件名;这是 `SAM2VideoPredictor` 对视频帧排序的要求,和项目库中持久化的 `frame_%06d.jpg` 对象名无关。
|
||||
9. `model` 为任一 SAM 2.1 权重变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播;`model=sam2` 会在入队时规范化为 tiny,任务 payload/result 会保留规范化后的权重 id;单个 SAM2 video predictor 调用内部暂不提供逐帧流式进度。
|
||||
10. `model=sam3` 当前不支持;SAM 3 video tracker 代码保留但没有接入产品路径。
|
||||
11. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录权重传播来源,同时写入 `propagation_seed_key`、`propagation_seed_signature`、`propagation_direction`、`source_annotation_id` 和 `source_mask_id` 供后续幂等传播判断。
|
||||
12. 前端轮询到已创建区域后刷新 `GET /api/ai/annotations` 并回显新标注;任务结束后如果后端返回 0 个新区域,工作区会明确提示没有生成新的 mask,若是未改变 seed 被跳过则提示未改变 mask 已跳过。`annotationToMask()` 会保留传播来源 metadata,供时间轴视频处理进度条显示蓝色传播区段。
|
||||
|
||||
### 手工绘制与历史栈
|
||||
|
||||
@@ -154,7 +163,7 @@
|
||||
5. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
||||
6. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||
7. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||
8. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
||||
8. 工具栏按钮、工作区顶栏按钮和 AI 页按钮调用 `undoMasks()` / `redoMasks()`;工作区由 `VideoWorkspace` 统一处理 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z` 和 `Ctrl/Cmd+Y`,并在输入框、下拉框和可编辑文本聚焦时跳过快捷键,避免影响帧范围输入。
|
||||
|
||||
### Polygon 逐点编辑
|
||||
|
||||
@@ -203,6 +212,7 @@
|
||||
11. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
||||
12. 同一次点击会把这些已选 mask 移动到前端 `masks` 数组末尾;`CanvasArea` 按数组顺序渲染,后渲染的 Path 显示在最上层,方便用户继续编辑刚换标签的区域。该显示置顶不改变模板 `zIndex` 或后端导出语义覆盖规则。
|
||||
13. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。
|
||||
14. 模板保存、删除和 JSON 导入失败使用 `TransientNotice` 非阻塞提示,默认数秒后自动消失。
|
||||
|
||||
### 导出
|
||||
|
||||
@@ -222,7 +232,8 @@
|
||||
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。
|
||||
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。
|
||||
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`。
|
||||
- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`,作为单 seed 同步兼容接口保留。
|
||||
- `queuePropagationTask()` 使用 `POST /api/ai/propagate/task`,请求体为 `project_id`、`frame_id`、`model`、`steps`、`include_source`、`save_annotations`,返回 `ProcessingTask`。
|
||||
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
||||
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
|
||||
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
||||
@@ -235,7 +246,7 @@
|
||||
- SAM 2.1 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。
|
||||
- SAM 3 前端入口、后端 registry 入口和状态展示均已禁用;`model=sam3` 会返回不支持。
|
||||
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
|
||||
- 后端 `/api/ai/propagate` 当前支持所选 SAM 2.1 mask seed 视频传播;当前前端默认向后传播 30 帧并保存结果标注。
|
||||
- 后端 `/api/ai/propagate/task` 当前支持所选 SAM 2.1 mask seed 视频传播后台任务;同步 `/api/ai/propagate` 仍保留为单 seed 兼容接口。
|
||||
- 后端 `/api/ai/models/status` 返回 GPU 和四个 SAM 2.1 变体的真实运行状态。
|
||||
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@
|
||||
|------|----------|--------|
|
||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||
| R2 项目管理 | `src/lib/api.test.ts`, `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、项目卡片删除、DELETE 契约、后端 CRUD、删除级联、帧列表 |
|
||||
| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、回显已保存标注时保留本地未保存 draft mask、缩略图/range/自动传播帧浅蓝进度条区段标记、当前帧由进度条末端和缩略图高亮表达/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合、polygon 顶点直接拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区 SAM 提示点点击删除且不冒泡新增点、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `src/components/TransientNotice.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、项目卡片显示目标 parse_fps 而非原视频 FPS、扩展名校验、自动建项目、关联项目、创建异步任务、非阻塞自动消失操作提示、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、工作区短状态自动消失、工作区/AI 画布底图默认居中且保留边距、回显已保存标注时保留本地未保存 draft mask、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、缩略图红/蓝边框、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条和视频处理进度条选择传播范围、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/store/useStore.test.ts` | 工具切换、工具栏紧凑垂直布局和高度不足时滚动、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合、polygon 顶点直接拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区 SAM 提示点点击删除且不冒泡新增点、工作区顶栏撤销/重做按钮、撤销/重做快捷键和输入框快捷键跳过、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播创建 Celery 任务、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名 metadata、未改变 seed 跳过、已改变 seed 先删旧自动传播标注再重传、传播中轮询任务进度、传播任务取消/重试、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、保存后用后端 saved annotation 替换已提交 draft、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
|
||||
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/components/TransientNotice.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 解包/打包、后端模板 CRUD |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts`, `backend/tests/test_ai.py` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签并移动到前端渲染最上层、自定义分类 PATCH 后端模板、选中 mask 后端属性分析、重新提取拓扑锚点 |
|
||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动进度区、最近完成任务保留显示、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat |
|
||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
||||
@@ -34,13 +34,13 @@
|
||||
|------|--------|----------|----------|
|
||||
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
|
||||
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
||||
| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R4 | 工作区加载帧、无帧项目不自动解析、后端标注回显保留本地未保存 draft mask、Canvas 底图、缩略图/range/自动传播帧浅蓝进度条区段标记、当前帧由进度条末端和缩略图高亮表达/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 顶点直接拖动编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区 SAM 提示点删除优先级、撤销/重做、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py` | 已覆盖 |
|
||||
| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、项目卡片 FPS 徽标显示 `parse_fps`、视频/DICOM 拆帧任务、非阻塞自动消失操作提示、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_media.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R4 | 工作区加载帧、无帧项目不自动解析、工作区短状态自动消失、后端标注回显保留本地未保存 draft mask、Canvas/AI 底图居中适配且保留边距、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、缩略图红/蓝边框、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条/视频处理进度条拖拽选择传播范围、Canvas/AI 画布拖拽平移回写 position state、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx` | 已覆盖 |
|
||||
| R5 | 工具切换、工具栏紧凑滚动布局、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 顶点直接拖动编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区 SAM 提示点删除优先级、工作区顶栏撤销/重做按钮、撤销/重做快捷键、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `VideoWorkspace.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播 Celery 任务入队、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名 metadata、未改变 seed 跳过、已改变 seed 先删旧自动传播标注再重传、前端任务轮询进度、传播任务 runner 保存标注和结果权重 id、传播任务重试、传播空结果提示、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_tasks.py`, `test_sam2_engine.py` | 已覆盖 |
|
||||
| R7 | 保存、保存后替换已提交 draft、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
||||
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
||||
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签并置顶显示、自定义分类写入后端模板、后端属性分析、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R10 | Dashboard 概览、任务进度区、最近完成任务保留显示、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
||||
@@ -56,7 +56,10 @@
|
||||
- R6:补充 `ModelStatusBadge.test.tsx` 中 SAM 3 不展示测试,避免禁用入口重新出现在前端。
|
||||
- R6:补充后端 `selected_model=sam3` 拒绝测试和 semantic 禁用测试,避免后端继续暴露 SAM 3 产品能力。
|
||||
- R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
|
||||
- R6:补充 `propagateMasks()` API 封装和 `VideoWorkspace` 自动传播按钮测试,验证当前参考帧全部 mask 会按范围发送到后端视频传播接口。
|
||||
- R6:补充 `propagateMasks()` 同步兼容接口和 `queuePropagationTask()` 任务接口测试,验证当前参考帧全部 mask 会按范围组装为后台传播 steps。
|
||||
- R6:补充 `VideoWorkspace` 自动传播进度测试,验证传播任务运行中显示进度,后端返回 0 个新区域时给出明确反馈。
|
||||
- R4/R6:补充时间轴传播范围选择测试,验证点击“自动传播”后可在播放进度条或视频处理进度条上拖拽回填起止帧,再提交后台传播任务。
|
||||
- R6/R10:补充 `queuePropagationTask()`、`POST /api/ai/propagate/task`、传播 Celery runner 和传播任务重试测试,验证工作区自动传播不再依赖长 HTTP 请求,并验证传给 `SAM2VideoPredictor` 的临时帧文件名是纯数字序列。
|
||||
- R6:`backend/tests/test_sam3_engine.py` 已标记跳过,仅作为历史保留实现的参考测试,不计入当前产品功能覆盖。
|
||||
- R3:补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps`、`max_frames`、`target_width` 会进入任务。
|
||||
- R3:补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms`、`source_frame_number` 和 `result.frame_sequence` 元数据。
|
||||
|
||||
490
doc/10-installation.md
Normal file
490
doc/10-installation.md
Normal file
@@ -0,0 +1,490 @@
|
||||
# Installation / 部署安装指南
|
||||
|
||||
本文件记录当前仓库的真实安装和部署方式。它面向一台新的 Linux 机器,目标是跑起完整系统:
|
||||
|
||||
- React 前端:默认 `http://localhost:3000`
|
||||
- FastAPI 后端:默认 `http://localhost:8000`
|
||||
- PostgreSQL:项目、帧、模板、标注、任务元数据
|
||||
- Redis:Celery broker/result backend 与进度 pub/sub
|
||||
- MinIO:视频、DICOM、拆帧图片等对象存储
|
||||
- Celery worker:执行视频/DICOM 拆帧等后台任务
|
||||
- SAM 2.1:当前产品启用 tiny/small/base+/large;SAM 3 源码保留但产品入口禁用,正常部署不需要安装 SAM 3
|
||||
|
||||
---
|
||||
|
||||
## 1. 前置条件
|
||||
|
||||
推荐环境:
|
||||
|
||||
| 项 | 建议 |
|
||||
|----|------|
|
||||
| OS | Ubuntu 22.04 LTS 或相近 Linux |
|
||||
| Python | 3.11 |
|
||||
| Node.js | 22.x |
|
||||
| 数据库 | PostgreSQL 14+ |
|
||||
| 缓存/队列 | Redis 6+ |
|
||||
| 对象存储 | MinIO |
|
||||
| 视频处理 | FFmpeg |
|
||||
| GPU | NVIDIA GPU + CUDA,用于 SAM 2.1 推理;无 GPU 时可 CPU 运行但会很慢 |
|
||||
|
||||
安装系统依赖:
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y \
|
||||
postgresql postgresql-contrib \
|
||||
redis-server \
|
||||
ffmpeg \
|
||||
libpq-dev build-essential curl ca-certificates gnupg wget
|
||||
```
|
||||
|
||||
安装 MinIO:
|
||||
|
||||
```bash
|
||||
cd /tmp
|
||||
wget https://dl.min.io/server/minio/release/linux-amd64/minio
|
||||
chmod +x minio
|
||||
sudo mv minio /usr/local/bin/
|
||||
mkdir -p ~/minio_data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 获取代码
|
||||
|
||||
```bash
|
||||
cd ~/Desktop
|
||||
git clone <your-gitea-or-git-url> Seg_Server
|
||||
cd Seg_Server
|
||||
```
|
||||
|
||||
如果已经有仓库,进入项目根目录即可:
|
||||
|
||||
```bash
|
||||
cd /home/wkmgc/Desktop/Seg_Server
|
||||
```
|
||||
|
||||
后续命令默认在项目根目录执行,除非特别说明。
|
||||
|
||||
---
|
||||
|
||||
## 3. 配置 PostgreSQL
|
||||
|
||||
默认后端配置来自 `backend/config.py`:
|
||||
|
||||
```text
|
||||
postgresql://seguser:segpass123@localhost:5432/segserver
|
||||
```
|
||||
|
||||
创建数据库和用户:
|
||||
|
||||
```bash
|
||||
sudo systemctl start postgresql
|
||||
|
||||
sudo -u postgres psql -c "CREATE DATABASE segserver;"
|
||||
sudo -u postgres psql -c "CREATE USER seguser WITH PASSWORD 'segpass123';"
|
||||
sudo -u postgres psql -c "GRANT ALL PRIVILEGES ON DATABASE segserver TO seguser;"
|
||||
sudo -u postgres psql -d segserver -c "GRANT ALL ON SCHEMA public TO seguser;"
|
||||
sudo -u postgres psql -c "ALTER DATABASE segserver OWNER TO seguser;"
|
||||
```
|
||||
|
||||
验收:
|
||||
|
||||
```bash
|
||||
pg_isready
|
||||
psql "postgresql://seguser:segpass123@localhost:5432/segserver" -c "select 1;"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 启动 Redis 和 MinIO
|
||||
|
||||
Redis:
|
||||
|
||||
```bash
|
||||
sudo systemctl start redis-server
|
||||
redis-cli ping
|
||||
```
|
||||
|
||||
MinIO:
|
||||
|
||||
```bash
|
||||
nohup minio server ~/minio_data --console-address :9001 > /tmp/minio.log 2>&1 &
|
||||
curl http://localhost:9000/minio/health/live
|
||||
```
|
||||
|
||||
默认 MinIO 账号密码是:
|
||||
|
||||
```text
|
||||
minioadmin / minioadmin
|
||||
```
|
||||
|
||||
后端启动时会检查并创建 bucket:
|
||||
|
||||
```text
|
||||
seg-media
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 安装后端 Python 环境
|
||||
|
||||
推荐使用 Conda:
|
||||
|
||||
```bash
|
||||
conda create -n seg_server python=3.11 -y
|
||||
conda activate seg_server
|
||||
```
|
||||
|
||||
安装 PyTorch。根据机器 CUDA 版本选择合适 wheel。示例:
|
||||
|
||||
```bash
|
||||
# CUDA 12.4 示例
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# 无 GPU / CPU 示例
|
||||
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
安装后端依赖:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
pip install -r requirements.txt
|
||||
cd ..
|
||||
```
|
||||
|
||||
确认关键包:
|
||||
|
||||
```bash
|
||||
python - <<'PY'
|
||||
import fastapi, sqlalchemy, redis, celery, minio, torch
|
||||
print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
|
||||
PY
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 配置后端环境变量
|
||||
|
||||
后端从 `backend/.env` 读取配置;该文件被 `.gitignore` 忽略,不要提交真实密码或本机路径。
|
||||
|
||||
创建 `backend/.env`:
|
||||
|
||||
```bash
|
||||
cat > backend/.env <<'EOF'
|
||||
db_url=postgresql://seguser:segpass123@localhost:5432/segserver
|
||||
redis_url=redis://localhost:6379/0
|
||||
|
||||
minio_endpoint=localhost:9000
|
||||
minio_access_key=minioadmin
|
||||
minio_secret_key=minioadmin
|
||||
minio_secure=false
|
||||
|
||||
sam_default_model=sam2.1_hiera_tiny
|
||||
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt
|
||||
sam_model_config=configs/sam2.1/sam2.1_hiera_t.yaml
|
||||
|
||||
app_env=development
|
||||
cors_origins=["http://localhost:3000","http://127.0.0.1:3000"]
|
||||
EOF
|
||||
```
|
||||
|
||||
如果前端通过局域网 IP 访问,例如 `http://192.168.3.11:3000`,需要把该地址加入 `cors_origins`,同时前端也要配置 API 地址。
|
||||
|
||||
---
|
||||
|
||||
## 7. 准备 SAM 2.1 权重
|
||||
|
||||
当前产品入口只暴露 SAM 2.1 变体:
|
||||
|
||||
- `sam2.1_hiera_tiny`
|
||||
- `sam2.1_hiera_small`
|
||||
- `sam2.1_hiera_base_plus`
|
||||
- `sam2.1_hiera_large`
|
||||
|
||||
下载脚本:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python download_sam2.py
|
||||
cd ..
|
||||
```
|
||||
|
||||
脚本默认下载到:
|
||||
|
||||
```text
|
||||
/home/wkmgc/Desktop/Seg_Server/models/
|
||||
```
|
||||
|
||||
推荐文件名:
|
||||
|
||||
```text
|
||||
models/sam2.1_hiera_tiny.pt
|
||||
models/sam2.1_hiera_small.pt
|
||||
models/sam2.1_hiera_base_plus.pt
|
||||
models/sam2.1_hiera_large.pt
|
||||
```
|
||||
|
||||
可以只部署 tiny;前端会显示四个选项,但只有本地存在 checkpoint 的模型会显示可用。
|
||||
|
||||
注意:SAM 3 相关脚本和源码是历史保留。当前前端入口隐藏 SAM 3,后端 registry 不暴露 `sam3`,正常部署不需要下载 SAM 3 权重,也不要把 Hugging Face token 写进项目文件。
|
||||
|
||||
---
|
||||
|
||||
## 8. 安装前端依赖
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
如需指定前端访问的后端地址,在项目根目录创建 `.env`:
|
||||
|
||||
```bash
|
||||
cat > .env <<'EOF'
|
||||
VITE_API_BASE_URL=http://localhost:8000
|
||||
VITE_WS_PROGRESS_URL=ws://localhost:8000/ws/progress
|
||||
EOF
|
||||
```
|
||||
|
||||
如果不设置,前端会按当前浏览器 hostname 推导:
|
||||
|
||||
```text
|
||||
http://<browser-host>:8000
|
||||
ws://<browser-host>:8000/ws/progress
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. 手动启动所有服务
|
||||
|
||||
开 4 个终端分别启动。
|
||||
|
||||
终端 A:FastAPI 后端
|
||||
|
||||
```bash
|
||||
conda activate seg_server
|
||||
cd /home/wkmgc/Desktop/Seg_Server/backend
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
终端 B:Celery worker
|
||||
|
||||
```bash
|
||||
conda activate seg_server
|
||||
cd /home/wkmgc/Desktop/Seg_Server/backend
|
||||
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
||||
```
|
||||
|
||||
终端 C:前端开发服务
|
||||
|
||||
```bash
|
||||
cd /home/wkmgc/Desktop/Seg_Server
|
||||
npm run dev
|
||||
```
|
||||
|
||||
终端 D:确认基础设施
|
||||
|
||||
```bash
|
||||
pg_isready
|
||||
redis-cli ping
|
||||
curl http://localhost:9000/minio/health/live
|
||||
```
|
||||
|
||||
访问:
|
||||
|
||||
| 服务 | 地址 |
|
||||
|------|------|
|
||||
| 前端 | `http://localhost:3000` |
|
||||
| FastAPI Docs | `http://localhost:8000/docs` |
|
||||
| Health | `http://localhost:8000/health` |
|
||||
| MinIO Console | `http://localhost:9001` |
|
||||
|
||||
默认开发登录:
|
||||
|
||||
```text
|
||||
admin / 123456
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. 一键启动脚本
|
||||
|
||||
项目根目录有 `start_services.sh`:
|
||||
|
||||
```bash
|
||||
chmod +x start_services.sh
|
||||
./start_services.sh
|
||||
```
|
||||
|
||||
脚本会检查/启动:
|
||||
|
||||
```text
|
||||
PostgreSQL -> Redis -> MinIO -> FastAPI -> Celery worker -> 前端
|
||||
```
|
||||
|
||||
使用前必须检查脚本里的本机路径和 sudo 逻辑:
|
||||
|
||||
- `PROJECT_DIR="/home/wkmgc/Desktop/Seg_Server"`
|
||||
- `CONDA_ENV="seg_server"`
|
||||
- MinIO 数据目录 `/home/wkmgc/minio_data`
|
||||
- 脚本里包含本机 sudo 密码写法,迁移机器时应移除或改成安全的 systemd/service 管理方式
|
||||
|
||||
---
|
||||
|
||||
## 11. 生产构建方式
|
||||
|
||||
前端构建:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
生产模式启动前端静态服务:
|
||||
|
||||
```bash
|
||||
NODE_ENV=production npm start
|
||||
```
|
||||
|
||||
后端生产启动示例:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
Celery worker 仍需要单独启动:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
||||
```
|
||||
|
||||
实际生产建议用 systemd、supervisor 或容器编排托管 FastAPI、Celery、前端静态服务、MinIO、Redis、PostgreSQL。
|
||||
|
||||
---
|
||||
|
||||
## 12. 部署验收 Checklist
|
||||
|
||||
基础服务:
|
||||
|
||||
```bash
|
||||
pg_isready
|
||||
redis-cli ping
|
||||
curl http://localhost:9000/minio/health/live
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
后端模型状态:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/api/ai/models/status
|
||||
```
|
||||
|
||||
前端质量检查:
|
||||
|
||||
```bash
|
||||
npm run lint
|
||||
npm run test:run
|
||||
npm run build
|
||||
```
|
||||
|
||||
后端测试:
|
||||
|
||||
```bash
|
||||
conda activate seg_server
|
||||
python -m pytest backend/tests
|
||||
```
|
||||
|
||||
手工业务验收:
|
||||
|
||||
1. 打开 `http://localhost:3000`。
|
||||
2. 使用 `admin / 123456` 登录。
|
||||
3. 创建项目或上传视频。
|
||||
4. 在项目库点击“生成帧”,选择 FPS。
|
||||
5. Dashboard 中应看到任务进度;Celery 日志应显示拆帧任务。
|
||||
6. 进入分割工作区,能看到帧、时间轴和画布。
|
||||
7. 手工画一个多边形 mask,点击“结构化归档保存”。
|
||||
8. 刷新工作区后,已保存标注应回显。
|
||||
9. AI 智能分割中选择可用 SAM 2.1 模型,放置点或框,执行分割。
|
||||
10. 导出 JSON 或 PNG Mask ZIP。
|
||||
|
||||
---
|
||||
|
||||
## 13. 常见问题
|
||||
|
||||
### 前端打不开或请求后端失败
|
||||
|
||||
检查:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
cat .env
|
||||
```
|
||||
|
||||
如果通过局域网 IP 访问前端,确保:
|
||||
|
||||
- `.env` 中 `VITE_API_BASE_URL` 是浏览器可访问的后端地址。
|
||||
- `backend/.env` 中 `cors_origins` 包含前端地址。
|
||||
|
||||
### Dashboard WebSocket 经常断开
|
||||
|
||||
检查:
|
||||
|
||||
```bash
|
||||
redis-cli ping
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
同时确认前端 `VITE_WS_PROGRESS_URL` 指向真实可访问的:
|
||||
|
||||
```text
|
||||
ws://<host>:8000/ws/progress
|
||||
```
|
||||
|
||||
### 生成帧没有进度
|
||||
|
||||
检查 Celery worker 是否启动:
|
||||
|
||||
```bash
|
||||
ps aux | grep celery
|
||||
tail -f /tmp/celery.log
|
||||
```
|
||||
|
||||
检查 Redis:
|
||||
|
||||
```bash
|
||||
redis-cli ping
|
||||
```
|
||||
|
||||
### MinIO 上传失败
|
||||
|
||||
检查:
|
||||
|
||||
```bash
|
||||
curl http://localhost:9000/minio/health/live
|
||||
tail -f /tmp/minio.log
|
||||
```
|
||||
|
||||
如果磁盘空间不足,MinIO 可能拒绝写入。清理 `~/minio_data`、旧日志、旧模型权重或迁移数据目录。
|
||||
|
||||
### SAM 2 模型不可用
|
||||
|
||||
检查:
|
||||
|
||||
```bash
|
||||
ls -lh models/
|
||||
curl http://localhost:8000/api/ai/models/status
|
||||
```
|
||||
|
||||
常见原因:
|
||||
|
||||
- checkpoint 文件不存在。
|
||||
- `backend/.env` 中 `sam_model_path` 指向旧文件名。
|
||||
- `sam2` Python 包未正确安装。
|
||||
- PyTorch/CUDA 不匹配。
|
||||
|
||||
### 不需要 SAM 3
|
||||
|
||||
当前版本不用 SAM 3。不要为了正常部署执行 `backend/setup_sam3_env.sh`,也不要在项目里保存 Hugging Face token。
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
| [07-current-requirements-freeze.md](./07-current-requirements-freeze.md) | 当前版本需求冻结,测试以此为准 |
|
||||
| [08-current-design-freeze.md](./08-current-design-freeze.md) | 当前版本设计冻结,记录模块、数据流和接口边界 |
|
||||
| [09-test-plan.md](./09-test-plan.md) | 需求到测试文件的覆盖矩阵和运行命令 |
|
||||
| [10-installation.md](./10-installation.md) | 系统安装部署指南,覆盖 PostgreSQL、Redis、MinIO、后端、Celery、前端和 SAM 2.1 权重 |
|
||||
|
||||
## 状态标记
|
||||
|
||||
|
||||
@@ -169,6 +169,26 @@ describe('AISegmentation', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('handles stage drag end for move-tool canvas panning', () => {
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
expect(screen.getByTestId('konva-stage')).toHaveAttribute('data-has-drag-end', 'true');
|
||||
});
|
||||
|
||||
it('centers the active frame with a large default fit inside the AI canvas', async () => {
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, get: () => 1000 });
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientHeight', { configurable: true, get: () => 700 });
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
await waitFor(() => {
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
expect(Number(stage.getAttribute('data-scale-x'))).toBeCloseTo(1.34375, 4);
|
||||
expect(Number(stage.getAttribute('data-x'))).toBeCloseTo(70, 0);
|
||||
expect(Number(stage.getAttribute('data-y'))).toBeCloseTo(108, 0);
|
||||
});
|
||||
});
|
||||
|
||||
it('combines the AI page box prompt with later positive and negative refinement points', async () => {
|
||||
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useCallback, useEffect } from 'react';
|
||||
import React, { useState, useCallback, useEffect, useRef } from 'react';
|
||||
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Undo, Redo, Loader2, XCircle, Trash2 } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group, Rect } from 'react-konva';
|
||||
@@ -14,8 +14,10 @@ interface AISegmentationProps {
|
||||
type PromptPoint = { x: number; y: number; type: 'pos' | 'neg' };
|
||||
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
||||
type ToolHint = { title: string; body: string };
|
||||
const DEFAULT_IMAGE_FIT_RATIO = 0.86;
|
||||
|
||||
export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const canvasContainerRef = useRef<HTMLDivElement>(null);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
const masks = useStore((state) => state.masks);
|
||||
@@ -42,8 +44,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [aiMaskIds, setAiMaskIds] = useState<string[]>([]);
|
||||
|
||||
// Canvas state
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const lastAutoFitKeyRef = useRef('');
|
||||
const [points, setPoints] = useState<PromptPoint[]>([]);
|
||||
const [promptBox, setPromptBox] = useState<PromptBox | null>(null);
|
||||
const [boxStart, setBoxStart] = useState<{ x: number; y: number } | null>(null);
|
||||
@@ -59,6 +63,42 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const modelCanInfer = selectedModelStatus?.available ?? true;
|
||||
|
||||
const effectiveTool = storeActiveTool;
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (!canvasContainerRef.current) return;
|
||||
setStageSize({
|
||||
width: canvasContainerRef.current.clientWidth,
|
||||
height: canvasContainerRef.current.clientHeight,
|
||||
});
|
||||
};
|
||||
|
||||
handleResize();
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!currentFrame?.id || stageSize.width <= 0 || stageSize.height <= 0) return;
|
||||
const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) return;
|
||||
|
||||
const fitKey = `${currentFrame.id}:${stageSize.width}x${stageSize.height}:${imageWidth}x${imageHeight}`;
|
||||
if (lastAutoFitKeyRef.current === fitKey) return;
|
||||
lastAutoFitKeyRef.current = fitKey;
|
||||
|
||||
const nextScale = Math.max(
|
||||
0.05,
|
||||
Math.min(stageSize.width / imageWidth, stageSize.height / imageHeight) * DEFAULT_IMAGE_FIT_RATIO,
|
||||
);
|
||||
setScale(nextScale);
|
||||
setPosition({
|
||||
x: (stageSize.width - imageWidth * nextScale) / 2,
|
||||
y: (stageSize.height - imageHeight * nextScale) / 2,
|
||||
});
|
||||
}, [currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, stageSize.height, stageSize.width]);
|
||||
|
||||
const toolHint = React.useMemo<ToolHint | null>(() => {
|
||||
if (!currentFrame) return null;
|
||||
if (effectiveTool === 'point_pos') {
|
||||
@@ -146,6 +186,14 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
});
|
||||
};
|
||||
|
||||
const handleStageDragEnd = (e: any) => {
|
||||
const stage = e.target;
|
||||
setPosition({
|
||||
x: stage.x(),
|
||||
y: stage.y(),
|
||||
});
|
||||
};
|
||||
|
||||
const handleMouseMove = (e: any) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) return;
|
||||
@@ -557,7 +605,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
</header>
|
||||
|
||||
<div className="flex-1 relative p-8">
|
||||
<div className="w-full h-full relative border border-white/5 rounded shadow-2xl bg-[#1e1e1e] overflow-hidden cursor-crosshair">
|
||||
<div ref={canvasContainerRef} className="w-full h-full relative border border-white/5 rounded shadow-2xl bg-[#1e1e1e] overflow-hidden cursor-crosshair">
|
||||
{!currentFrame && (
|
||||
<div className="absolute inset-0 z-20 flex items-center justify-center bg-[#151515] text-xs text-gray-500">
|
||||
请先在项目库选择项目并生成帧
|
||||
@@ -570,8 +618,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
</div>
|
||||
)}
|
||||
<Stage
|
||||
width={window.innerWidth - 320 - 64}
|
||||
height={window.innerHeight - 64 - 64}
|
||||
width={stageSize.width}
|
||||
height={stageSize.height}
|
||||
onWheel={handleWheel}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseDown={handleStageMouseDown}
|
||||
@@ -582,6 +630,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
x={position.x}
|
||||
y={position.y}
|
||||
draggable={effectiveTool === 'move'}
|
||||
onDragEnd={handleStageDragEnd}
|
||||
>
|
||||
<Layer>
|
||||
{/* Background Image */}
|
||||
|
||||
@@ -292,6 +292,26 @@ describe('CanvasArea', () => {
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('handles stage drag end when the move tool pans the canvas', () => {
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
expect(screen.getByTestId('konva-stage')).toHaveAttribute('data-has-drag-end', 'true');
|
||||
});
|
||||
|
||||
it('centers the frame image with a large default fit that keeps margins', async () => {
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, get: () => 1000 });
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientHeight', { configurable: true, get: () => 700 });
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
await waitFor(() => {
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
expect(Number(stage.getAttribute('data-scale-x'))).toBeCloseTo(1.34375, 4);
|
||||
expect(Number(stage.getAttribute('data-x'))).toBeCloseTo(70, 0);
|
||||
expect(Number(stage.getAttribute('data-y'))).toBeCloseTo(108, 0);
|
||||
});
|
||||
});
|
||||
|
||||
it('publishes the selected mask ids for the ontology panel', async () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -663,7 +683,10 @@ describe('CanvasArea', () => {
|
||||
expect(selectedPaths[0]).toHaveAttribute('data-stroke', '#facc15');
|
||||
expect(selectedPaths[0]).toHaveAttribute('data-dash', '');
|
||||
expect(selectedPaths[1]).toHaveAttribute('data-stroke', '#fb7185');
|
||||
expect(selectedPaths[1]).toHaveAttribute('data-dash', '6,4');
|
||||
const scale = Number(screen.getByTestId('konva-stage').getAttribute('data-scale-x')) || 1;
|
||||
const dash = selectedPaths[1].getAttribute('data-dash')?.split(',').map(Number);
|
||||
expect(dash?.[0]).toBeCloseTo(6 / scale, 4);
|
||||
expect(dash?.[1]).toBeCloseTo(4 / scale, 4);
|
||||
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(2);
|
||||
|
||||
@@ -24,6 +24,7 @@ const EDIT_POLYGON_TOOL = 'edit_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
const POLYGON_CLOSE_RADIUS = 8;
|
||||
const DEFAULT_IMAGE_FIT_RATIO = 0.86;
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
@@ -245,6 +246,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const previousFrameIdRef = useRef<string | undefined>(frame?.id);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
const lastAutoFitKeyRef = useRef('');
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
@@ -256,8 +258,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const undoMasks = useStore((state) => state.undoMasks);
|
||||
const redoMasks = useStore((state) => state.redoMasks);
|
||||
|
||||
const effectiveTool = activeTool || storeActiveTool;
|
||||
|
||||
@@ -374,6 +374,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!frame?.id || stageSize.width <= 0 || stageSize.height <= 0) return;
|
||||
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) return;
|
||||
|
||||
const fitKey = `${frame.id}:${stageSize.width}x${stageSize.height}:${imageWidth}x${imageHeight}`;
|
||||
if (lastAutoFitKeyRef.current === fitKey) return;
|
||||
lastAutoFitKeyRef.current = fitKey;
|
||||
|
||||
const nextScale = Math.max(
|
||||
0.05,
|
||||
Math.min(stageSize.width / imageWidth, stageSize.height / imageHeight) * DEFAULT_IMAGE_FIT_RATIO,
|
||||
);
|
||||
setScale(nextScale);
|
||||
setPosition({
|
||||
x: (stageSize.width - imageWidth * nextScale) / 2,
|
||||
y: (stageSize.height - imageHeight * nextScale) / 2,
|
||||
});
|
||||
}, [frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, stageSize.height, stageSize.width]);
|
||||
|
||||
useEffect(() => {
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
@@ -458,6 +479,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
});
|
||||
};
|
||||
|
||||
const handleStageDragEnd = (e: any) => {
|
||||
const stage = e.target;
|
||||
setPosition({
|
||||
x: stage.x(),
|
||||
y: stage.y(),
|
||||
});
|
||||
};
|
||||
|
||||
const stagePoint = (e: any): CanvasPoint | null => {
|
||||
const stage = e.target.getStage();
|
||||
const relPos = stage?.getRelativePointerPosition();
|
||||
@@ -839,18 +868,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
const key = event.key.toLowerCase();
|
||||
if ((event.metaKey || event.ctrlKey) && key === 'z') {
|
||||
event.preventDefault();
|
||||
if (event.shiftKey) redoMasks();
|
||||
else undoMasks();
|
||||
return;
|
||||
}
|
||||
if ((event.metaKey || event.ctrlKey) && key === 'y') {
|
||||
event.preventDefault();
|
||||
redoMasks();
|
||||
return;
|
||||
}
|
||||
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) {
|
||||
const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex);
|
||||
if (currentPoints.length > 3) {
|
||||
@@ -880,7 +897,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, updatePolygonMask]);
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
@@ -1080,6 +1097,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
x={position.x}
|
||||
y={position.y}
|
||||
draggable={effectiveTool === 'move'}
|
||||
onDragEnd={handleStageDragEnd}
|
||||
onClick={handleStageClick}
|
||||
>
|
||||
<Layer>
|
||||
|
||||
@@ -42,7 +42,7 @@ export function Dashboard() {
|
||||
status: task.message || task.status,
|
||||
raw_status: task.status,
|
||||
error: task.error,
|
||||
frame_count: Number(task.result?.frames_extracted || 0),
|
||||
frame_count: Number(task.result?.frames_extracted || task.result?.processed_frame_count || 0),
|
||||
updated_at: task.updated_at,
|
||||
});
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ describe('FrameTimeline', () => {
|
||||
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('marks propagated frames as light-blue progress bar segments', () => {
|
||||
it('renders a processing progress bar with red annotation markers and blue propagation segments', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
@@ -76,13 +76,162 @@ describe('FrameTimeline', () => {
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getByText('自动传播 1 帧')).toBeInTheDocument();
|
||||
expect(screen.getByLabelText('视频处理进度条')).toBeInTheDocument();
|
||||
expect(screen.getByText('人工/AI 1 帧 · 自动传播 1 帧')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('current-frame-line')).not.toBeInTheDocument();
|
||||
expect(screen.getAllByTestId('propagated-frame-segment')).toHaveLength(1);
|
||||
expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-sky-200');
|
||||
expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-blue-500');
|
||||
expect(screen.getAllByTestId('annotated-frame-marker')).toHaveLength(1);
|
||||
expect(screen.getByTestId('annotated-frame-marker').className).toContain('bg-red-500');
|
||||
expect(screen.queryByLabelText('跳转到已编辑帧 3')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('jumps from the processing progress bar and frame status markers', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
|
||||
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
|
||||
],
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' },
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'f4',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Tracked',
|
||||
color: '#3b82f6',
|
||||
metadata: { source: 'sam2.1_hiera_tiny_propagation' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
const processingBar = screen.getByLabelText('视频处理进度条');
|
||||
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
|
||||
left: 0,
|
||||
right: 100,
|
||||
top: 0,
|
||||
bottom: 10,
|
||||
width: 100,
|
||||
height: 10,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
});
|
||||
fireEvent.pointerDown(processingBar, { clientX: 50, pointerId: 1 });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '跳转到人工/AI 标注帧 2' }));
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '跳转到自动传播帧 4' }));
|
||||
expect(useStore.getState().currentFrameIndex).toBe(3);
|
||||
});
|
||||
|
||||
it('hides the propagation range overlay until range selection is active', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline propagationRange={{ startFrame: 1, endFrame: 3 }} />);
|
||||
|
||||
expect(screen.queryByTestId('propagation-range-overlay')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('uses red thumbnail borders for manual or AI frames and blue for propagated frames', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' },
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'f3',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Tracked',
|
||||
color: '#3b82f6',
|
||||
metadata: { propagated_from_frame_id: 'f1' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getByAltText('frame-0').closest('div')?.className).toContain('border-cyan-500');
|
||||
expect(screen.getByAltText('frame-1').closest('div')?.className).toContain('border-red-500');
|
||||
expect(screen.getByAltText('frame-2').closest('div')?.className).toContain('border-blue-500');
|
||||
});
|
||||
|
||||
it('keeps the current frame blue border while showing an inner red ring for annotated frames', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
],
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
const currentAnnotatedTile = screen.getByAltText('frame-1').closest('div');
|
||||
expect(currentAnnotatedTile?.className).toContain('border-cyan-500');
|
||||
expect(currentAnnotatedTile?.className).toContain('inset_0_0_0_2px_rgba(239,68,68,0.95)');
|
||||
});
|
||||
|
||||
it('selects a propagation range from the playback and processing progress bars', () => {
|
||||
const onPropagationRangeChange = vi.fn();
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
|
||||
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(
|
||||
<FrameTimeline
|
||||
propagationRange={{ startFrame: 2, endFrame: 4 }}
|
||||
propagationRangeSelectionActive
|
||||
onPropagationRangeChange={onPropagationRangeChange}
|
||||
/>,
|
||||
);
|
||||
|
||||
const playbackBar = screen.getByTestId('playback-progress-bar');
|
||||
vi.spyOn(playbackBar, 'getBoundingClientRect').mockReturnValue({
|
||||
left: 0,
|
||||
right: 100,
|
||||
top: 0,
|
||||
bottom: 10,
|
||||
width: 100,
|
||||
height: 10,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
});
|
||||
fireEvent.pointerDown(playbackBar, { clientX: 25, pointerId: 1 });
|
||||
fireEvent.pointerMove(playbackBar, { clientX: 75, pointerId: 1 });
|
||||
fireEvent.pointerUp(playbackBar, { clientX: 75, pointerId: 1 });
|
||||
|
||||
expect(onPropagationRangeChange).toHaveBeenLastCalledWith(2, 4);
|
||||
expect(screen.getAllByTestId('propagation-range-overlay')).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('changes frames with left and right arrow keys without leaving bounds', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
|
||||
@@ -3,13 +3,29 @@ import { Play, Pause } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
export function FrameTimeline() {
|
||||
interface FrameTimelineProps {
|
||||
propagationRange?: {
|
||||
startFrame: number;
|
||||
endFrame: number;
|
||||
};
|
||||
propagationRangeSelectionActive?: boolean;
|
||||
propagationRangeDisabled?: boolean;
|
||||
onPropagationRangeChange?: (startFrame: number, endFrame: number) => void;
|
||||
}
|
||||
|
||||
export function FrameTimeline({
|
||||
propagationRange,
|
||||
propagationRangeSelectionActive = false,
|
||||
propagationRangeDisabled = false,
|
||||
onPropagationRangeChange,
|
||||
}: FrameTimelineProps = {}) {
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [rangeDragAnchorFrame, setRangeDragAnchorFrame] = useState<number | null>(null);
|
||||
|
||||
const totalFrames = frames.length;
|
||||
const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0;
|
||||
@@ -23,21 +39,42 @@ export function FrameTimeline() {
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
|
||||
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
|
||||
const isPropagatedMask = (mask: (typeof masks)[number]) => {
|
||||
const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : '';
|
||||
return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined;
|
||||
};
|
||||
const propagatedFrameMarkers = useMemo(() => {
|
||||
const frameIds = new Set(frames.map((frame) => frame.id));
|
||||
const propagatedIds = new Set(
|
||||
masks
|
||||
.filter((mask) => frameIds.has(mask.frameId))
|
||||
.filter((mask) => {
|
||||
const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : '';
|
||||
return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined;
|
||||
})
|
||||
.filter(isPropagatedMask)
|
||||
.map((mask) => mask.frameId),
|
||||
);
|
||||
return frames
|
||||
.map((frame, index) => ({ frame, index }))
|
||||
.filter(({ frame }) => propagatedIds.has(frame.id));
|
||||
}, [frames, masks]);
|
||||
const propagatedFrameIds = useMemo(
|
||||
() => new Set(propagatedFrameMarkers.map(({ frame }) => frame.id)),
|
||||
[propagatedFrameMarkers],
|
||||
);
|
||||
const annotatedFrameMarkers = useMemo(() => {
|
||||
const frameIds = new Set(frames.map((frame) => frame.id));
|
||||
const annotatedIds = new Set(
|
||||
masks
|
||||
.filter((mask) => frameIds.has(mask.frameId))
|
||||
.filter((mask) => !isPropagatedMask(mask))
|
||||
.map((mask) => mask.frameId),
|
||||
);
|
||||
return frames
|
||||
.map((frame, index) => ({ frame, index }))
|
||||
.filter(({ frame }) => annotatedIds.has(frame.id));
|
||||
}, [frames, masks]);
|
||||
const annotatedFrameIds = useMemo(
|
||||
() => new Set(annotatedFrameMarkers.map(({ frame }) => frame.id)),
|
||||
[annotatedFrameMarkers],
|
||||
);
|
||||
|
||||
const formatTime = (seconds: number) => {
|
||||
const safeSeconds = Math.max(0, seconds);
|
||||
@@ -47,6 +84,80 @@ export function FrameTimeline() {
|
||||
return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
const clampFrame = (frame: number) => Math.min(Math.max(frame, 1), Math.max(totalFrames, 1));
|
||||
const normalizeRange = (startFrame: number, endFrame: number) => ({
|
||||
startFrame: Math.min(clampFrame(startFrame), clampFrame(endFrame)),
|
||||
endFrame: Math.max(clampFrame(startFrame), clampFrame(endFrame)),
|
||||
});
|
||||
const selectedRange = propagationRange
|
||||
? normalizeRange(propagationRange.startFrame, propagationRange.endFrame)
|
||||
: null;
|
||||
const visibleSelectedRange = propagationRangeSelectionActive ? selectedRange : null;
|
||||
const rangeLeft = visibleSelectedRange && totalFrames > 0 ? ((visibleSelectedRange.startFrame - 1) / totalFrames) * 100 : 0;
|
||||
const rangeWidth = visibleSelectedRange && totalFrames > 0
|
||||
? ((visibleSelectedRange.endFrame - visibleSelectedRange.startFrame + 1) / totalFrames) * 100
|
||||
: 0;
|
||||
|
||||
const frameFromPointerEvent = (event: React.PointerEvent<HTMLElement>) => {
|
||||
const rect = event.currentTarget.getBoundingClientRect();
|
||||
const ratio = rect.width > 0 ? (event.clientX - rect.left) / rect.width : 0;
|
||||
return clampFrame(Math.round(Math.min(Math.max(ratio, 0), 1) * Math.max(totalFrames - 1, 0)) + 1);
|
||||
};
|
||||
|
||||
const jumpToFrame = (frame: number) => {
|
||||
if (totalFrames === 0) return;
|
||||
setIsPlaying(false);
|
||||
setCurrentFrame(clampFrame(frame) - 1);
|
||||
};
|
||||
|
||||
const updatePropagationRangeFromPointer = (
|
||||
event: React.PointerEvent<HTMLElement>,
|
||||
anchorFrame = rangeDragAnchorFrame,
|
||||
) => {
|
||||
if (!propagationRangeSelectionActive || propagationRangeDisabled || totalFrames === 0 || !onPropagationRangeChange) return;
|
||||
const frame = frameFromPointerEvent(event);
|
||||
const startFrame = anchorFrame ?? frame;
|
||||
const nextRange = normalizeRange(startFrame, frame);
|
||||
onPropagationRangeChange(nextRange.startFrame, nextRange.endFrame);
|
||||
};
|
||||
|
||||
const handleRangePointerDown = (event: React.PointerEvent<HTMLElement>) => {
|
||||
if (!propagationRangeSelectionActive || propagationRangeDisabled || totalFrames === 0 || !onPropagationRangeChange) return;
|
||||
event.preventDefault();
|
||||
setIsPlaying(false);
|
||||
const frame = frameFromPointerEvent(event);
|
||||
setRangeDragAnchorFrame(frame);
|
||||
event.currentTarget.setPointerCapture?.(event.pointerId);
|
||||
onPropagationRangeChange(frame, frame);
|
||||
};
|
||||
|
||||
const handleProcessingBarPointerDown = (event: React.PointerEvent<HTMLElement>) => {
|
||||
if (propagationRangeSelectionActive) {
|
||||
handleRangePointerDown(event);
|
||||
return;
|
||||
}
|
||||
if (totalFrames === 0) return;
|
||||
event.preventDefault();
|
||||
jumpToFrame(frameFromPointerEvent(event));
|
||||
};
|
||||
|
||||
const handleFrameMarkerClick = (event: React.MouseEvent<HTMLElement>, frame: number) => {
|
||||
event.stopPropagation();
|
||||
if (propagationRangeSelectionActive) return;
|
||||
jumpToFrame(frame);
|
||||
};
|
||||
|
||||
const handleRangePointerMove = (event: React.PointerEvent<HTMLElement>) => {
|
||||
if (rangeDragAnchorFrame === null) return;
|
||||
updatePropagationRangeFromPointer(event, rangeDragAnchorFrame);
|
||||
};
|
||||
|
||||
const handleRangePointerUp = (event: React.PointerEvent<HTMLElement>) => {
|
||||
if (rangeDragAnchorFrame === null) return;
|
||||
updatePropagationRangeFromPointer(event, rangeDragAnchorFrame);
|
||||
setRangeDragAnchorFrame(null);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPlaying || totalFrames <= 1) return;
|
||||
|
||||
@@ -99,41 +210,48 @@ export function FrameTimeline() {
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
|
||||
<div className="h-7 bg-[#0d0d0d] flex items-center group relative">
|
||||
<div className="h-36 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
|
||||
<div className="h-12 bg-[#0d0d0d] flex flex-col justify-center group relative">
|
||||
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
<div className="absolute right-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(totalSeconds)}
|
||||
</div>
|
||||
<input
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max={Math.max(totalFrames, 1)}
|
||||
value={currentFrame}
|
||||
onChange={(e) => setCurrentFrame(parseInt(e.target.value) - 1)}
|
||||
className="w-full absolute inset-0 opacity-0 cursor-ew-resize z-20"
|
||||
className={cn(
|
||||
"w-full absolute left-0 right-0 top-0 h-7 opacity-0 cursor-ew-resize z-20",
|
||||
propagationRangeSelectionActive && "pointer-events-none",
|
||||
)}
|
||||
disabled={totalFrames === 0}
|
||||
/>
|
||||
<div className="h-1 bg-white/10 w-full relative group-hover:h-2 transition-all">
|
||||
<div
|
||||
data-testid="playback-progress-bar"
|
||||
className={cn(
|
||||
"h-1 bg-white/10 w-full relative group-hover:h-2 transition-all",
|
||||
propagationRangeSelectionActive && !propagationRangeDisabled && "cursor-crosshair",
|
||||
)}
|
||||
onPointerDown={handleRangePointerDown}
|
||||
onPointerMove={handleRangePointerMove}
|
||||
onPointerUp={handleRangePointerUp}
|
||||
onPointerCancel={() => setRangeDragAnchorFrame(null)}
|
||||
>
|
||||
{visibleSelectedRange && (
|
||||
<div
|
||||
data-testid="propagation-range-overlay"
|
||||
className="absolute inset-y-0 z-10 rounded-sm border border-amber-300/80 bg-amber-300/30 shadow-[0_0_12px_rgba(251,191,36,0.45)]"
|
||||
style={{ left: `${rangeLeft}%`, width: `${rangeWidth}%` }}
|
||||
/>
|
||||
)}
|
||||
<div
|
||||
className="h-full bg-cyan-500 absolute left-0"
|
||||
style={{ width: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
/>
|
||||
{propagatedFrameMarkers.map(({ frame, index }) => {
|
||||
const left = totalFrames > 0 ? (index / totalFrames) * 100 : 0;
|
||||
const width = totalFrames > 0 ? 100 / totalFrames : 0;
|
||||
return (
|
||||
<div
|
||||
key={frame.id}
|
||||
data-testid="propagated-frame-segment"
|
||||
title={`自动传播帧 ${index + 1}`}
|
||||
className="absolute inset-y-0 z-10 bg-sky-200/80 shadow-[0_0_10px_rgba(186,230,253,0.55)]"
|
||||
style={{ left: `${left}%`, width: `${width}%` }}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<div
|
||||
className="absolute -top-7 -translate-x-1/2 rounded bg-black/80 border border-white/10 px-2 py-0.5 text-[10px] font-mono text-cyan-300 opacity-0 group-hover:opacity-100 transition-opacity pointer-events-none"
|
||||
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
@@ -141,8 +259,66 @@ export function FrameTimeline() {
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"mt-2 h-2.5 w-full relative bg-zinc-700/80 border-y border-white/10 shadow-inner",
|
||||
propagationRangeSelectionActive && !propagationRangeDisabled && "cursor-crosshair",
|
||||
)}
|
||||
aria-label="视频处理进度条"
|
||||
onPointerDown={handleProcessingBarPointerDown}
|
||||
onPointerMove={handleRangePointerMove}
|
||||
onPointerUp={handleRangePointerUp}
|
||||
onPointerCancel={() => setRangeDragAnchorFrame(null)}
|
||||
>
|
||||
{visibleSelectedRange && (
|
||||
<div
|
||||
data-testid="propagation-range-overlay"
|
||||
className="absolute inset-y-0 z-30 rounded-sm border border-amber-300/80 bg-amber-300/30 shadow-[0_0_12px_rgba(251,191,36,0.45)]"
|
||||
style={{ left: `${rangeLeft}%`, width: `${rangeWidth}%` }}
|
||||
/>
|
||||
)}
|
||||
{propagatedFrameMarkers.map(({ frame, index }) => {
|
||||
const left = totalFrames > 0 ? (index / totalFrames) * 100 : 0;
|
||||
const width = totalFrames > 0 ? 100 / totalFrames : 0;
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
key={frame.id}
|
||||
data-testid="propagated-frame-segment"
|
||||
title={`自动传播帧 ${index + 1}`}
|
||||
aria-label={`跳转到自动传播帧 ${index + 1}`}
|
||||
onPointerDown={(event) => event.stopPropagation()}
|
||||
onClick={(event) => handleFrameMarkerClick(event, index + 1)}
|
||||
className={cn(
|
||||
"absolute inset-y-0 z-10 border-0 bg-blue-500/85 p-0 shadow-[0_0_8px_rgba(59,130,246,0.65)]",
|
||||
propagationRangeSelectionActive ? "pointer-events-none" : "cursor-pointer hover:bg-blue-400",
|
||||
)}
|
||||
style={{ left: `${left}%`, width: `${width}%` }}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{annotatedFrameMarkers.map(({ frame, index }) => {
|
||||
const left = totalFrames > 1 ? (index / Math.max(totalFrames - 1, 1)) * 100 : 0;
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
key={frame.id}
|
||||
data-testid="annotated-frame-marker"
|
||||
title={`人工/AI 标注帧 ${index + 1}`}
|
||||
aria-label={`跳转到人工/AI 标注帧 ${index + 1}`}
|
||||
onPointerDown={(event) => event.stopPropagation()}
|
||||
onClick={(event) => handleFrameMarkerClick(event, index + 1)}
|
||||
className={cn(
|
||||
"absolute -top-1 z-20 h-4 w-0.5 -translate-x-1/2 rounded-full border-0 bg-red-500 p-0 shadow-[0_0_9px_rgba(239,68,68,0.8)]",
|
||||
propagationRangeSelectionActive ? "pointer-events-none" : "cursor-pointer hover:bg-red-400",
|
||||
)}
|
||||
style={{ left: `${left}%` }}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<div className="absolute bottom-0 right-3 text-[9px] font-mono text-gray-500 pointer-events-none">
|
||||
自动传播 {propagatedFrameMarkers.length} 帧
|
||||
人工/AI {annotatedFrameMarkers.length} 帧 · 自动传播 {propagatedFrameMarkers.length} 帧
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -176,13 +352,33 @@ export function FrameTimeline() {
|
||||
}
|
||||
const frame = frames[idx];
|
||||
const isCurrent = idx === currentFrameIndex;
|
||||
const isPropagatedFrame = propagatedFrameIds.has(frame.id);
|
||||
const isAnnotatedFrame = annotatedFrameIds.has(frame.id);
|
||||
return (
|
||||
<div
|
||||
key={frame.id}
|
||||
onClick={() => setCurrentFrame(idx)}
|
||||
title={
|
||||
isPropagatedFrame
|
||||
? `自动传播帧 ${idx + 1}`
|
||||
: isAnnotatedFrame
|
||||
? `人工/AI 标注帧 ${idx + 1}`
|
||||
: `视频帧 ${idx + 1}`
|
||||
}
|
||||
className={cn(
|
||||
"relative shrink-0 rounded-sm transition-all cursor-pointer flex items-center justify-center overflow-hidden group mx-0.5",
|
||||
isCurrent ? "w-28 h-16 border-2 border-cyan-500 bg-gray-700 shadow-[0_0_15px_rgba(6,182,212,0.3)] z-10" : "w-16 h-12 border border-white/5 bg-gray-800/50 opacity-40 hover:opacity-100"
|
||||
isCurrent
|
||||
? cn(
|
||||
"w-28 h-16 border-2 border-cyan-500 bg-gray-700 z-10",
|
||||
isAnnotatedFrame
|
||||
? "shadow-[inset_0_0_0_2px_rgba(239,68,68,0.95),0_0_15px_rgba(6,182,212,0.3)]"
|
||||
: "shadow-[0_0_15px_rgba(6,182,212,0.3)]",
|
||||
)
|
||||
: isPropagatedFrame
|
||||
? "w-16 h-12 border border-blue-500 bg-blue-950/30 opacity-80 shadow-[0_0_10px_rgba(59,130,246,0.22)] hover:opacity-100"
|
||||
: isAnnotatedFrame
|
||||
? "w-16 h-12 border border-red-500 bg-red-950/30 opacity-85 shadow-[0_0_10px_rgba(239,68,68,0.22)] hover:opacity-100"
|
||||
: "w-16 h-12 border border-white/5 bg-gray-800/50 opacity-40 hover:opacity-100"
|
||||
)}
|
||||
>
|
||||
{frame.url ? (
|
||||
|
||||
@@ -42,6 +42,27 @@ describe('ProjectLibrary', () => {
|
||||
expect(onProjectSelect).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('shows the generated frame sequence FPS on project cards instead of source FPS', async () => {
|
||||
apiMock.getProjects.mockResolvedValueOnce([
|
||||
{
|
||||
id: 'p-fps',
|
||||
name: 'Frame Rate Demo',
|
||||
status: 'ready',
|
||||
frames: 120,
|
||||
fps: '12FPS',
|
||||
parse_fps: 12,
|
||||
original_fps: 29.97,
|
||||
video_path: 'uploads/demo.mp4',
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
|
||||
expect(await screen.findByText('12FPS')).toBeInTheDocument();
|
||||
expect(screen.getByText('原 30.0fps')).toBeInTheDocument();
|
||||
expect(screen.queryByText('30FPS')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('creates a new project from the modal', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' });
|
||||
|
||||
@@ -74,6 +95,7 @@ describe('ProjectLibrary', () => {
|
||||
})));
|
||||
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
||||
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||
expect(await screen.findByRole('status')).toHaveTextContent('视频导入成功');
|
||||
});
|
||||
|
||||
it('generates frames from an imported video with the selected FPS', async () => {
|
||||
@@ -89,6 +111,8 @@ describe('ProjectLibrary', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始生成帧' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 }));
|
||||
expect(await screen.findByRole('status')).toHaveTextContent('生成帧任务已入队 #22');
|
||||
expect(await screen.findByText('12FPS')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('deletes a project from the project card without entering the workspace', async () => {
|
||||
@@ -129,5 +153,6 @@ describe('ProjectLibrary', () => {
|
||||
|
||||
await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm]));
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('77');
|
||||
expect(await screen.findByRole('status')).toHaveTextContent('DICOM 上传成功: 1 个文件');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,6 +4,7 @@ import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch, deleteProject } from '../lib/api';
|
||||
import type { Project } from '../store/useStore';
|
||||
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
|
||||
|
||||
interface ProjectLibraryProps {
|
||||
onProjectSelect: () => void;
|
||||
@@ -31,9 +32,24 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
const [frameParseFps, setFrameParseFps] = useState(30);
|
||||
const [isGeneratingFrames, setIsGeneratingFrames] = useState(false);
|
||||
const [deletingProjectId, setDeletingProjectId] = useState<string | null>(null);
|
||||
const [notice, setNotice] = useState<NoticeState | null>(null);
|
||||
const videoInputRef = useRef<HTMLInputElement>(null);
|
||||
const dicomInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const showNotice = (message: string, tone: NoticeTone = 'info') => {
|
||||
setNotice({ id: Date.now(), message, tone });
|
||||
};
|
||||
|
||||
const frameSequenceLabel = (project: Project) => {
|
||||
if (project.source_type === 'dicom') return 'DICOM';
|
||||
if (project.video_path && (project.frames ?? 0) === 0 && project.status !== 'parsing') return '待生成帧';
|
||||
if (project.parse_fps && project.parse_fps > 0) {
|
||||
const rounded = Math.round(project.parse_fps * 10) / 10;
|
||||
return `${Number.isInteger(rounded) ? rounded.toFixed(0) : rounded.toFixed(1)}FPS`;
|
||||
}
|
||||
return project.fps || '30FPS';
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
setIsLoading(true);
|
||||
getProjects()
|
||||
@@ -81,7 +97,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Delete project failed:', err);
|
||||
alert('删除项目失败,请检查后端服务');
|
||||
showNotice('删除项目失败,请检查后端服务', 'error');
|
||||
} finally {
|
||||
setDeletingProjectId(null);
|
||||
}
|
||||
@@ -102,12 +118,12 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
description: `导入于 ${new Date().toLocaleString()}`,
|
||||
});
|
||||
const result = await uploadMedia(pendingFile, String(newProject.id));
|
||||
alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`);
|
||||
showNotice(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`, 'success');
|
||||
const data = await getProjects();
|
||||
setProjects(data);
|
||||
} catch (err) {
|
||||
console.error('Upload failed:', err);
|
||||
alert('上传失败,请检查后端服务');
|
||||
showNotice('上传失败,请检查后端服务', 'error');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
setPendingFile(null);
|
||||
@@ -127,14 +143,14 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
setIsGeneratingFrames(true);
|
||||
try {
|
||||
const task = await parseMedia(frameProject.id, { parseFps: frameParseFps });
|
||||
alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`);
|
||||
showNotice(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`, 'success');
|
||||
const data = await getProjects();
|
||||
setProjects(data);
|
||||
setShowFrameConfig(false);
|
||||
setFrameProject(null);
|
||||
} catch (err) {
|
||||
console.error('Frame generation failed:', err);
|
||||
alert('生成帧失败,请检查后端服务或项目源文件');
|
||||
showNotice('生成帧失败,请检查后端服务或项目源文件', 'error');
|
||||
} finally {
|
||||
setIsGeneratingFrames(false);
|
||||
}
|
||||
@@ -144,19 +160,19 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
if (!files || files.length === 0) return;
|
||||
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
|
||||
if (dcmFiles.length === 0) {
|
||||
alert('未选择有效的 .dcm 文件');
|
||||
showNotice('未选择有效的 .dcm 文件', 'error');
|
||||
return;
|
||||
}
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const result = await uploadDicomBatch(dcmFiles);
|
||||
await parseMedia(String(result.project_id));
|
||||
alert(`DICOM 上传成功: ${result.uploaded_count} 个文件`);
|
||||
showNotice(`DICOM 上传成功: ${result.uploaded_count} 个文件`, 'success');
|
||||
const data = await getProjects();
|
||||
setProjects(data);
|
||||
} catch (err) {
|
||||
console.error('DICOM upload failed:', err);
|
||||
alert('DICOM 上传失败,请检查后端服务');
|
||||
showNotice('DICOM 上传失败,请检查后端服务', 'error');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
if (dicomInputRef.current) dicomInputRef.current.value = '';
|
||||
@@ -175,6 +191,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
|
||||
<div className="flex justify-between items-end mb-8 border-b border-white/5 pb-6">
|
||||
<div>
|
||||
<h1 className="text-3xl font-medium tracking-tight text-white mb-2">视频与连续帧项目库</h1>
|
||||
@@ -263,7 +280,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
)}
|
||||
<div className="absolute top-2 right-2 flex gap-2">
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))}
|
||||
{frameSequenceLabel(proj)}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'ready' ? (
|
||||
|
||||
@@ -83,6 +83,32 @@ describe('TemplateRegistry', () => {
|
||||
expect(screen.getByText('分类A')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows JSON import errors as transient notices instead of blocking alerts', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.click(screen.getByText('批量导入'));
|
||||
fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), {
|
||||
target: { value: '{broken-json' },
|
||||
});
|
||||
fireEvent.click(screen.getByRole('button', { name: '导入' }));
|
||||
|
||||
expect(await screen.findByRole('status')).toHaveTextContent('JSON 解析失败');
|
||||
});
|
||||
|
||||
it('shows template save errors as transient notices', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
apiMock.createTemplate.mockRejectedValueOnce(new Error('boom'));
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'Bad Template' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存' }));
|
||||
|
||||
expect(await screen.findByRole('status')).toHaveTextContent('保存失败,请查看控制台');
|
||||
});
|
||||
|
||||
it('edits an existing template through the backend and store', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@ import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api';
|
||||
import type { Template, TemplateClass } from '../store/useStore';
|
||||
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
|
||||
|
||||
// HSL to Hex color generator
|
||||
function hslToHex(h: number, s: number, l: number): string {
|
||||
@@ -59,6 +60,11 @@ export function TemplateRegistry() {
|
||||
const [editClasses, setEditClasses] = useState<TemplateClass[]>([]);
|
||||
const [editingClassId, setEditingClassId] = useState<string | null>(null);
|
||||
const [dragOverIndex, setDragOverIndex] = useState<number | null>(null);
|
||||
const [notice, setNotice] = useState<NoticeState | null>(null);
|
||||
|
||||
const showNotice = (message: string, tone: NoticeTone = 'info') => {
|
||||
setNotice({ id: Date.now(), message, tone });
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
setIsLoading(true);
|
||||
@@ -106,7 +112,7 @@ export function TemplateRegistry() {
|
||||
setShowModal(false);
|
||||
} catch (err) {
|
||||
console.error('Failed to save template:', err);
|
||||
alert('保存失败,请查看控制台');
|
||||
showNotice('保存失败,请查看控制台', 'error');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
@@ -122,6 +128,7 @@ export function TemplateRegistry() {
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to delete template:', err);
|
||||
showNotice('删除失败,请检查后端服务', 'error');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -168,7 +175,7 @@ export function TemplateRegistry() {
|
||||
colors = data.colors;
|
||||
names = data.names;
|
||||
} else {
|
||||
alert('格式错误:请提供 [[colors...], [names...]] 或 {colors, names}');
|
||||
showNotice('格式错误:请提供 [[colors...], [names...]] 或 {colors, names}', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -188,7 +195,7 @@ export function TemplateRegistry() {
|
||||
setShowImport(false);
|
||||
setImportText('');
|
||||
} catch (e) {
|
||||
alert('JSON 解析失败');
|
||||
showNotice('JSON 解析失败', 'error');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -212,6 +219,7 @@ export function TemplateRegistry() {
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
|
||||
<div className="mb-8 border-b border-white/5 pb-6">
|
||||
<h1 className="text-3xl font-medium tracking-tight text-white mb-2">分割模板与分类优先级管理库</h1>
|
||||
<p className="text-gray-400 text-sm">定义业务语义本体树架构、约束覆盖遮罩优先级(Z-Index裁决权重),以及真实标签数据的向下兼容转换映射(Dict Translation)原则。</p>
|
||||
|
||||
@@ -42,4 +42,13 @@ describe('ToolsPalette', () => {
|
||||
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
|
||||
expect(onTriggerAI).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('uses compact vertically scrollable layout for smaller workspaces', () => {
|
||||
const { container } = render(<ToolsPalette activeTool="move" setActiveTool={vi.fn()} />);
|
||||
const palette = container.firstElementChild;
|
||||
|
||||
expect(palette).toHaveClass('overflow-y-auto');
|
||||
expect(screen.getByTitle('创建多边形 (P)')).toHaveClass('h-9');
|
||||
expect(screen.getByTitle('打开 AI 智能分割')).toHaveClass('h-9');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,8 +40,8 @@ export function ToolsPalette({
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="w-12 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-center py-4 shrink-0 z-10">
|
||||
<div className="flex flex-col gap-4 w-full px-2">
|
||||
<div className="h-full w-12 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-center py-2 shrink-0 z-10 overflow-y-auto overflow-x-hidden overscroll-contain">
|
||||
<div className="flex flex-col gap-1.5 w-full px-1.5">
|
||||
{tools.map(tool => {
|
||||
const Icon = tool.icon;
|
||||
const isActive = activeTool === tool.id;
|
||||
@@ -51,7 +51,7 @@ export function ToolsPalette({
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2",
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5",
|
||||
isActive
|
||||
? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500"
|
||||
: tool.id.includes('merge') ? "bg-green-500/10 text-green-500"
|
||||
@@ -59,12 +59,12 @@ export function ToolsPalette({
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white"
|
||||
)}
|
||||
>
|
||||
<Icon size={18} strokeWidth={isActive ? 2.5 : 2} />
|
||||
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-1" />
|
||||
<div className="w-full h-px bg-white/10 my-0.5" />
|
||||
|
||||
{aiTools.map(tool => {
|
||||
const Icon = tool.icon;
|
||||
@@ -75,13 +75,13 @@ export function ToolsPalette({
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2 border",
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border",
|
||||
isActive
|
||||
? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]`
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white border-transparent"
|
||||
)}
|
||||
>
|
||||
<Icon size={18} strokeWidth={isActive ? 2.5 : 2} />
|
||||
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
@@ -93,32 +93,32 @@ export function ToolsPalette({
|
||||
}}
|
||||
title="打开 AI 智能分割"
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all",
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all",
|
||||
activeTool === 'sam_trigger'
|
||||
? "bg-cyan-600 text-white shadow-lg shadow-cyan-900/20"
|
||||
: "text-gray-500 hover:bg-white/5"
|
||||
)}
|
||||
>
|
||||
<Wand2 size={18} strokeWidth={2} />
|
||||
<Wand2 size={16} strokeWidth={2} />
|
||||
</button>
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-1" />
|
||||
<div className="w-full h-px bg-white/10 my-0.5" />
|
||||
|
||||
<button
|
||||
onClick={onUndo}
|
||||
disabled={!canUndo}
|
||||
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="撤销操作 (Ctrl+Z)"
|
||||
>
|
||||
<Undo size={18} />
|
||||
<Undo size={16} />
|
||||
</button>
|
||||
<button
|
||||
onClick={onRedo}
|
||||
disabled={!canRedo}
|
||||
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="重做操作 (Ctrl+Shift+Z)"
|
||||
>
|
||||
<Redo size={18} />
|
||||
<Redo size={16} />
|
||||
</button>
|
||||
|
||||
</div>
|
||||
|
||||
28
src/components/TransientNotice.test.tsx
Normal file
28
src/components/TransientNotice.test.tsx
Normal file
@@ -0,0 +1,28 @@
|
||||
import { act, render, screen } from '@testing-library/react';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { TransientNotice } from './TransientNotice';
|
||||
|
||||
describe('TransientNotice', () => {
|
||||
it('renders a non-blocking notice and dismisses it after the timeout', () => {
|
||||
vi.useFakeTimers();
|
||||
const onDismiss = vi.fn();
|
||||
|
||||
render(
|
||||
<TransientNotice
|
||||
notice={{ id: 1, message: '操作已完成', tone: 'success' }}
|
||||
onDismiss={onDismiss}
|
||||
durationMs={1000}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByRole('status')).toHaveTextContent('操作已完成');
|
||||
expect(screen.getByRole('status').parentElement).toHaveClass('pointer-events-none');
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(1000);
|
||||
});
|
||||
|
||||
expect(onDismiss).toHaveBeenCalledTimes(1);
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
47
src/components/TransientNotice.tsx
Normal file
47
src/components/TransientNotice.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import { cn } from '../lib/utils';
|
||||
|
||||
export type NoticeTone = 'info' | 'success' | 'error';
|
||||
|
||||
export interface NoticeState {
|
||||
id: number;
|
||||
message: string;
|
||||
tone?: NoticeTone;
|
||||
}
|
||||
|
||||
interface TransientNoticeProps {
|
||||
notice: NoticeState | null;
|
||||
onDismiss: () => void;
|
||||
durationMs?: number;
|
||||
}
|
||||
|
||||
const toneClasses: Record<NoticeTone, string> = {
|
||||
info: 'border-cyan-400/30 bg-cyan-950/85 text-cyan-100 shadow-cyan-950/30',
|
||||
success: 'border-emerald-400/30 bg-emerald-950/85 text-emerald-100 shadow-emerald-950/30',
|
||||
error: 'border-red-400/30 bg-red-950/85 text-red-100 shadow-red-950/30',
|
||||
};
|
||||
|
||||
export function TransientNotice({ notice, onDismiss, durationMs = 3600 }: TransientNoticeProps) {
|
||||
useEffect(() => {
|
||||
if (!notice) return undefined;
|
||||
const timer = window.setTimeout(onDismiss, durationMs);
|
||||
return () => window.clearTimeout(timer);
|
||||
}, [durationMs, notice, onDismiss]);
|
||||
|
||||
if (!notice) return null;
|
||||
|
||||
const tone = notice.tone || 'info';
|
||||
return (
|
||||
<div className="pointer-events-none fixed right-6 top-6 z-[80] max-w-sm" aria-live="polite">
|
||||
<div
|
||||
role="status"
|
||||
className={cn(
|
||||
'rounded-md border px-4 py-3 text-xs font-medium leading-relaxed shadow-2xl backdrop-blur whitespace-pre-line',
|
||||
toneClasses[tone],
|
||||
)}
|
||||
>
|
||||
{notice.message}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
@@ -8,7 +8,9 @@ const apiMock = vi.hoisted(() => ({
|
||||
getProjectFrames: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
propagateMasks: vi.fn(),
|
||||
queuePropagationTask: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
getTemplates: vi.fn(),
|
||||
getProjectAnnotations: vi.fn(),
|
||||
saveAnnotation: vi.fn(),
|
||||
@@ -26,7 +28,9 @@ vi.mock('../lib/api', () => ({
|
||||
getProjectFrames: apiMock.getProjectFrames,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
propagateMasks: apiMock.propagateMasks,
|
||||
queuePropagationTask: apiMock.queuePropagationTask,
|
||||
getTask: apiMock.getTask,
|
||||
cancelTask: apiMock.cancelTask,
|
||||
getTemplates: apiMock.getTemplates,
|
||||
getProjectAnnotations: apiMock.getProjectAnnotations,
|
||||
saveAnnotation: apiMock.saveAnnotation,
|
||||
@@ -48,7 +52,15 @@ describe('VideoWorkspace', () => {
|
||||
apiMock.getTemplates.mockResolvedValue([]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValue([]);
|
||||
apiMock.annotationToMask.mockReturnValue(null);
|
||||
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
|
||||
apiMock.queuePropagationTask.mockResolvedValue({ id: 31, status: 'queued', progress: 0, message: '自动传播任务已入队' });
|
||||
apiMock.getTask.mockResolvedValue({
|
||||
id: 31,
|
||||
status: 'success',
|
||||
progress: 100,
|
||||
message: '自动传播完成',
|
||||
result: { processed_frame_count: 3, created_annotation_count: 2, completed_steps: 1 },
|
||||
});
|
||||
apiMock.cancelTask.mockResolvedValue({ id: 31, status: 'cancelled', progress: 100, message: '任务已取消' });
|
||||
apiMock.propagateMasks.mockResolvedValue({
|
||||
model: 'sam2.1_hiera_tiny',
|
||||
direction: 'forward',
|
||||
@@ -81,6 +93,60 @@ describe('VideoWorkspace', () => {
|
||||
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('exposes workspace undo/redo buttons and keyboard shortcuts without hijacking inputs', async () => {
|
||||
const mask = {
|
||||
id: 'mask-undo',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
};
|
||||
useStore.setState({
|
||||
currentProject: null,
|
||||
masks: [mask],
|
||||
maskHistory: [[]],
|
||||
maskFuture: [],
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '撤销操作' }));
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '重做操作' }));
|
||||
expect(useStore.getState().masks).toEqual([mask]);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'z', ctrlKey: true });
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'z', ctrlKey: true, shiftKey: true });
|
||||
expect(useStore.getState().masks).toEqual([mask]);
|
||||
|
||||
fireEvent.keyDown(screen.getByLabelText('传播起始帧'), { key: 'z', ctrlKey: true });
|
||||
expect(useStore.getState().masks).toEqual([mask]);
|
||||
});
|
||||
|
||||
it('auto-dismisses short workspace operation messages without blocking later actions', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
|
||||
vi.useFakeTimers();
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
expect(screen.getByText('没有待保存标注')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(3600);
|
||||
});
|
||||
|
||||
expect(screen.queryByText('没有待保存标注')).not.toBeInTheDocument();
|
||||
expect(screen.getByRole('button', { name: '结构化归档保存' })).not.toBeDisabled();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('does not auto-generate frames when a media project has no frames yet', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([]);
|
||||
|
||||
@@ -417,26 +483,190 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
expect(apiMock.queuePropagationTask).not.toHaveBeenCalled();
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
|
||||
await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
model: 'sam2.1_hiera_tiny',
|
||||
direction: 'forward',
|
||||
max_frames: 2,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed: {
|
||||
steps: [{
|
||||
direction: 'forward',
|
||||
max_frames: 2,
|
||||
seed: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
points: undefined,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
template_id: 2,
|
||||
source_mask_id: 'mask-1',
|
||||
source_annotation_id: undefined,
|
||||
},
|
||||
}],
|
||||
}));
|
||||
await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,删除旧区域 0 个,保存 2 个区域')).toBeInTheDocument());
|
||||
});
|
||||
|
||||
it('uses the separately selected propagation weight when queueing propagation', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
points: undefined,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
template_id: 2,
|
||||
},
|
||||
}));
|
||||
await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,保存 2 个区域')).toBeInTheDocument());
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
aiModel: 'sam2.1_hiera_tiny',
|
||||
masks: [{
|
||||
id: 'mask-propagation-model',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
const propagationWeightSelect = screen.getByLabelText('传播权重');
|
||||
fireEvent.change(propagationWeightSelect, { target: { value: 'sam2.1_hiera_small' } });
|
||||
expect(propagationWeightSelect).toHaveValue('sam2.1_hiera_small');
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'sam2.1_hiera_small',
|
||||
})));
|
||||
await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,删除旧区域 0 个,保存 2 个区域')).toBeInTheDocument());
|
||||
});
|
||||
|
||||
it('shows propagation task progress and reports empty results', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
},
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
apiMock.queuePropagationTask.mockResolvedValueOnce({ id: 44, status: 'queued', progress: 0, message: '自动传播任务已入队' });
|
||||
apiMock.getTask.mockResolvedValueOnce({
|
||||
id: 44,
|
||||
status: 'success',
|
||||
progress: 100,
|
||||
message: '自动传播完成,但没有生成新的 mask',
|
||||
result: { processed_frame_count: 2, created_annotation_count: 0, completed_steps: 1 },
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-progress',
|
||||
frameId: 10 as unknown as string,
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
|
||||
|
||||
const progressPanel = await screen.findByLabelText('自动传播进度');
|
||||
expect(progressPanel).toBeInTheDocument();
|
||||
expect(within(progressPanel).getByText('0%')).toBeInTheDocument();
|
||||
|
||||
expect(await screen.findByText(/没有生成新的 mask/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('lets users select the propagation range on the timeline before queueing', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
|
||||
{ id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 },
|
||||
{ id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
},
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(5));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-timeline-range',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
const processingBar = screen.getByLabelText('视频处理进度条');
|
||||
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
|
||||
left: 0,
|
||||
right: 100,
|
||||
top: 0,
|
||||
bottom: 10,
|
||||
width: 100,
|
||||
height: 10,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
});
|
||||
fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 });
|
||||
fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 });
|
||||
fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 });
|
||||
|
||||
expect(screen.getByLabelText('传播起始帧')).toHaveValue(2);
|
||||
expect(screen.getByLabelText('传播结束帧')).toHaveValue(5);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
frame_id: 10,
|
||||
steps: [expect.objectContaining({ direction: 'forward', max_frames: 5 })],
|
||||
})));
|
||||
});
|
||||
|
||||
it('auto-propagates all reference-frame masks in both directions inside the selected range', async () => {
|
||||
@@ -445,13 +675,16 @@ describe('VideoWorkspace', () => {
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.propagateMasks.mockResolvedValue({
|
||||
model: 'sam2.1_hiera_tiny',
|
||||
direction: 'forward',
|
||||
source_frame_id: 11,
|
||||
processed_frame_count: 2,
|
||||
created_annotation_count: 1,
|
||||
annotations: [],
|
||||
apiMock.getTask.mockResolvedValue({
|
||||
id: 31,
|
||||
status: 'success',
|
||||
progress: 100,
|
||||
message: '自动传播完成',
|
||||
result: {
|
||||
processed_frame_count: 8,
|
||||
created_annotation_count: 4,
|
||||
completed_steps: 4,
|
||||
},
|
||||
});
|
||||
apiMock.buildAnnotationPayload
|
||||
.mockReturnValueOnce({
|
||||
@@ -505,27 +738,14 @@ describe('VideoWorkspace', () => {
|
||||
fireEvent.change(screen.getByLabelText('传播结束帧'), { target: { value: '3' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledTimes(4));
|
||||
expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(1, expect.objectContaining({
|
||||
direction: 'backward',
|
||||
max_frames: 2,
|
||||
seed: expect.objectContaining({ label: '胆囊' }),
|
||||
}));
|
||||
expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(2, expect.objectContaining({
|
||||
direction: 'forward',
|
||||
max_frames: 2,
|
||||
seed: expect.objectContaining({ label: '胆囊' }),
|
||||
}));
|
||||
expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(3, expect.objectContaining({
|
||||
direction: 'backward',
|
||||
max_frames: 2,
|
||||
seed: expect.objectContaining({ label: '肝脏' }),
|
||||
}));
|
||||
expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(4, expect.objectContaining({
|
||||
direction: 'forward',
|
||||
max_frames: 2,
|
||||
seed: expect.objectContaining({ label: '肝脏' }),
|
||||
}));
|
||||
await waitFor(() => expect(screen.getByText('已自动传播 2 个参考 mask,处理 8 帧次,保存 4 个区域')).toBeInTheDocument());
|
||||
await waitFor(() => expect(apiMock.queuePropagationTask).toHaveBeenCalledTimes(1));
|
||||
const queuedPayload = apiMock.queuePropagationTask.mock.calls[0][0];
|
||||
expect(queuedPayload.steps).toEqual([
|
||||
expect.objectContaining({ direction: 'backward', max_frames: 2, seed: expect.objectContaining({ label: '胆囊' }) }),
|
||||
expect.objectContaining({ direction: 'forward', max_frames: 2, seed: expect.objectContaining({ label: '胆囊' }) }),
|
||||
expect.objectContaining({ direction: 'backward', max_frames: 2, seed: expect.objectContaining({ label: '肝脏' }) }),
|
||||
expect.objectContaining({ direction: 'forward', max_frames: 2, seed: expect.objectContaining({ label: '肝脏' }) }),
|
||||
]);
|
||||
await waitFor(() => expect(screen.getByText('已自动传播 2 个参考 mask,处理 8 帧次,删除旧区域 0 个,保存 4 个区域')).toBeInTheDocument());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { Redo, Undo } from 'lucide-react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import {
|
||||
annotationToMask,
|
||||
buildAnnotationPayload,
|
||||
cancelTask,
|
||||
deleteAnnotation,
|
||||
exportCoco,
|
||||
exportMasks,
|
||||
getProjectAnnotations,
|
||||
getProjectFrames,
|
||||
getTask,
|
||||
getTemplates,
|
||||
importGtMask,
|
||||
propagateMasks,
|
||||
queuePropagationTask,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
} from '../lib/api';
|
||||
@@ -19,9 +22,20 @@ import { ToolsPalette } from './ToolsPalette';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
import type { Frame, Mask } from '../store/useStore';
|
||||
import { DEFAULT_AI_MODEL_ID, SAM2_MODEL_OPTIONS, type AiModelId, type Frame, type Mask } from '../store/useStore';
|
||||
|
||||
type PropagationDirection = 'forward' | 'backward';
|
||||
type PropagationProgress = {
|
||||
currentStep: number;
|
||||
completedSteps: number;
|
||||
totalSteps: number;
|
||||
processedCount: number;
|
||||
createdCount: number;
|
||||
label: string;
|
||||
} | null;
|
||||
|
||||
const PROPAGATION_POLL_INTERVAL_MS = 250;
|
||||
const STATUS_MESSAGE_TTL_MS = 3600;
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
||||
@@ -53,6 +67,47 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
const [propagationStartFrame, setPropagationStartFrame] = useState(1);
|
||||
const [propagationEndFrame, setPropagationEndFrame] = useState(1);
|
||||
const [isPropagationRangeSelecting, setIsPropagationRangeSelecting] = useState(false);
|
||||
const [hasExplicitPropagationRange, setHasExplicitPropagationRange] = useState(false);
|
||||
const [propagationProgress, setPropagationProgress] = useState<PropagationProgress>(null);
|
||||
const [propagationTaskId, setPropagationTaskId] = useState<number | null>(null);
|
||||
const [propagationWeight, setPropagationWeight] = useState<AiModelId>(aiModel || DEFAULT_AI_MODEL_ID);
|
||||
const [hasCustomPropagationWeight, setHasCustomPropagationWeight] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasCustomPropagationWeight) {
|
||||
setPropagationWeight(aiModel || DEFAULT_AI_MODEL_ID);
|
||||
}
|
||||
}, [aiModel, hasCustomPropagationWeight]);
|
||||
|
||||
const propagationWeightLabel = useMemo(
|
||||
() => SAM2_MODEL_OPTIONS.find((option) => option.id === propagationWeight)?.label || propagationWeight,
|
||||
[propagationWeight],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleWorkspaceShortcuts = (event: KeyboardEvent) => {
|
||||
const target = event.target as HTMLElement | null;
|
||||
const tagName = target?.tagName?.toLowerCase();
|
||||
if (tagName === 'input' || tagName === 'textarea' || tagName === 'select' || target?.isContentEditable) return;
|
||||
if (!event.metaKey && !event.ctrlKey) return;
|
||||
|
||||
const key = event.key.toLowerCase();
|
||||
if (key === 'z') {
|
||||
event.preventDefault();
|
||||
if (event.shiftKey) redoMasks();
|
||||
else undoMasks();
|
||||
return;
|
||||
}
|
||||
if (key === 'y') {
|
||||
event.preventDefault();
|
||||
redoMasks();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleWorkspaceShortcuts);
|
||||
return () => window.removeEventListener('keydown', handleWorkspaceShortcuts);
|
||||
}, [redoMasks, undoMasks]);
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (
|
||||
projectId: string,
|
||||
@@ -143,6 +198,13 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]);
|
||||
const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]);
|
||||
const currentFrameNumber = currentFrameIndex + 1;
|
||||
const isWorkspaceBusy = isSaving || isExporting || isImportingGt || isPropagating || Boolean(propagationProgress);
|
||||
|
||||
useEffect(() => {
|
||||
if (!statusMessage || isWorkspaceBusy || totalFrames === 0) return undefined;
|
||||
const timer = window.setTimeout(() => setStatusMessage(''), STATUS_MESSAGE_TTL_MS);
|
||||
return () => window.clearTimeout(timer);
|
||||
}, [isWorkspaceBusy, statusMessage, totalFrames]);
|
||||
|
||||
useEffect(() => {
|
||||
if (totalFrames === 0) {
|
||||
@@ -152,6 +214,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
setPropagationStartFrame(currentFrameNumber);
|
||||
setPropagationEndFrame(Math.min(totalFrames, currentFrameNumber + 29));
|
||||
setIsPropagationRangeSelecting(false);
|
||||
setHasExplicitPropagationRange(false);
|
||||
}, [currentFrameNumber, totalFrames]);
|
||||
|
||||
const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => {
|
||||
@@ -341,6 +405,9 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) {
|
||||
return null;
|
||||
}
|
||||
const sourceAnnotationId = seedMask.annotationId && /^\d+$/.test(seedMask.annotationId)
|
||||
? Number(seedMask.annotationId)
|
||||
: undefined;
|
||||
return {
|
||||
polygons: seedPayload.mask_data?.polygons,
|
||||
bbox: seedPayload.bbox,
|
||||
@@ -349,12 +416,33 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
color: seedPayload.mask_data?.color,
|
||||
class_metadata: seedPayload.mask_data?.class,
|
||||
template_id: seedPayload.template_id,
|
||||
source_mask_id: seedMask.id,
|
||||
source_annotation_id: sourceAnnotationId,
|
||||
};
|
||||
}, [activeTemplateId, currentFrame, currentProject?.id]);
|
||||
|
||||
const handleAutoPropagate = async () => {
|
||||
const handlePropagationRangeChange = useCallback((startFrame: number, endFrame: number) => {
|
||||
const nextStart = clampFrameNumber(startFrame);
|
||||
const nextEnd = clampFrameNumber(endFrame);
|
||||
setPropagationStartFrame(nextStart);
|
||||
setPropagationEndFrame(nextEnd);
|
||||
setHasExplicitPropagationRange(true);
|
||||
setStatusMessage(`已选择自动传播范围:第 ${Math.min(nextStart, nextEnd)}-${Math.max(nextStart, nextEnd)} 帧`);
|
||||
}, [clampFrameNumber]);
|
||||
|
||||
const handlePropagationStartInput = (value: number) => {
|
||||
setPropagationStartFrame(clampFrameNumber(value || 1));
|
||||
setHasExplicitPropagationRange(true);
|
||||
};
|
||||
|
||||
const handlePropagationEndInput = (value: number) => {
|
||||
setPropagationEndFrame(clampFrameNumber(value || 1));
|
||||
setHasExplicitPropagationRange(true);
|
||||
};
|
||||
|
||||
const runAutoPropagate = async () => {
|
||||
if (!currentProject?.id || !currentFrame?.id) return;
|
||||
const seedMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
|
||||
const seedMasks = masks.filter((mask) => String(mask.frameId) === String(currentFrame.id));
|
||||
if (seedMasks.length === 0) {
|
||||
setStatusMessage('请先在当前参考帧创建或保存至少一个 mask');
|
||||
return;
|
||||
@@ -390,37 +478,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
return;
|
||||
}
|
||||
|
||||
setIsPropagationRangeSelecting(false);
|
||||
setIsPropagating(true);
|
||||
setStatusMessage(`${aiModel.toUpperCase()} 正在以第 ${currentFrameNumber} 帧为参考,自动传播 ${seeds.length} 个 mask 到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`);
|
||||
const totalSteps = seeds.length * propagationDirections.length;
|
||||
setPropagationProgress({
|
||||
currentStep: 0,
|
||||
completedSteps: 0,
|
||||
totalSteps,
|
||||
processedCount: 0,
|
||||
createdCount: 0,
|
||||
label: '准备传播',
|
||||
});
|
||||
setStatusMessage(`${propagationWeightLabel} 权重正在以第 ${currentFrameNumber} 帧为参考,自动传播 ${seeds.length} 个 mask 到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`);
|
||||
try {
|
||||
let createdCount = 0;
|
||||
let processedCount = 0;
|
||||
for (const { seed } of seeds) {
|
||||
for (const { direction, maxFrames } of propagationDirections) {
|
||||
const result = await propagateMasks({
|
||||
project_id: Number(currentProject.id),
|
||||
frame_id: Number(currentFrame.id),
|
||||
model: aiModel,
|
||||
direction,
|
||||
max_frames: maxFrames,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed,
|
||||
});
|
||||
createdCount += result.created_annotation_count;
|
||||
processedCount += result.processed_frame_count;
|
||||
const steps = seeds.flatMap(({ seed }) => (
|
||||
propagationDirections.map(({ direction, maxFrames }) => ({
|
||||
seed,
|
||||
direction,
|
||||
max_frames: maxFrames,
|
||||
}))
|
||||
));
|
||||
const task = await queuePropagationTask({
|
||||
project_id: Number(currentProject.id),
|
||||
frame_id: Number(currentFrame.id),
|
||||
model: propagationWeight,
|
||||
steps,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
});
|
||||
setPropagationTaskId(task.id);
|
||||
setStatusMessage(`自动传播任务已入队 #${task.id},可在 Dashboard 查看进度`);
|
||||
|
||||
let currentTask = task;
|
||||
while (!['success', 'failed', 'cancelled'].includes(currentTask.status)) {
|
||||
await new Promise((resolve) => setTimeout(resolve, PROPAGATION_POLL_INTERVAL_MS));
|
||||
currentTask = await getTask(task.id);
|
||||
const result = currentTask.result || {};
|
||||
const completedSteps = Number(result.completed_steps || 0);
|
||||
const processedCount = Number(result.processed_frame_count || 0);
|
||||
const createdCount = Number(result.created_annotation_count || 0);
|
||||
setPropagationProgress({
|
||||
currentStep: Math.min(completedSteps + 1, totalSteps),
|
||||
completedSteps,
|
||||
totalSteps,
|
||||
processedCount,
|
||||
createdCount,
|
||||
label: currentTask.message || `自动传播任务 #${task.id}`,
|
||||
});
|
||||
setStatusMessage(currentTask.message || `自动传播任务 #${task.id} 运行中...`);
|
||||
if (createdCount > 0) {
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
}
|
||||
}
|
||||
|
||||
const result = currentTask.result || {};
|
||||
const createdCount = Number(result.created_annotation_count || 0);
|
||||
const processedCount = Number(result.processed_frame_count || 0);
|
||||
const skippedCount = Number(result.skipped_seed_count || 0);
|
||||
const deletedCount = Number(result.deleted_annotation_count || 0);
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
setStatusMessage(`已自动传播 ${seeds.length} 个参考 mask,处理 ${processedCount} 帧次,保存 ${createdCount} 个区域`);
|
||||
if (currentTask.status === 'failed') {
|
||||
setStatusMessage(currentTask.error ? `传播失败:${currentTask.error}` : '传播失败,请检查权重状态或后端日志');
|
||||
return;
|
||||
}
|
||||
if (currentTask.status === 'cancelled') {
|
||||
setStatusMessage('自动传播任务已取消');
|
||||
return;
|
||||
}
|
||||
setStatusMessage(createdCount > 0
|
||||
? `已自动传播 ${seeds.length} 个参考 mask,处理 ${processedCount} 帧次,删除旧区域 ${deletedCount} 个,保存 ${createdCount} 个区域`
|
||||
: skippedCount > 0
|
||||
? `自动传播已完成:${skippedCount} 个未改变 mask 已跳过,没有生成重复区域`
|
||||
: `自动传播已完成,但没有生成新的 mask;请检查参考 mask、传播范围或 ${propagationWeightLabel} 权重状态`);
|
||||
} catch (err) {
|
||||
console.error('Propagation failed:', err);
|
||||
setStatusMessage('传播失败,请检查模型状态或后端日志');
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
setStatusMessage(detail ? `传播失败:${detail}` : '传播失败,请检查权重状态或后端日志');
|
||||
} finally {
|
||||
setIsPropagating(false);
|
||||
setPropagationProgress(null);
|
||||
setPropagationTaskId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAutoPropagate = async () => {
|
||||
if (!hasExplicitPropagationRange && !isPropagationRangeSelecting) {
|
||||
setIsPropagationRangeSelecting(true);
|
||||
setStatusMessage('请在播放进度条或视频处理进度条上点击/拖拽选择传播起止帧,再点击“开始传播”');
|
||||
return;
|
||||
}
|
||||
await runAutoPropagate();
|
||||
};
|
||||
|
||||
const handleCancelPropagationRangeSelection = () => {
|
||||
setIsPropagationRangeSelecting(false);
|
||||
setHasExplicitPropagationRange(false);
|
||||
setPropagationStartFrame(currentFrameNumber || 1);
|
||||
setPropagationEndFrame(Math.min(Math.max(totalFrames, 1), (currentFrameNumber || 1) + 29));
|
||||
setStatusMessage('已取消自动传播范围选择');
|
||||
};
|
||||
|
||||
const handleCancelPropagation = async () => {
|
||||
if (!propagationTaskId) return;
|
||||
try {
|
||||
await cancelTask(propagationTaskId);
|
||||
setStatusMessage(`正在取消自动传播任务 #${propagationTaskId}...`);
|
||||
} catch (err) {
|
||||
console.error('Cancel propagation failed:', err);
|
||||
setStatusMessage('取消自动传播失败,请稍后重试');
|
||||
}
|
||||
};
|
||||
|
||||
const propagationPercent = propagationProgress
|
||||
? Math.round((propagationProgress.completedSteps / Math.max(propagationProgress.totalSteps, 1)) * 100)
|
||||
: 0;
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
{/* Top Header / Status bar */}
|
||||
@@ -436,7 +608,68 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
{statusMessage}
|
||||
</span>
|
||||
)}
|
||||
{propagationProgress && (
|
||||
<div
|
||||
className="w-56 rounded-md border border-blue-500/20 bg-blue-500/5 px-2 py-1"
|
||||
aria-label="自动传播进度"
|
||||
title={`已处理 ${propagationProgress.processedCount} 帧次,已保存 ${propagationProgress.createdCount} 个区域`}
|
||||
>
|
||||
<div className="mb-1 flex items-center justify-between gap-2 text-[10px] font-mono text-blue-200">
|
||||
<span className="truncate">{propagationProgress.label}</span>
|
||||
<span>{propagationPercent}%</span>
|
||||
</div>
|
||||
<div className="h-1.5 overflow-hidden rounded-full bg-zinc-700">
|
||||
<div
|
||||
className="h-full rounded-full bg-blue-400 transition-all"
|
||||
style={{ width: `${propagationPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center gap-1 rounded-md border border-white/10 bg-white/[0.03] px-1 py-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={undoMasks}
|
||||
disabled={maskHistory.length === 0}
|
||||
aria-label="撤销操作"
|
||||
title="撤销操作 (Ctrl+Z)"
|
||||
className="h-7 px-2 rounded text-gray-400 hover:bg-white/5 hover:text-white inline-flex items-center gap-1.5 text-xs transition-colors disabled:opacity-35 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||
>
|
||||
<Undo size={14} />
|
||||
撤销
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={redoMasks}
|
||||
disabled={maskFuture.length === 0}
|
||||
aria-label="重做操作"
|
||||
title="重做操作 (Ctrl+Shift+Z / Ctrl+Y)"
|
||||
className="h-7 px-2 rounded text-gray-400 hover:bg-white/5 hover:text-white inline-flex items-center gap-1.5 text-xs transition-colors disabled:opacity-35 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||
>
|
||||
<Redo size={14} />
|
||||
重做
|
||||
</button>
|
||||
</div>
|
||||
<ModelStatusBadge />
|
||||
<div className="flex items-center gap-1 rounded-md border border-white/10 bg-white/[0.03] px-2 py-1">
|
||||
<span className="text-[10px] text-gray-500 whitespace-nowrap">传播权重</span>
|
||||
<select
|
||||
aria-label="传播权重"
|
||||
value={propagationWeight}
|
||||
onChange={(event) => {
|
||||
setHasCustomPropagationWeight(true);
|
||||
setPropagationWeight(event.target.value as AiModelId);
|
||||
}}
|
||||
disabled={isPropagating || isSaving || isExporting || isImportingGt}
|
||||
className="h-6 w-24 rounded border border-white/10 bg-black/20 px-1 text-[10px] text-gray-300 outline-none focus:border-cyan-500/50 disabled:opacity-40"
|
||||
>
|
||||
{SAM2_MODEL_OPTIONS.map((option) => (
|
||||
<option key={option.id} value={option.id}>
|
||||
{option.shortLabel}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
<input
|
||||
ref={gtMaskInputRef}
|
||||
type="file"
|
||||
@@ -460,7 +693,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
min={1}
|
||||
max={Math.max(totalFrames, 1)}
|
||||
value={propagationStartFrame}
|
||||
onChange={(event) => setPropagationStartFrame(clampFrameNumber(Number(event.target.value) || 1))}
|
||||
onChange={(event) => handlePropagationStartInput(Number(event.target.value))}
|
||||
disabled={isPropagating || isSaving || isExporting || isImportingGt || totalFrames === 0}
|
||||
className="h-6 w-14 rounded bg-black/20 border border-white/10 px-1 text-[10px] text-gray-300 outline-none focus:border-cyan-500/50 disabled:opacity-40"
|
||||
/>
|
||||
@@ -471,7 +704,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
min={1}
|
||||
max={Math.max(totalFrames, 1)}
|
||||
value={propagationEndFrame}
|
||||
onChange={(event) => setPropagationEndFrame(clampFrameNumber(Number(event.target.value) || 1))}
|
||||
onChange={(event) => handlePropagationEndInput(Number(event.target.value))}
|
||||
disabled={isPropagating || isSaving || isExporting || isImportingGt || totalFrames === 0}
|
||||
className="h-6 w-14 rounded bg-black/20 border border-white/10 px-1 text-[10px] text-gray-300 outline-none focus:border-cyan-500/50 disabled:opacity-40"
|
||||
/>
|
||||
@@ -481,8 +714,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isPropagating ? '传播中...' : '自动传播'}
|
||||
{isPropagating ? '传播中...' : isPropagationRangeSelecting ? '开始传播' : '自动传播'}
|
||||
</button>
|
||||
{isPropagationRangeSelecting && (
|
||||
<button
|
||||
onClick={handleCancelPropagationRangeSelection}
|
||||
disabled={isPropagating}
|
||||
className="px-3 py-1.5 bg-amber-500/10 hover:bg-amber-500/20 border border-amber-500/25 rounded-md text-xs transition-colors text-amber-100 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
取消选区
|
||||
</button>
|
||||
)}
|
||||
{propagationTaskId && (
|
||||
<button
|
||||
onClick={handleCancelPropagation}
|
||||
className="px-3 py-1.5 bg-red-500/10 hover:bg-red-500/20 border border-red-500/20 rounded-md text-xs transition-colors text-red-200"
|
||||
>
|
||||
取消传播
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={handleExportMasks}
|
||||
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
|
||||
@@ -534,7 +784,15 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
</div>
|
||||
|
||||
{/* Bottom Timeline */}
|
||||
<FrameTimeline />
|
||||
<FrameTimeline
|
||||
propagationRange={{
|
||||
startFrame: propagationStartFrame,
|
||||
endFrame: propagationEndFrame,
|
||||
}}
|
||||
propagationRangeSelectionActive={isPropagationRangeSelecting}
|
||||
propagationRangeDisabled={isPropagating || isSaving || isExporting || isImportingGt}
|
||||
onPropagationRangeChange={handlePropagationRangeChange}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -53,10 +53,12 @@ describe('api client contracts', () => {
|
||||
name: 'Demo',
|
||||
status: 'ready',
|
||||
frames: 12,
|
||||
fps: '30FPS',
|
||||
fps: '10FPS',
|
||||
thumbnail_url: 'thumb',
|
||||
video_path: 'uploads/demo.mp4',
|
||||
source_type: 'video',
|
||||
original_fps: 29.97,
|
||||
parse_fps: 10,
|
||||
createdAt: 'created',
|
||||
updatedAt: 'updated',
|
||||
}),
|
||||
@@ -184,7 +186,7 @@ describe('api client contracts', () => {
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const { deleteAnnotation, getProjectAnnotations, propagateMasks, queuePropagationTask, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const saved = {
|
||||
id: 1,
|
||||
project_id: 9,
|
||||
@@ -267,6 +269,48 @@ describe('api client contracts', () => {
|
||||
}, {
|
||||
timeout: 600000,
|
||||
});
|
||||
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
id: 33,
|
||||
task_type: 'propagate_masks',
|
||||
status: 'queued',
|
||||
progress: 0,
|
||||
message: '自动传播任务已入队',
|
||||
},
|
||||
});
|
||||
await expect(queuePropagationTask({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2.1_hiera_tiny',
|
||||
steps: [{
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
}],
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
})).resolves.toEqual(expect.objectContaining({ id: 33, task_type: 'propagate_masks' }));
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate/task', {
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2.1_hiera_tiny',
|
||||
steps: [{
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
}],
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('imports GT masks through multipart form data', async () => {
|
||||
|
||||
@@ -49,6 +49,12 @@ function normalizeProjectStatus(status?: string): Project['status'] {
|
||||
return 'pending';
|
||||
}
|
||||
|
||||
function formatProjectFps(value?: number | null): string {
|
||||
if (!value || value <= 0) return '30FPS';
|
||||
const rounded = Math.round(value * 10) / 10;
|
||||
return `${Number.isInteger(rounded) ? rounded.toFixed(0) : rounded.toFixed(1)}FPS`;
|
||||
}
|
||||
|
||||
function mapProject(p: any): Project {
|
||||
return {
|
||||
id: String(p.id),
|
||||
@@ -56,7 +62,7 @@ function mapProject(p: any): Project {
|
||||
description: p.description,
|
||||
status: normalizeProjectStatus(p.status),
|
||||
frames: p.frame_count ?? 0,
|
||||
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
|
||||
fps: formatProjectFps(p.parse_fps ?? p.original_fps),
|
||||
thumbnail_url: p.thumbnail_url,
|
||||
video_path: p.video_path,
|
||||
source_type: p.source_type,
|
||||
@@ -346,6 +352,8 @@ export interface PropagateMasksPayload {
|
||||
category?: string;
|
||||
};
|
||||
template_id?: number;
|
||||
source_mask_id?: string;
|
||||
source_annotation_id?: number;
|
||||
};
|
||||
direction?: 'forward' | 'backward' | 'both';
|
||||
max_frames?: number;
|
||||
@@ -353,6 +361,19 @@ export interface PropagateMasksPayload {
|
||||
save_annotations?: boolean;
|
||||
}
|
||||
|
||||
export interface PropagateTaskPayload {
|
||||
project_id: number;
|
||||
frame_id: number;
|
||||
model?: AiModelId;
|
||||
steps: Array<{
|
||||
seed: PropagateMasksPayload['seed'];
|
||||
direction: 'forward' | 'backward';
|
||||
max_frames: number;
|
||||
}>;
|
||||
include_source?: boolean;
|
||||
save_annotations?: boolean;
|
||||
}
|
||||
|
||||
export interface PropagateMasksResult {
|
||||
model: AiModelId;
|
||||
direction: string;
|
||||
@@ -652,6 +673,11 @@ export async function propagateMasks(payload: PropagateMasksPayload): Promise<Pr
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function queuePropagationTask(payload: PropagateTaskPayload): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post('/api/ai/propagate/task', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.post('/api/ai/annotate', payload);
|
||||
return response.data;
|
||||
|
||||
@@ -32,7 +32,7 @@ function makeStageEvent(x = 120, y = 80) {
|
||||
}
|
||||
|
||||
vi.mock('react-konva', () => ({
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => {
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel, onDragEnd, scaleX, scaleY, x, y, width, height }: any) => {
|
||||
const coords = (event: React.MouseEvent<HTMLDivElement>, fallbackX: number, fallbackY: number) => ({
|
||||
x: event.clientX || fallbackX,
|
||||
y: event.clientY || fallbackY,
|
||||
@@ -40,6 +40,13 @@ vi.mock('react-konva', () => ({
|
||||
return (
|
||||
<div
|
||||
data-testid="konva-stage"
|
||||
data-has-drag-end={Boolean(onDragEnd)}
|
||||
data-scale-x={scaleX}
|
||||
data-scale-y={scaleY}
|
||||
data-x={x}
|
||||
data-y={y}
|
||||
data-width={width}
|
||||
data-height={height}
|
||||
onClick={(event) => {
|
||||
const point = coords(event, 120, 80);
|
||||
onClick?.(makeStageEvent(point.x, point.y));
|
||||
@@ -57,6 +64,12 @@ vi.mock('react-konva', () => ({
|
||||
onMouseMove?.(makeStageEvent(point.x, point.y));
|
||||
}}
|
||||
onWheel={() => onWheel?.(makeStageEvent())}
|
||||
onDragEnd={(event) => onDragEnd?.({
|
||||
target: {
|
||||
x: () => event.clientX || 0,
|
||||
y: () => event.clientY || 0,
|
||||
},
|
||||
})}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user