feat: 完善 SAM2.1 模型选择与标注工作流

- 后端 SAM2 引擎新增 sam2.1_hiera_tiny、sam2.1_hiera_small、sam2.1_hiera_base_plus、sam2.1_hiera_large 四个变体定义,并按变体维护 checkpoint/config、image predictor、video predictor、加载状态、错误信息和真实状态回报。

- 后端 SAM registry 仅暴露当前产品启用的 SAM2.1 变体,保留 sam2 作为 tiny 兼容别名,拒绝 sam3 产品入口,并把 point、box、interactive、auto、propagate 都分发到所选 SAM2.1 变体。

- 后端默认配置和下载脚本切换到 SAM2.1 checkpoint 命名,支持 legacy SAM2 checkpoint fallback,并在状态消息中标出 fallback 使用情况。

- 前端全局 AI 模型状态新增 SAM2.1 tiny/small/base+/large 类型和默认 tiny,API 请求默认携带 sam2.1_hiera_tiny,AI 页面提供模型变体选择和所选模型状态展示。

- AI 智能分割页移除当前产品不使用的 SAM3/文本提示入口,保留正向点、反向点、框选和参数开关;AI 页只展示本页生成的候选 mask,并支持遮罩清晰度调节、候选 mask 上继续加正/反点、清空本页候选、推送到工作区编辑。

- 工作区和 Canvas 补强 SAM2 交互式细化链路:框选后正/反点继续细化同一个候选 mask,反向点请求启用背景过滤,空结果会移除被否定候选;AI 推送到工作区后保留选中态和未保存 draft mask。

- 工作区标注保存闭环补强:未保存 mask 可归档保存,dirty saved mask 可更新,保存后用后端 saved annotation 替换已提交 draft,清空/删除已保存 mask 时同步后端删除。

- Dashboard 任务进度区改为展示 queued、running、success、failed、cancelled 最近任务,处理中统计只计算 queued/running,并保留近期完成记录。

- 时间轴在顶部时间进度条和底部缩略图导航轴之间新增已编辑帧标记带,基于当前项目帧内 masks 标出已有编辑/标注的帧,并支持点击标记跳转。

- 前端测试覆盖 SAM2.1 变体选择、模型状态徽标、AI 页候选隔离、遮罩透明度、候选上追加正/反点、推送工作区保留选择、Canvas 交互式细化、VideoWorkspace 传播/保存、Dashboard 进度和时间轴已编辑帧标记。

- 后端测试覆盖 SAM2.1 变体状态、sam2 alias 兼容、sam3 禁用、semantic 禁用、传播标注保存、Dashboard 最近任务状态和 SAM3 历史测试跳过说明。

- README、AGENTS 和 doc 文档同步当前真实进度,更新 SAM2.1 变体、SAM3 禁用、接口契约、设计冻结、需求冻结、前端元素审计、实施计划、FastAPI docs 说明和测试矩阵。
This commit is contained in:
2026-05-01 23:39:53 +08:00
parent 8a9247075e
commit 29a1a87e52
38 changed files with 1087 additions and 631 deletions

View File

@@ -6,7 +6,7 @@
## 项目概述 ## 项目概述
本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、显式视频生成帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出 本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、显式视频生成帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、可选 SAM 2.1 tiny/small/base+/large 辅助分割、模板分类管理和标注导出。SAM 3 相关源码和安装脚本保留在仓库中,但由于当前产品不提供文本提示,前端入口已隐藏,后端 registry 也不暴露 `sam3` 模型
- **项目名称**: `react-example``package.json` 中的 `name` - **项目名称**: `react-example``package.json` 中的 `name`
- **前端入口**: `src/main.tsx``src/App.tsx` - **前端入口**: `src/main.tsx``src/App.tsx`
@@ -39,9 +39,9 @@
| 缓存 / 队列 Broker | Redis | | 缓存 / 队列 Broker | Redis |
| 后台任务 | Celery worker | | 后台任务 | Celery worker |
| 对象存储 | MinIO | | 对象存储 | MinIO |
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorchSAM 3 通过独立 Python 3.12 conda 环境桥接`GET /api/ai/models/status` 返回真实 GPU/模型/本地 checkpoint 状态 | | AI 推理 | 当前启用 SAM 2.1 + PyTorch可选 tiny/small/base+/large`GET /api/ai/models/status` 返回真实 GPU 和各 SAM 2.1 变体状态SAM 3 源码保留但产品入口禁用 |
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom | | 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
| 运行时 | Node.js ES ModulesPython 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 | | 运行时 | Node.js ES ModulesPython 3.11 后端环境;历史保留的 `sam3` Python 3.12 conda 环境不是当前必需运行条件 |
--- ---
@@ -71,7 +71,7 @@ Seg_Server/
│ ├── celery_app.py # Celery app 配置 │ ├── celery_app.py # Celery app 配置
│ ├── worker_tasks.py # Celery 任务入口 │ ├── worker_tasks.py # Celery 任务入口
│ ├── download_sam2.py # SAM 2 权重下载脚本 │ ├── download_sam2.py # SAM 2 权重下载脚本
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本 │ ├── setup_sam3_env.sh # 历史保留的 SAM 3 独立 Python 3.12 环境安装脚本;当前产品入口禁用
│ ├── requirements.txt # Python 依赖 │ ├── requirements.txt # Python 依赖
│ ├── routers/ │ ├── routers/
│ │ ├── auth.py # /api/auth/login │ │ ├── auth.py # /api/auth/login
@@ -82,10 +82,10 @@ Seg_Server/
│ │ └── export.py # /api/export/{project_id}/coco、/masks │ │ └── export.py # /api/export/{project_id}/coco、/masks
│ └── services/ │ └── services/
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传 │ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
│ ├── sam2_engine.py # SAM 2 单帧推理和 video predictor 传播封装 │ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理和 video predictor 传播封装
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器 │ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper当前未被产品入口调用
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ └── sam_registry.py # 当前暴露 SAM 2.1 变体、GPU 状态与推理分发
└── src/ # React 前端 └── src/ # React 前端
├── main.tsx # React StrictMode 挂载 ├── main.tsx # React StrictMode 挂载
├── App.tsx # 登录拦截 + 模块切换 ├── App.tsx # 登录拦截 + 模块切换
@@ -222,10 +222,10 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
3. 上传资源:视频走 `/api/media/upload`只上传源文件并关联项目不自动拆帧DICOM 批量走 `/api/media/upload/dicom` 3. 上传资源:视频走 `/api/media/upload`只上传源文件并关联项目不自动拆帧DICOM 批量走 `/api/media/upload/dicom`
4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery接口支持 `parse_fps``max_frames``target_width` 标准帧序列参数。 4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery接口支持 `parse_fps``max_frames``target_width` 标准帧序列参数。
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms``source_frame_number` 和任务 `frame_sequence` 元数据。 5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms``source_frame_number` 和任务 `frame_sequence` 元数据。
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`FrameTimeline` 会根据当前项目帧内的 `masks` 在进度条和缩略图导航轴之间标出已有编辑/标注的帧;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 可拖动/删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference内含去除结果用 even-odd 规则渲染 holeZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;工具栏有“调整多边形”入口,点击 mask 可拖动/删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference内含去除结果用 even-odd 规则渲染 holeZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
8. AI 分割前端工具包括正向点、反向点和框选SAM 2 框选会建立候选 mask后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask包含反向点时工作区会传 `options.auto_filter_background=true``min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播但不支持文本语义提示AI 页面在 SAM 2 纯文本时提示改用点提示或切换 SAM 3SAM 2 多候选默认只采用最高分区域,避免重叠候选同时显示AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理、框选提示和 external video trackersemantic 请求会把正数 `options.min_score` 传给 external worker 作为置信度阈值,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用 8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选SAM 2.1 框选会建立候选 mask后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask包含反向点时工作区会传 `options.auto_filter_background=true``min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny``sam2.1_hiera_small``sam2.1_hiera_base_plus``sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播多候选默认只采用最高分区域避免重叠候选同时显示AI 页面只渲染本页新生成的候选 mask不会把工作区已有 mask 带入 AI 画布AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`并把后续帧结果保存为 `Annotation` 9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新。 10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新。
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。 11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。
12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json` 12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`
@@ -242,12 +242,12 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 - 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 - 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。 - 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`SAM 2 路径使用视频 predictorSAM 3 路径使用独立 Python helper 的官方 video tracker完成后刷新后端已保存标注。 - 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`当前启用所选 SAM 2.1 变体的视频 predictor完成后刷新后端已保存标注。
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
- 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error` - 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error`
- 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。 - 项目库的视频导入与生成帧是两个独立动作:导入视频只上传源文件,生成帧按钮才会带 `parse_fps` 调用 `/api/media/parse`;工作区不会再因“有视频但无帧”自动创建拆帧任务。
- `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。 - `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`;前端 WebSocket 客户端通过 `onopen/onclose/onerror` 更新连接状态,并定时发送 `ping` 心跳。 - `Dashboard.tsx` 初始统计、任务进度和活动日志来自 `GET /api/dashboard/overview`任务进度来自 `processing_tasks` queued/running/success/failed/cancelled处理中统计只计算 queued/running,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`;前端 WebSocket 客户端通过 `onopen/onclose/onerror` 更新连接状态,并定时发送 `ping` 心跳。
--- ---

View File

@@ -4,20 +4,20 @@
# 语义分割系统SegServer # 语义分割系统SegServer
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。 > 基于 React + FastAPI + SAM 2 的全栈交互式图像/视频语义分割与标注平台。
> >
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、视频片段传播、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2SAM 3 支持语义文本、框选提示和 video tracker前端会显示真实 GPU/模型状态 > 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、视频片段传播、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理走可选 SAM 2.1 tiny/small/base+/large前端会显示真实 GPU/模型状态。SAM 3 源码和脚本在仓库中保留,但由于当前系统不提供文本提示,产品入口已隐藏,后端也不暴露 `sam3` 模型
--- ---
## 核心功能 ## 核心功能
- **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS - **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box、自动分割auto和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示SAM 3 入口支持文本语义提示、框选提示和 external video tracker并按真实运行环境显示可用性 - **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体tiny/small/base+/large选择支持点分割point、框分割box交互式正/反点细化、自动分割auto和 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示
- **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 - **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point - **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 - **项目工作区** — 项目创建、帧浏览、多图层标注、已编辑帧提示、进度追踪
- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射 - **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射
--- ---
@@ -40,7 +40,7 @@
│ ├── /api/templates 本体字典(分类/颜色/z-index │ ├── /api/templates 本体字典(分类/颜色/z-index
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │ │ ├── /api/media 文件上传 & 异步拆帧任务创建 │
│ ├── /api/tasks Celery 后台任务状态/取消/重试/详情 │ │ ├── /api/tasks Celery 后台任务状态/取消/重试/详情 │
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │ │ ├── /api/ai SAM 2 推理与模型状态
│ └── /api/export COCO JSON / PNG Masks 导出 │ │ └── /api/export COCO JSON / PNG Masks 导出 │
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ SQLAlchemy 2.0 │ SQLAlchemy 2.0
@@ -71,7 +71,7 @@
| 队列 Broker | Redis | 6 | | 队列 Broker | Redis | 6 |
| 后台任务 | Celery worker | 5.6+ | | 后台任务 | Celery worker | 5.6+ |
| 对象存储 | MinIO | 2025+ | | 对象存储 | MinIO | 2025+ |
| AI 推理 | SAM 2 / SAM 3 (Meta) + PyTorch | - | | AI 推理 | SAM 2.1 (Meta) + PyTorch,可选 tiny/small/base+/large | - |
| 视频处理 | FFmpeg + OpenCV | 4.4+ | | 视频处理 | FFmpeg + OpenCV | 4.4+ |
| DICOM 处理 | pydicom | 3.0+ | | DICOM 处理 | pydicom | 3.0+ |
@@ -94,7 +94,7 @@ Seg_Server/
│ ├── celery_app.py # Celery app 配置 │ ├── celery_app.py # Celery app 配置
│ ├── worker_tasks.py # Celery 任务入口 │ ├── worker_tasks.py # Celery 任务入口
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本 │ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本 │ ├── setup_sam3_env.sh # 历史保留的 SAM 3 独立 Python 3.12 环境安装脚本;当前产品入口禁用
│ ├── requirements.txt # Python 依赖 │ ├── requirements.txt # Python 依赖
│ ├── routers/ # API 路由 │ ├── routers/ # API 路由
│ │ ├── auth.py # 登录认证 │ │ ├── auth.py # 登录认证
@@ -104,10 +104,10 @@ Seg_Server/
│ │ ├── ai.py # SAM 推理与模型状态接口 │ │ ├── ai.py # SAM 推理与模型状态接口
│ │ └── export.py # 数据导出 │ │ └── export.py # 数据导出
│ └── services/ # 业务服务 │ └── services/ # 业务服务
│ ├── sam2_engine.py # SAM 2 推理引擎(单帧推理 + video predictor 传播 │ ├── sam2_engine.py # SAM 2.1 变体选择、单帧推理 + video predictor 传播
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器 │ ├── sam3_engine.py # 历史保留的 SAM 3 桥接实现;当前未接入 registry
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ ├── sam3_external_worker.py # 历史保留的独立 sam3 helper当前未被产品入口调用
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ ├── sam_registry.py # 当前暴露 SAM 2.1 变体、GPU 状态与推理分发
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片 │ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
├── src/ # React 前端 ├── src/ # React 前端
│ ├── main.tsx # 应用挂载点 │ ├── main.tsx # 应用挂载点
@@ -121,7 +121,7 @@ Seg_Server/
│ └── components/ # 组件(扁平化目录) │ └── components/ # 组件(扁平化目录)
│ ├── Login.tsx # 登录页 │ ├── Login.tsx # 登录页
│ ├── Sidebar.tsx # 左侧导航栏 │ ├── Sidebar.tsx # 左侧导航栏
│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列/任务控制) │ ├── Dashboard.tsx # 总体概况仪表盘(任务进度/任务控制)
│ ├── ProjectLibrary.tsx # 项目库列表 │ ├── ProjectLibrary.tsx # 项目库列表
│ ├── VideoWorkspace.tsx # 核心分割工作区布局 │ ├── VideoWorkspace.tsx # 核心分割工作区布局
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染 │ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染
@@ -162,10 +162,10 @@ Seg_Server/
### 系统要求 ### 系统要求
- **OS**: Ubuntu 22.04 LTS - **OS**: Ubuntu 22.04 LTS
- **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 推理SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境 - **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 2 推理
- **CUDA**: 12.x / 13.x - **CUDA**: 12.x / 13.x
- **Node.js**: 22.x+ - **Node.js**: 22.x+
- **Python**: 主后端使用 3.11(通过 Miniconda/Anaconda 管理);SAM 3 使用独立 `sam3` Python 3.12 conda 环境 - **Python**: 主后端使用 3.11(通过 Miniconda/Anaconda 管理);历史保留的 SAM 3 环境不是当前必需运行条件
### 安装系统级依赖 ### 安装系统级依赖
@@ -239,27 +239,30 @@ cd ~/Desktop/Seg_Server/backend
python download_sam2.py python download_sam2.py
# 模型将下载到 ~/Desktop/Seg_Server/models/ # 模型将下载到 ~/Desktop/Seg_Server/models/
# sam2_hiera_tiny.pt (149 MB) # 推荐放置 SAM 2.1 文件名:
# sam2_hiera_small.pt (176 MB) # sam2.1_hiera_tiny.pt
# sam2_hiera_base_plus.pt (309 MB) # sam2.1_hiera_small.pt
# sam2_hiera_large.pt (856 MB) # sam2.1_hiera_base_plus.pt
# sam2.1_hiera_large.pt
#
# 兼容旧版 SAM 2 文件名:
# sam2_hiera_tiny.pt / sam2_hiera_small.pt / sam2_hiera_base_plus.pt / sam2_hiera_large.pt
``` ```
> **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间 > **注意**:当前系统磁盘紧张时,建议仅保留 `sam2.1_hiera_tiny.pt` 或兼容旧名 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。前端可以选择四个变体,但只有本地存在对应 checkpoint 的变体会显示可用
### 步骤 5: 可选安装 SAM 3 环境 ### 步骤 5: 历史保留的 SAM 3 环境
当前后端不会把 SAM 3 直接装进 `seg_server`,而是通过独立 `sam3` conda 环境执行 `backend/services/sam3_external_worker.py`。这样可以保留现有 Python 3.11 / SAM 2 环境 当前产品入口不再启用 SAM 3前端隐藏 SAM 3 相关入口,后端 registry 只暴露 SAM 2.1 变体,`model=sam3` 会返回不支持。以下脚本和 helper 仅作为以后恢复 SAM 3 研究路径的保留文件,正常部署不需要执行
```bash ```bash
cd ~/Desktop/Seg_Server cd ~/Desktop/Seg_Server
./backend/setup_sam3_env.sh ./backend/setup_sam3_env.sh
# 如果已把权重放在 sam3权重/sam3.pt可直接走本地 checkpoint # 仅在后续恢复 SAM 3 实验路径时使用。
# 未配置本地 checkpoint 时,才需要 Hugging Face gated repo 授权和登录。
``` ```
官方 `facebook/sam3` 权重约 3.45 GB当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度。当前仓库默认使用本机 `sam3权重/sam3.pt`,不会提交权重文件;未配置本地 checkpoint 且未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足 官方 `facebook/sam3` 权重约 3.45 GB当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度。本项目不会提交权重文件由于当前系统不提供文本提示SAM 3 不在模型状态接口和前端 UI 中展示
### 步骤 6: 配置环境变量 ### 步骤 6: 配置环境变量
@@ -273,14 +276,15 @@ minio_endpoint=192.168.3.11:9000
minio_access_key=minioadmin minio_access_key=minioadmin
minio_secret_key=minioadmin minio_secret_key=minioadmin
minio_secure=false minio_secure=false
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt
sam_model_config=configs/sam2/sam2_hiera_t.yaml sam_model_config=configs/sam2.1/sam2.1_hiera_t.yaml
sam_default_model=sam2 sam_default_model=sam2.1_hiera_tiny
sam3_model_version=sam3 # 以下 sam3_* 配置为历史保留项;当前产品入口不读取它们来暴露 SAM 3。
sam3_checkpoint_path=/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt # sam3_model_version=sam3
sam3_external_enabled=true # sam3_checkpoint_path=/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt
sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python # sam3_external_enabled=true
sam3_timeout_seconds=300 # sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
# sam3_timeout_seconds=300
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"] cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
``` ```
@@ -309,9 +313,9 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
- 创建数据库表(如果不存在) - 创建数据库表(如果不存在)
- 检查 MinIO bucket 是否存在 - 检查 MinIO bucket 是否存在
- 测试 Redis 连接 - 测试 Redis 连接
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态 - 懒加载所选 SAM 2.1 模型;`GET /api/ai/models/status` 会返回 tiny/small/base+/large 和 GPU 的真实可用状态,`selected_model=sam3` 会返回不支持
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt``auto_filter_background``min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤 - `/api/ai/predict` 支持 AI 参数 `crop_to_prompt``auto_filter_background``min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:SAM 2 使用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker - `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`
### 步骤 6.1: 启动 Celery Worker ### 步骤 6.1: 启动 Celery Worker
@@ -325,7 +329,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
``` ```
视频导入只创建项目并把源视频保存到 MinIO不会自动拆帧用户在项目库点击“生成帧”后再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps``max_frames``target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms``source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的 WebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 视频导入只创建项目并把源视频保存到 MinIO不会自动拆帧用户在项目库点击“生成帧”后再选择目标 FPS 并调用 `POST /api/media/parse`。该接口只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps``max_frames``target_width`,用于生成后续 SAM 2 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms``source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 的任务进度区展示 queued/running/success/failed/cancelled 最近任务,处理中统计只计算 queued/runningWebSocket 状态由浏览器 `onopen/onclose/onerror` 驱动,客户端会定时发送 `ping` 心跳,服务端返回 `status` 确认连接。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
### 步骤 7: 安装前端依赖并构建 ### 步骤 7: 安装前端依赖并构建
@@ -417,9 +421,9 @@ sudo apt autoremove -y && sudo apt clean
conda clean --all -y conda clean --all -y
# 3. 仅保留最小模型 # 3. 仅保留最小模型
rm ~/Desktop/Seg_Server/models/sam2_hiera_large.pt rm ~/Desktop/Seg_Server/models/sam2.1_hiera_large.pt
rm ~/Desktop/Seg_Server/models/sam2_hiera_base_plus.pt rm ~/Desktop/Seg_Server/models/sam2.1_hiera_base_plus.pt
rm ~/Desktop/Seg_Server/models/sam2_hiera_small.pt rm ~/Desktop/Seg_Server/models/sam2.1_hiera_small.pt
# 4. 如需安装 sam2 包,确保有 >5GB 可用空间后再执行 # 4. 如需安装 sam2 包,确保有 >5GB 可用空间后再执行
``` ```
@@ -431,7 +435,7 @@ rm ~/Desktop/Seg_Server/models/sam2_hiera_small.pt
**解决**: **解决**:
```bash ```bash
# 1. 确认模型文件存在 # 1. 确认模型文件存在
ls ~/Desktop/Seg_Server/models/sam2_hiera_tiny.pt ls ~/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt
# 2. 安装 sam2需 >5GB 磁盘空间) # 2. 安装 sam2需 >5GB 磁盘空间)
cd /tmp cd /tmp
@@ -461,8 +465,8 @@ pip install -e . --no-build-isolation
- 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData` - 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData`
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict` - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`
- 工作区 SAM 2 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。 - 工作区 SAM 2.1 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks` 并自动选中右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。 - AI 页面只显示本页新生成的 SAM 2.1 候选,不会把工作区已有 mask 带入 AI 画布;新生成 mask 会写入全局 `masks` 并自动选中右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择。
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed调用 `/api/ai/propagate`,并在完成后刷新已保存标注。 - 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco` - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程导出前会先保存当前待归档的前端 mask。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程导出前会先保存当前待归档的前端 mask。

View File

@@ -19,9 +19,9 @@ class Settings(BaseSettings):
minio_secure: bool = False minio_secure: bool = False
# SAM # SAM
sam_default_model: str = "sam2" sam_default_model: str = "sam2.1_hiera_tiny"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml" sam_model_config: str = "configs/sam2.1/sam2.1_hiera_t.yaml"
sam3_model_version: str = "sam3" sam3_model_version: str = "sam3"
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt" sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
sam3_external_enabled: bool = True sam3_external_enabled: bool = True

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
SAM 2 模型权重下载脚本 SAM 2.1 模型权重下载脚本
运行: python download_sam2.py 运行: python download_sam2.py
""" """
import os import os
@@ -10,12 +10,12 @@ import sys
MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models" MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models"
os.makedirs(MODEL_DIR, exist_ok=True) os.makedirs(MODEL_DIR, exist_ok=True)
# SAM 2 模型权重 (Meta AI 官方) # SAM 2.1 模型权重 (Meta AI 官方)
MODELS = { MODELS = {
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", "sam2.1_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", "sam2.1_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", "sam2.1_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", "sam2.1_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
} }
def download_file(url: str, dest: str): def download_file(url: str, dest: str):
@@ -35,7 +35,7 @@ def download_file(url: str, dest: str):
def main(): def main():
print("=" * 50) print("=" * 50)
print("SAM 2 模型权重下载") print("SAM 2.1 模型权重下载")
print("=" * 50) print("=" * 50)
for name, url in MODELS.items(): for name, url in MODELS.items():
dest = os.path.join(MODEL_DIR, name) dest = os.path.join(MODEL_DIR, name)

View File

@@ -231,7 +231,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`. coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`. - **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto. - **semantic**: disabled in the current SAM 2.1 point/box product flow.
""" """
frame = db.query(Frame).filter(Frame.id == payload.image_id).first() frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame: if not frame:
@@ -382,7 +382,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
summary="Get SAM model and GPU runtime status", summary="Get SAM model and GPU runtime status",
) )
def model_status(selected_model: str | None = None) -> dict: def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU, SAM 2, and SAM 3.""" """Return real runtime availability for GPU and the currently enabled SAM model."""
try: try:
return sam_registry.runtime_status(selected_model) return sam_registry.runtime_status(selected_model)
except ValueError as exc: except ValueError as exc:
@@ -398,7 +398,7 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
"""Track one selected region from the current frame across nearby frames. """Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed. SAM 2 uses the official video predictor with the selected mask as the seed.
SAM 3 uses the external Python 3.12 video tracker with the seed bbox. SAM 3 video tracking is currently disabled in this product flow.
""" """
direction = payload.direction.lower() direction = payload.direction.lower()
if direction not in {"forward", "backward", "both"}: if direction not in {"forward", "backward", "both"}:

View File

@@ -14,7 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"]) router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"} ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"} MONITORED_TASK_STATUSES = {"queued", "running", "success", "failed", "cancelled"}
def _system_load_percent() -> int: def _system_load_percent() -> int:

View File

@@ -204,7 +204,7 @@ class PropagationSeed(BaseModel):
class PropagateRequest(BaseModel): class PropagateRequest(BaseModel):
project_id: int project_id: int
frame_id: int frame_id: int
model: Optional[str] = "sam2" model: Optional[str] = "sam2.1_hiera_tiny"
seed: PropagationSeed seed: PropagationSeed
direction: str = "forward" direction: str = "forward"
max_frames: int = 30 max_frames: int = 30

View File

@@ -2,6 +2,8 @@
import logging import logging
import os import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@@ -10,6 +12,67 @@ from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny"
@dataclass(frozen=True)
class SAM2Variant:
"""One selectable SAM 2.1 runtime variant."""
id: str
label: str
short_label: str
config: str
legacy_config: str
checkpoint_filename: str
legacy_checkpoint_filename: str
SAM2_VARIANTS: dict[str, SAM2Variant] = {
"sam2.1_hiera_tiny": SAM2Variant(
id="sam2.1_hiera_tiny",
label="SAM 2.1 Tiny",
short_label="tiny",
config="configs/sam2.1/sam2.1_hiera_t.yaml",
legacy_config="configs/sam2/sam2_hiera_t.yaml",
checkpoint_filename="sam2.1_hiera_tiny.pt",
legacy_checkpoint_filename="sam2_hiera_tiny.pt",
),
"sam2.1_hiera_small": SAM2Variant(
id="sam2.1_hiera_small",
label="SAM 2.1 Small",
short_label="small",
config="configs/sam2.1/sam2.1_hiera_s.yaml",
legacy_config="configs/sam2/sam2_hiera_s.yaml",
checkpoint_filename="sam2.1_hiera_small.pt",
legacy_checkpoint_filename="sam2_hiera_small.pt",
),
"sam2.1_hiera_base_plus": SAM2Variant(
id="sam2.1_hiera_base_plus",
label="SAM 2.1 Base+",
short_label="base+",
config="configs/sam2.1/sam2.1_hiera_b+.yaml",
legacy_config="configs/sam2/sam2_hiera_b+.yaml",
checkpoint_filename="sam2.1_hiera_base_plus.pt",
legacy_checkpoint_filename="sam2_hiera_base_plus.pt",
),
"sam2.1_hiera_large": SAM2Variant(
id="sam2.1_hiera_large",
label="SAM 2.1 Large",
short_label="large",
config="configs/sam2.1/sam2.1_hiera_l.yaml",
legacy_config="configs/sam2/sam2_hiera_l.yaml",
checkpoint_filename="sam2.1_hiera_large.pt",
legacy_checkpoint_filename="sam2_hiera_large.pt",
),
}
SAM2_MODEL_ALIASES = {
"sam2": DEFAULT_SAM2_MODEL_ID,
"sam2.1": DEFAULT_SAM2_MODEL_ID,
"sam2_tiny": DEFAULT_SAM2_MODEL_ID,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable. # Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -38,115 +101,173 @@ class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine.""" """Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None: def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {}
self._video_predictor = None self._video_predictors: dict[str, object | None] = {}
self._model_loaded = False self._model_loaded: dict[str, bool] = {}
self._video_model_loaded = False self._video_model_loaded: dict[str, bool] = {}
self._loaded_device: str | None = None self._loaded_device: dict[str, str] = {}
self._last_error: str | None = None self._last_error: dict[str, str | None] = {}
self._video_last_error: str | None = None self._video_last_error: dict[str, str | None] = {}
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Internal helpers # Internal helpers
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
def _load_model(self) -> None: def variant_ids(self) -> list[str]:
return list(SAM2_VARIANTS.keys())
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
selected = SAM2_MODEL_ALIASES.get(selected, selected)
if selected not in SAM2_VARIANTS:
raise ValueError(f"Unsupported SAM2 model: {model_id}")
return selected
def is_sam2_model(self, model_id: str | None) -> bool:
try:
self.normalize_model_id(model_id)
return True
except ValueError:
return False
def _models_dir(self) -> Path:
configured_path = Path(settings.sam_model_path)
return configured_path.parent if configured_path.parent else Path("models")
def _variant(self, model_id: str | None) -> SAM2Variant:
return SAM2_VARIANTS[self.normalize_model_id(model_id)]
def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]:
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
models_dir = self._models_dir()
candidates: list[tuple[str, str]] = []
configured_path = Path(settings.sam_model_path)
if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file():
candidates.append((settings.sam_model_config, str(configured_path)))
candidates.extend([
(variant.config, str(models_dir / variant.checkpoint_filename)),
(variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)),
])
for config, checkpoint_path in candidates:
if os.path.isfile(checkpoint_path):
return config, checkpoint_path
return candidates[0]
def _load_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 model and predictor on first use.""" """Load the SAM 2 model and predictor on first use."""
if self._model_loaded: variant_id = self.normalize_model_id(model_id)
if self._model_loaded.get(variant_id):
return return
if not TORCH_AVAILABLE: if not TORCH_AVAILABLE:
self._last_error = "PyTorch is not installed." self._last_error[variant_id] = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.") logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded = True self._model_loaded[variant_id] = True
return return
if not SAM2_AVAILABLE: if not SAM2_AVAILABLE:
self._last_error = "sam2 package is not installed." self._last_error[variant_id] = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.") logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True self._model_loaded[variant_id] = True
return return
if not os.path.isfile(settings.sam_model_path): config, checkpoint_path = self._checkpoint_config(variant_id)
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}" if not os.path.isfile(checkpoint_path):
logger.error("SAM checkpoint not found at %s", settings.sam_model_path) self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
self._model_loaded = True logger.error("SAM checkpoint not found at %s", checkpoint_path)
self._model_loaded[variant_id] = True
return return
try: try:
device = self._best_device() device = self._best_device()
model = build_sam2( model = build_sam2(
settings.sam_model_config, config,
settings.sam_model_path, checkpoint_path,
device=device, device=device,
) )
self._predictor = SAM2ImagePredictor(model) self._predictors[variant_id] = SAM2ImagePredictor(model)
self._model_loaded = True self._model_loaded[variant_id] = True
self._loaded_device = device self._loaded_device[variant_id] = device
self._last_error = None self._last_error[variant_id] = None
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device) logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
self._last_error = str(exc) self._last_error[variant_id] = str(exc)
logger.error("Failed to load SAM 2 model: %s", exc) logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc)
self._model_loaded = True # Prevent repeated load attempts self._model_loaded[variant_id] = True # Prevent repeated load attempts
def _load_video_model(self) -> None: def _load_video_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 video predictor on first propagation use.""" """Load the SAM 2 video predictor on first propagation use."""
if self._video_model_loaded: variant_id = self.normalize_model_id(model_id)
if self._video_model_loaded.get(variant_id):
return return
if not TORCH_AVAILABLE: if not TORCH_AVAILABLE:
self._video_last_error = "PyTorch is not installed." self._video_last_error[variant_id] = "PyTorch is not installed."
self._video_model_loaded = True self._video_model_loaded[variant_id] = True
return return
if not SAM2_AVAILABLE: if not SAM2_AVAILABLE:
self._video_last_error = "sam2 package is not installed." self._video_last_error[variant_id] = "sam2 package is not installed."
self._video_model_loaded = True self._video_model_loaded[variant_id] = True
return return
if not os.path.isfile(settings.sam_model_path):
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}" config, checkpoint_path = self._checkpoint_config(variant_id)
self._video_model_loaded = True if not os.path.isfile(checkpoint_path):
self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
self._video_model_loaded[variant_id] = True
return return
try: try:
device = self._best_device() device = self._best_device()
self._video_predictor = build_sam2_video_predictor( self._video_predictors[variant_id] = build_sam2_video_predictor(
settings.sam_model_config, config,
settings.sam_model_path, checkpoint_path,
device=device, device=device,
) )
self._video_model_loaded = True self._video_model_loaded[variant_id] = True
self._loaded_device = device self._loaded_device[variant_id] = device
self._video_last_error = None self._video_last_error[variant_id] = None
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device) logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
self._video_last_error = str(exc) self._video_last_error[variant_id] = str(exc)
self._video_model_loaded = True self._video_model_loaded[variant_id] = True
logger.error("Failed to load SAM 2 video predictor: %s", exc) logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc)
def _best_device(self) -> str: def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available(): if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda" return "cuda"
return "cpu" return "cpu"
def _ensure_ready(self) -> bool: def _ensure_ready(self, model_id: str | None = None) -> bool:
"""Ensure the model is loaded; return whether it is usable.""" """Ensure the model is loaded; return whether it is usable."""
self._load_model() variant_id = self.normalize_model_id(model_id)
return SAM2_AVAILABLE and self._predictor is not None self._load_model(variant_id)
return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None
def _ensure_video_ready(self) -> bool: def _ensure_video_ready(self, model_id: str | None = None) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable.""" """Ensure the video predictor is loaded; return whether it is usable."""
self._load_video_model() variant_id = self.normalize_model_id(model_id)
return SAM2_AVAILABLE and self._video_predictor is not None self._load_video_model(variant_id)
return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None
def status(self) -> dict: def status(self, model_id: str | None = None) -> dict:
"""Return lightweight, real runtime status without forcing model load.""" """Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path) variant_id = self.normalize_model_id(model_id)
device = self._loaded_device or self._best_device() variant = SAM2_VARIANTS[variant_id]
_, checkpoint_path = self._checkpoint_config(variant_id)
checkpoint_exists = os.path.isfile(checkpoint_path)
using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename
predictor = self._predictors.get(variant_id)
device = self._loaded_device.get(variant_id) or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists) available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if self._predictor is not None: if predictor is not None:
message = "SAM 2 model loaded and ready." message = f"{variant.label} model loaded and ready."
elif available: elif available:
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference." message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference."
if using_legacy_checkpoint:
message += " Using legacy SAM 2 checkpoint fallback."
else: else:
missing = [] missing = []
if not TORCH_AVAILABLE: if not TORCH_AVAILABLE:
@@ -155,20 +276,21 @@ class SAM2Engine:
missing.append("sam2 package") missing.append("sam2 package")
if not checkpoint_exists: if not checkpoint_exists:
missing.append("checkpoint") missing.append("checkpoint")
message = f"SAM 2 unavailable: missing {', '.join(missing)}." message = f"{variant.label} unavailable: missing {', '.join(missing)}."
if self._last_error and not self._predictor: last_error = self._last_error.get(variant_id)
message = self._last_error if last_error and not predictor:
message = last_error
return { return {
"id": "sam2", "id": variant.id,
"label": "SAM 2", "label": variant.label,
"available": available, "available": available,
"loaded": self._predictor is not None, "loaded": predictor is not None,
"device": device, "device": device,
"supports": ["point", "box", "interactive", "auto", "propagate"], "supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message, "message": message,
"package_available": SAM2_AVAILABLE, "package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists, "checkpoint_exists": checkpoint_exists,
"checkpoint_path": settings.sam_model_path, "checkpoint_path": checkpoint_path,
"python_ok": True, "python_ok": True,
"torch_ok": TORCH_AVAILABLE, "torch_ok": TORCH_AVAILABLE,
"cuda_required": False, "cuda_required": False,
@@ -179,6 +301,7 @@ class SAM2Engine:
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
def predict_points( def predict_points(
self, self,
model_id: str | None,
image: np.ndarray, image: np.ndarray,
points: list[list[float]], points: list[list[float]],
labels: list[int], labels: list[int],
@@ -193,18 +316,20 @@ class SAM2Engine:
Returns: Returns:
Tuple of (polygons, scores). Tuple of (polygons, scores).
""" """
if not self._ensure_ready(): variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.") logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try: try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2] h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32) pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32) lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined] with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image) predictor.set_image(image)
masks, scores, _ = self._predictor.predict( masks, scores, _ = predictor.predict(
point_coords=pts, point_coords=pts,
point_labels=lbls, point_labels=lbls,
multimask_output=False, multimask_output=False,
@@ -223,6 +348,7 @@ class SAM2Engine:
def predict_box( def predict_box(
self, self,
model_id: str | None,
image: np.ndarray, image: np.ndarray,
box: list[float], box: list[float],
) -> tuple[list[list[list[float]]], list[float]]: ) -> tuple[list[list[list[float]]], list[float]]:
@@ -235,11 +361,13 @@ class SAM2Engine:
Returns: Returns:
Tuple of (polygons, scores). Tuple of (polygons, scores).
""" """
if not self._ensure_ready(): variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.") logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try: try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2] h, w = image.shape[:2]
bbox = np.array( bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h], [box[0] * w, box[1] * h, box[2] * w, box[3] * h],
@@ -247,8 +375,8 @@ class SAM2Engine:
) )
with torch.inference_mode(): # type: ignore[name-defined] with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image) predictor.set_image(image)
masks, scores, _ = self._predictor.predict( masks, scores, _ = predictor.predict(
box=bbox[None, :], box=bbox[None, :],
multimask_output=False, multimask_output=False,
) )
@@ -266,17 +394,20 @@ class SAM2Engine:
def predict_interactive( def predict_interactive(
self, self,
model_id: str | None,
image: np.ndarray, image: np.ndarray,
box: list[float] | None, box: list[float] | None,
points: list[list[float]], points: list[list[float]],
labels: list[int], labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]: ) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement.""" """Run combined box and point prompt segmentation for refinement."""
if not self._ensure_ready(): variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.") logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try: try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2] h, w = image.shape[:2]
bbox = None bbox = None
if box: if box:
@@ -291,8 +422,8 @@ class SAM2Engine:
lbls = np.array(labels, dtype=np.int32) lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined] with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image) predictor.set_image(image)
masks, scores, _ = self._predictor.predict( masks, scores, _ = predictor.predict(
point_coords=pts, point_coords=pts,
point_labels=lbls, point_labels=lbls,
box=bbox, box=bbox,
@@ -310,7 +441,7 @@ class SAM2Engine:
logger.error("SAM2 interactive prediction failed: %s", exc) logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]: def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points). """Run automatic mask generation (grid of points).
Args: Args:
@@ -319,20 +450,22 @@ class SAM2Engine:
Returns: Returns:
Tuple of (polygons, scores). Tuple of (polygons, scores).
""" """
if not self._ensure_ready(): variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.") logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try: try:
predictor = self._predictors[variant_id]
with torch.inference_mode(): # type: ignore[name-defined] with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image) predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts # Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2] h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h]) pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32) lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = self._predictor.predict( masks, scores, _ = predictor.predict(
point_coords=pts, point_coords=pts,
point_labels=lbls, point_labels=lbls,
multimask_output=False, multimask_output=False,
@@ -351,6 +484,7 @@ class SAM2Engine:
def propagate_video( def propagate_video(
self, self,
model_id: str | None,
frame_paths: list[str], frame_paths: list[str],
source_frame_index: int, source_frame_index: int,
seed: dict, seed: dict,
@@ -358,8 +492,10 @@ class SAM2Engine:
max_frames: int | None = None, max_frames: int | None = None,
) -> list[dict]: ) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video.""" """Propagate one seed mask across a prepared frame directory with SAM 2 video."""
if not self._ensure_video_ready(): variant_id = self.normalize_model_id(model_id)
raise RuntimeError(self._video_last_error or self.status()["message"]) if not self._ensure_video_ready(variant_id):
raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"])
video_predictor = self._video_predictors[variant_id]
if not frame_paths: if not frame_paths:
return [] return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths): if source_frame_index < 0 or source_frame_index >= len(frame_paths):
@@ -379,12 +515,12 @@ class SAM2Engine:
if not seed_mask.any(): if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.") raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = self._video_predictor.init_state( inference_state = video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]), video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True, offload_video_to_cpu=True,
offload_state_to_cpu=True, offload_state_to_cpu=True,
) )
self._video_predictor.add_new_mask( video_predictor.add_new_mask(
inference_state, inference_state,
frame_idx=source_frame_index, frame_idx=source_frame_index,
obj_id=1, obj_id=1,
@@ -394,7 +530,7 @@ class SAM2Engine:
results: dict[int, dict] = {} results: dict[int, dict] = {}
def collect(reverse: bool) -> None: def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video( for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
inference_state, inference_state,
start_frame_idx=source_frame_index, start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames, max_frame_num_to_track=max_frames,
@@ -427,7 +563,7 @@ class SAM2Engine:
collect(reverse=True) collect(reverse=True)
try: try:
self._video_predictor.reset_state(inference_state) video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
pass pass
return [results[index] for index in sorted(results)] return [results[index] for index in sorted(results)]

View File

@@ -5,8 +5,12 @@ from __future__ import annotations
from typing import Any from typing import Any
from config import settings from config import settings
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, TORCH_AVAILABLE, sam_engine as sam2_engine
from services.sam3_engine import sam3_engine
# SAM 3 integration is intentionally disabled for the current product flow.
# The source files are kept in the repository so the integration can be
# restored later, but the active registry only exposes SAM 2.
# from services.sam3_engine import sam3_engine
try: try:
import torch import torch
@@ -24,20 +28,23 @@ class SAMRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self._engines = { self._engines = {
"sam2": sam2_engine, "sam2": sam2_engine,
"sam3": sam3_engine, # "sam3": sam3_engine,
} }
def normalize_model_id(self, model_id: str | None) -> str: def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or "sam2").lower() selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
if self._engines["sam2"].is_sam2_model(selected):
return self._engines["sam2"].normalize_model_id(selected)
if selected not in self._engines: if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}") raise ValueError(f"Unsupported model: {model_id}")
return selected return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]: def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
selected = self.normalize_model_id(selected_model)
return { return {
"selected_model": self.normalize_model_id(selected_model), "selected_model": selected,
"gpu": self.gpu_status(), "gpu": self.gpu_status(),
"models": [engine.status() for engine in self._engines.values()], "models": [sam2_engine.status(model_id) for model_id in sam2_engine.variant_ids()],
} }
def gpu_status(self) -> dict[str, Any]: def gpu_status(self) -> dict[str, Any]:
@@ -52,20 +59,26 @@ class SAMRegistry:
} }
def _engine(self, model_id: str | None) -> Any: def _engine(self, model_id: str | None) -> Any:
return self._engines[self.normalize_model_id(model_id)] normalized = self.normalize_model_id(model_id)
if self._engines["sam2"].is_sam2_model(normalized):
return self._engines["sam2"]
return self._engines[normalized]
def _ensure_available(self, model_id: str | None) -> Any: def _ensure_available(self, model_id: str | None) -> Any:
normalized = self.normalize_model_id(model_id)
engine = self._engine(model_id) engine = self._engine(model_id)
status = engine.status() status = engine.status(normalized) if engine is sam2_engine else engine.status()
if not status["available"]: if not status["available"]:
raise ModelUnavailableError(status["message"]) raise ModelUnavailableError(status["message"])
return engine return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]): def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
return self._ensure_available(model_id).predict_points(image, points, labels) model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_points(model, image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]): def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box) model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_box(model, image, box)
def predict_interactive( def predict_interactive(
self, self,
@@ -76,12 +89,13 @@ class SAMRegistry:
labels: list[int], labels: list[int],
): ):
model = self.normalize_model_id(model_id) model = self.normalize_model_id(model_id)
if model != "sam2": if not sam2_engine.is_sam2_model(model):
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.") raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
return self._ensure_available(model).predict_interactive(image, box, points, labels) return self._ensure_available(model).predict_interactive(model, image, box, points, labels)
def predict_auto(self, model_id: str | None, image: Any): def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image) model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_auto(model, image)
def predict_semantic( def predict_semantic(
self, self,
@@ -90,14 +104,8 @@ class SAMRegistry:
text: str, text: str,
confidence_threshold: float | None = None, confidence_threshold: float | None = None,
): ):
model = self.normalize_model_id(model_id) self.normalize_model_id(model_id)
if model == "sam3": raise NotImplementedError("Semantic text prompting is disabled; use SAM 2 point or box prompts.")
return self._ensure_available(model).predict_semantic(
image,
text,
confidence_threshold=confidence_threshold,
)
return self._ensure_available(model).predict_auto(image)
def propagate_video( def propagate_video(
self, self,
@@ -108,7 +116,9 @@ class SAMRegistry:
direction: str, direction: str,
max_frames: int | None, max_frames: int | None,
): ):
return self._ensure_available(model_id).propagate_video( model = self.normalize_model_id(model_id)
return self._ensure_available(model).propagate_video(
model,
frame_paths, frame_paths,
source_frame_index, source_frame_index,
seed, seed,

View File

@@ -87,28 +87,14 @@ def test_predict_applies_crop_and_background_filter_options(client, monkeypatch)
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point) assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
def test_predict_box_and_semantic_fallback(client, monkeypatch): def test_predict_box_and_rejects_semantic_prompt(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client) _, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: ( monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.8], [0.8],
)) ))
def fake_predict_semantic(model, image, text, confidence_threshold=None):
calls["semantic"] = {
"model": model,
"text": text,
"confidence_threshold": confidence_threshold,
}
return (
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
[0.5],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", fake_predict_semantic)
box_response = client.post("/api/ai/predict", json={ box_response = client.post("/api/ai/predict", json={
"image_id": frame["id"], "image_id": frame["id"],
"prompt_type": "box", "prompt_type": "box",
@@ -124,13 +110,8 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
assert box_response.status_code == 200 assert box_response.status_code == 200
assert box_response.json()["scores"] == [0.8] assert box_response.json()["scores"] == [0.8]
assert semantic_response.status_code == 200 assert semantic_response.status_code == 400
assert semantic_response.json()["scores"] == [0.5] assert "Unsupported model: sam3" in semantic_response.json()["detail"]
assert calls["semantic"] == {
"model": "sam3",
"text": "胆囊",
"confidence_threshold": 0.05,
}
def test_predict_interactive_combines_box_and_points(client, monkeypatch): def test_predict_interactive_combines_box_and_points(client, monkeypatch):
@@ -158,13 +139,13 @@ def test_predict_interactive_combines_box_and_points(client, monkeypatch):
"points": [[0.5, 0.5], [0.2, 0.2]], "points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0], "labels": [1, 0],
}, },
"model": "sam2", "model": "sam2.1_hiera_small",
}) })
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["scores"] == [0.88] assert response.json()["scores"] == [0.88]
assert calls == { assert calls == {
"model": "sam2", "model": "sam2.1_hiera_small",
"box": [0.1, 0.1, 0.9, 0.9], "box": [0.1, 0.1, 0.9, 0.9],
"points": [[0.5, 0.5], [0.2, 0.2]], "points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0], "labels": [1, 0],
@@ -173,7 +154,7 @@ def test_predict_interactive_combines_box_and_points(client, monkeypatch):
def test_model_status_reports_runtime(client, monkeypatch): def test_model_status_reports_runtime(client, monkeypatch):
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: { monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
"selected_model": selected_model or "sam2", "selected_model": "sam2.1_hiera_tiny",
"gpu": { "gpu": {
"available": False, "available": False,
"device": "cpu", "device": "cpu",
@@ -184,8 +165,8 @@ def test_model_status_reports_runtime(client, monkeypatch):
}, },
"models": [ "models": [
{ {
"id": "sam2", "id": "sam2.1_hiera_tiny",
"label": "SAM 2", "label": "SAM 2.1 Tiny",
"available": True, "available": True,
"loaded": False, "loaded": False,
"device": "cpu", "device": "cpu",
@@ -198,31 +179,23 @@ def test_model_status_reports_runtime(client, monkeypatch):
"torch_ok": True, "torch_ok": True,
"cuda_required": False, "cuda_required": False,
}, },
{
"id": "sam3",
"label": "SAM 3",
"available": False,
"loaded": False,
"device": "unavailable",
"supports": ["semantic"],
"message": "missing Python 3.12+ runtime",
"package_available": False,
"checkpoint_exists": False,
"checkpoint_path": None,
"python_ok": False,
"torch_ok": True,
"cuda_required": True,
},
], ],
}) })
response = client.get("/api/ai/models/status?selected_model=sam3") response = client.get("/api/ai/models/status")
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
assert body["selected_model"] == "sam3" assert body["selected_model"] == "sam2.1_hiera_tiny"
assert body["models"][1]["id"] == "sam3" assert len(body["models"]) == 1
assert body["models"][1]["available"] is False assert body["models"][0]["id"] == "sam2.1_hiera_tiny"
def test_model_status_rejects_disabled_sam3(client):
response = client.get("/api/ai/models/status?selected_model=sam3")
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
def test_propagate_saves_tracked_annotations(client, monkeypatch): def test_propagate_saves_tracked_annotations(client, monkeypatch):
@@ -267,7 +240,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
response = client.post("/api/ai/propagate", json={ response = client.post("/api/ai/propagate", json={
"project_id": project["id"], "project_id": project["id"],
"frame_id": frames[0]["id"], "frame_id": frames[0]["id"],
"model": "sam2", "model": "sam2.1_hiera_tiny",
"direction": "forward", "direction": "forward",
"max_frames": 2, "max_frames": 2,
"include_source": False, "include_source": False,
@@ -285,13 +258,13 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
body = response.json() body = response.json()
assert body["created_annotation_count"] == 1 assert body["created_annotation_count"] == 1
assert body["processed_frame_count"] == 2 assert body["processed_frame_count"] == 2
assert calls["model"] == "sam2" assert calls["model"] == "sam2.1_hiera_tiny"
assert calls["source_frame_index"] == 0 assert calls["source_frame_index"] == 0
assert calls["direction"] == "forward" assert calls["direction"] == "forward"
assert calls["frame_count"] == 2 assert calls["frame_count"] == 2
saved = body["annotations"][0] saved = body["annotations"][0]
assert saved["frame_id"] == frames[1]["id"] assert saved["frame_id"] == frames[1]["id"]
assert saved["mask_data"]["source"] == "sam2_propagation" assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
assert saved["mask_data"]["class"]["name"] == "胆囊" assert saved["mask_data"]["class"]["name"] == "胆囊"
assert saved["mask_data"]["score"] == 0.8 assert saved["mask_data"]["score"] == 0.8

View File

@@ -69,3 +69,44 @@ def test_dashboard_overview_uses_persisted_records(client, db_session):
assert any(item["kind"] == "annotation" for item in body["activity"]) assert any(item["kind"] == "annotation" for item in body["activity"])
assert any(item["kind"] == "template" for item in body["activity"]) assert any(item["kind"] == "template" for item in body["activity"])
assert all(item["name"] != "Ready Project" for item in body["tasks"]) assert all(item["name"] != "Ready Project" for item in body["tasks"])
def test_dashboard_overview_keeps_recent_success_tasks_in_progress_list(client, db_session):
from models import ProcessingTask
project = client.post("/api/projects", json={
"name": "Completed Project",
"status": "ready",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="success",
progress=100,
message="解析完成",
project_id=project["id"],
payload={"source_type": "video"},
result={"frames_extracted": 120},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
response = client.get("/api/dashboard/overview")
assert response.status_code == 200
body = response.json()
assert body["summary"]["parsing_task_count"] == 0
assert body["tasks"] == [
{
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": project["id"],
"name": "Completed Project",
"progress": 100,
"status": "解析完成",
"raw_status": "success",
"frame_count": 120,
"error": None,
"updated_at": body["tasks"][0]["updated_at"],
},
]

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from services.sam2_engine import SAM2Engine from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, SAM2Engine
class _FakePredictor: class _FakePredictor:
@@ -26,8 +26,8 @@ def _mask(offset=0):
def _ready_engine(monkeypatch, predictor): def _ready_engine(monkeypatch, predictor):
monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True) monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True)
engine = SAM2Engine() engine = SAM2Engine()
engine._model_loaded = True engine._model_loaded[DEFAULT_SAM2_MODEL_ID] = True
engine._predictor = predictor engine._predictors[DEFAULT_SAM2_MODEL_ID] = predictor
return engine return engine
@@ -39,6 +39,7 @@ def test_sam2_point_prediction_requests_single_best_mask(monkeypatch):
engine = _ready_engine(monkeypatch, predictor) engine = _ready_engine(monkeypatch, predictor)
polygons, scores = engine.predict_points( polygons, scores = engine.predict_points(
DEFAULT_SAM2_MODEL_ID,
np.zeros((32, 32, 3), dtype=np.uint8), np.zeros((32, 32, 3), dtype=np.uint8),
[[0.5, 0.5]], [[0.5, 0.5]],
[1], [1],
@@ -56,8 +57,24 @@ def test_sam2_auto_prediction_keeps_single_best_mask(monkeypatch):
) )
engine = _ready_engine(monkeypatch, predictor) engine = _ready_engine(monkeypatch, predictor)
polygons, scores = engine.predict_auto(np.zeros((32, 32, 3), dtype=np.uint8)) polygons, scores = engine.predict_auto(DEFAULT_SAM2_MODEL_ID, np.zeros((32, 32, 3), dtype=np.uint8))
assert predictor.calls[0]["multimask_output"] is False assert predictor.calls[0]["multimask_output"] is False
assert len(polygons) == 1 assert len(polygons) == 1
assert scores == [0.800000011920929] assert scores == [0.800000011920929]
def test_sam2_status_exposes_selectable_variants(monkeypatch, tmp_path):
checkpoint = tmp_path / "sam2.1_hiera_small.pt"
checkpoint.write_bytes(b"model")
monkeypatch.setattr("services.sam2_engine.settings.sam_model_path", str(tmp_path / "sam2.1_hiera_tiny.pt"))
engine = SAM2Engine()
status = engine.status("sam2.1_hiera_small")
assert engine.normalize_model_id("sam2") == DEFAULT_SAM2_MODEL_ID
assert "sam2.1_hiera_small" in engine.variant_ids()
assert status["id"] == "sam2.1_hiera_small"
assert status["label"] == "SAM 2.1 Small"
assert status["checkpoint_exists"] is True
assert status["checkpoint_path"].endswith("sam2.1_hiera_small.pt")

View File

@@ -2,6 +2,12 @@ import json
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pytest
pytest.skip(
"SAM 3 integration is disabled in the current SAM2-only product flow.",
allow_module_level=True,
)
from services.sam3_engine import SAM3Engine from services.sam3_engine import SAM3Engine
from services.sam3_external_worker import _prediction_to_response, _to_numpy from services.sam3_external_worker import _prediction_to_response, _to_numpy

View File

@@ -20,7 +20,7 @@ Word 方案描述的理想系统包含:
- FastAPI 后端,使用 WebSocket 处理实时交互与任务进度。 - FastAPI 后端,使用 WebSocket 处理实时交互与任务进度。
- Celery + Redis 处理视频拆帧等长任务。 - Celery + Redis 处理视频拆帧等长任务。
- FFmpeg/OpenCV 解析视频pydicom 解析医学影像。 - FFmpeg/OpenCV 解析视频pydicom 解析医学影像。
- 本地 CUDA 上的 SAM 3 推理。 - 本地 CUDA 上的 SAM 推理;当前产品实现启用可选 SAM 2.1 tiny/small/base+/largeSAM 3 因没有文本提示入口而暂时禁用
- GT mask 导入后通过距离变换、骨架提取、聚类等算法降维为点区域。 - GT mask 导入后通过距离变换、骨架提取、聚类等算法降维为点区域。
- 模板库管理分类、颜色和 z-index用于语义分割遮罩重叠裁决。 - 模板库管理分类、颜色和 z-index用于语义分割遮罩重叠裁决。
- PostgreSQL 存储项目、帧、模板和点区域数据。 - PostgreSQL 存储项目、帧、模板和点区域数据。
@@ -38,14 +38,14 @@ Word 方案描述的理想系统包含:
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` | | 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` |
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 | | DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` | | WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` |
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 2 已接 video predictor 片段传播SAM 3 通过独立 Python 3.12 环境桥接,支持文本/框提示和 official video tracker 入口,状态会检查 Python/CUDA/包/本地 checkpoint | | SAM 推理 | 部分落地 | 当前产品入口启用 SAM 2.1 tiny/small/base+/large 和真实 GPU/SAM2.1 状态接口SAM 2.1 已接 point/box/interactive 和 video predictor 片段传播SAM 3 桥接源码保留,但前端入口和后端 registry 已禁用 |
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 | | 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
| 标注持久化 | 部分落地 | 后端有 `Annotation`前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 | | 标注持久化 | 部分落地 | 后端有 `Annotation`前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`COCO JSON 和 PNG mask ZIP 前端按钮均已接入ZIP 包含单标注 mask、语义融合 mask 和类别映射 | | COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`COCO JSON 和 PNG mask ZIP 前端按钮均已接入ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
## 当前代码尚未落地的目标 ## 当前代码尚未落地的目标
- SAM 3当前已提供 `sam3_engine.py` 外部环境桥接`sam3_external_worker.py``setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU、官方包导入和本地 `sam3权重/sam3.pt` checkpoint 状态检查。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tinyvideo tracker 入口已接入,真实效果取决于本地 checkpoint 是否兼容 video model - SAM 3`sam3_engine.py``sam3_external_worker.py``setup_sam3_env.sh` 作为历史实现保留;由于当前系统不给文本提示,前端不再展示 SAM 3后端 registry 也不暴露 `sam3`。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2.1 tiny。
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。 - Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。
- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point骨架提取、HDBSCAN 和模板自动映射尚未实现。 - GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point骨架提取、HDBSCAN 和模板自动映射尚未实现。
- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 等增强尚未实现。 - Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 等增强尚未实现。
@@ -55,4 +55,4 @@ Word 方案描述的理想系统包含:
## 结论 ## 结论
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可用 SAM 2 / SAM 3 进行视频片段传播、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 SAM 3 真实视频 tracker smoke test、复杂洞结构编辑GT mask 骨架/聚类增强。 当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可用可选 SAM 2.1 做点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可用 SAM 2.1 进行视频片段传播、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐复杂洞结构编辑GT mask 骨架/聚类增强和传播任务异步化

View File

@@ -99,7 +99,7 @@
## 当前主要风险点 ## 当前主要风险点
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。 - 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖、并具备 Hugging Face gated 权重访问时走 SAM 3当前状态接口会分别暴露外部 Python 环境、CUDA、包导入和 checkpoint access 是否满足 - AI 当前启用 SAM 2.1 tiny/small/base+/large 点/框/interactive 路径;语义文本提示和 SAM 3 产品入口已禁用,`model=sam3` 会被后端拒绝。SAM 3 源码保留但不计入当前可用功能
- 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。 - 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。 - Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。
- 后端路由大多未做真实鉴权。 - 后端路由大多未做真实鉴权。

View File

@@ -30,7 +30,7 @@
| 元素 | 状态 | 说明 | | 元素 | 状态 | 说明 |
|------|------|------| |------|------|------|
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` | | WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/failed/cancelled 任务生成 | | 任务进度 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/success/failed/cancelled 任务生成;统计卡片中的处理中任务数只计算 queued/running |
| 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` | | 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` |
| 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 | | 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 |
| 失败详情 | 真实可用 | 任务详情按钮调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间 | | 失败详情 | 真实可用 | 任务详情按钮调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间 |
@@ -61,13 +61,13 @@
| 当前项目名 | 真实可用 | 读取 `currentProject.name` | | 当前项目名 | 真实可用 | 读取 `currentProject.name` |
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` | | 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
| 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 | | 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 |
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 | | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前启用的 SAM 2 与 GPU 状态 |
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask | | 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask;回显时保留当前项目帧里尚未保存的 AI/手工 draft mask避免从 AI 页推送的候选被覆盖 |
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON | | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON |
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` | | “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point再回显到工作区 | | “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point再回显到工作区 |
| “传播片段”按钮 | 真实可用 | 使用当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`SAM 2 video predictorSAM 3 走 external video tracker完成后刷新已保存标注 | | “传播片段”按钮 | 真实可用 | 使用当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`当前启用 SAM 2 video predictor完成后刷新已保存标注 |
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}` | | “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}`;保存成功后会重新拉取后端标注,并用 saved annotation 替换本次提交的 draft mask避免仍显示未保存 |
## CanvasArea 画布 ## CanvasArea 画布
@@ -109,6 +109,7 @@
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` | | 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
| 顶部 range 拖动 | 真实可用 | 改变当前帧 | | 顶部 range 拖动 | 真实可用 | 改变当前帧 |
| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` | | 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` |
| 已编辑帧标记带 | 真实可用 | 根据当前项目帧内的 `masks` 计算有编辑/标注的帧,在顶部进度条和缩略图导航轴之间显示标记;点击标记可跳转到对应帧 |
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps | | 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 | | 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 |
@@ -127,16 +128,17 @@
| 元素 | 状态 | 说明 | | 元素 | 状态 | 说明 |
|------|------|------| |------|------|------|
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand`predictMask()` 会把 `model` 传给后端 SAM registry | | SAM 2.1 变体选择 / 模型状态 | 真实可用 | AI 页可选 tiny/small/base+/large调用 `GET /api/ai/models/status?selected_model=<variant>` 展示所选变体和 GPU 状态;只有本地存在 checkpoint 的变体显示可用 |
| 正向/反向点 | 真实可用 | 可在当前项目帧上加点并调用 AI 推理接口SAM 2 框选后会携带原始框和累计正/反点细化同一个候选 maskSAM 3 选择后会提示点交互需切回 SAM 2 | | 正向/反向点 | 真实可用 | 可在当前项目帧上加点并调用 AI 推理接口;AI 页中点击已有候选 mask 时也会继续添加当前正/反向提示点;SAM 2.1 框选后会携带原始框和累计正/反点细化同一个候选 mask |
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 | | SAM 3 入口 | 当前禁用 | 因当前系统不提供文本提示,前端不再显示 SAM 3 模型选择、文本输入或 SAM 3 框选入口;后端 `model=sam3` 返回不支持 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 | | 语义文本输入 | 当前禁用 | AI 页不再提供文本语义输入;后端收到 `semantic` prompt 会返回 400 |
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 | | 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 |
| 执行高精度语义分割 | 真实可用 | 使用当前项目帧调用 `/api/ai/predict`SAM 2 需要点提示且只采用最高分候选SAM 3 需要文本语义提示;生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 | | 遮罩清晰度 | 真实可用 | 调节 AI 页候选 mask 的预览透明度,只影响本页显示,不改变 mask 几何、分类或保存数据 |
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的 mask便于继续调轮廓和归档 | | 执行高精度语义分割 | 真实可用 | 使用当前项目帧和所选 SAM 2.1 变体调用 `/api/ai/predict`SAM 2.1 需要点/框提示且只采用最高分候选AI 页只渲染本页新生成候选,不显示工作区已有 mask生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 |
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的未保存 mask工作区回显后端标注时不会覆盖这类 draft mask |
| 上传替换底图 | Mock / UI-only | 按钮无事件 | | 上传替换底图 | Mock / UI-only | 按钮无事件 |
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 | | 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks | | 清空全体锚点 | 真实可用 | 清空 AI 页提示点和本页生成的候选 mask不删除工作区已有 mask |
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store但没有保存/确认流程 | | 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store但没有保存/确认流程 |
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 | | 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
@@ -158,4 +160,4 @@
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。 当前最主要的 Mock 或未打通链路是:真正的文本语义分割已因无文本提示入口而暂时禁用;复杂洞结构编辑、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标仍未落地

View File

@@ -39,7 +39,7 @@ Authorization: Bearer <token>
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url以及 `timestamp_ms``source_frame_number` | | `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url以及 `timestamp_ms``source_frame_number` |
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` | | `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 | | `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 |
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 | | `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU 和四个 SAM 2.1 变体状态;`selected_model=sam3` 返回不支持 |
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 | | `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask | | `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask | | `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
@@ -76,8 +76,8 @@ Authorization: Bearer <token>
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 | | GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 | | POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 | | POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 | | POST | `/api/ai/predict` | 当前启用 SAM 2 点/框/interactive 推理 |
| POST | `/api/ai/propagate` | SAM 2 / SAM 3 视频片段传播并保存标注 | | POST | `/api/ai/propagate` | 当前启用 SAM 2 视频片段传播并保存标注 |
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 | | GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
| POST | `/api/ai/auto` | 自动分割 | | POST | `/api/ai/auto` | 自动分割 |
| POST | `/api/ai/annotate` | 保存 AI 标注 | | POST | `/api/ai/annotate` | 保存 AI 标注 |
@@ -173,7 +173,7 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
```json ```json
{ {
"image_id": 123, "image_id": 123,
"model": "sam2", "model": "sam2.1_hiera_tiny",
"prompt_type": "point", "prompt_type": "point",
"prompt_data": { "prompt_data": {
"points": [[0.5, 0.5]], "points": [[0.5, 0.5]],
@@ -187,19 +187,19 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
- `point` - `point`
- `box` - `box`
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points``labels` - `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points``labels`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理。前端 AI 页面不会再用 SAM 2 发送纯文本 semanticSAM 2 的交互入口应使用点/框提示。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定 - `semantic` 当前被禁用;由于产品不提供文本提示,前端不会显示语义文本入口,后端收到 semantic 会返回 400
SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同一提示下多个备选 mask 被前端叠加显示。 SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同一提示下多个备选 mask 被前端叠加显示。
工作区 SAM 2 请求包含反向点时,`CanvasArea` 会发送 `options.auto_filter_background=true``options.min_score=0.05`;如果负向点过滤后没有可用 polygon前端会移除当前旧候选 mask 并要求重新框选或添加正向点。 工作区 SAM 2 请求包含反向点时,`CanvasArea` 会发送 `options.auto_filter_background=true``options.min_score=0.05`;如果负向点过滤后没有可用 polygon前端会移除当前旧候选 mask 并要求重新框选或添加正向点。
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正 当前 registry 暴露 `sam2.1_hiera_tiny``sam2.1_hiera_small``sam2.1_hiera_base_plus``sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;发送 `model=sam3` 会返回 400 Unsupported model。SAM 3 源码文件保留在仓库中,但没有接入当前运行时模型列表
可选 `options` 字段: 可选 `options` 字段:
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 - `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。 - `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值;对 SAM 3 semantic 请求也会作为 external worker 的 `confidence_threshold` 传入,避免本地 checkpoint 在默认高阈值下返回 0 个 mask - `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
后端响应: 后端响应:
@@ -234,7 +234,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同
{ {
"project_id": 1, "project_id": 1,
"frame_id": 123, "frame_id": 123,
"model": "sam2", "model": "sam2.1_hiera_tiny",
"direction": "forward", "direction": "forward",
"max_frames": 30, "max_frames": 30,
"include_source": false, "include_source": false,
@@ -250,7 +250,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同
} }
``` ```
`model=sam2` 使用 SAM 2 video predictor 的 mask seed 传播;`model=sam3` 使用独立 Python 3.12 helper 中的 SAM 3 video tracker并以 seed bbox 作为初始提示。响应会返回已创建的 `annotations`,保存的 `mask_data.source``sam2_propagation``sam3_propagation` SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2` 会兼容归一化为 tiny`model=sam3` 当前不支持。响应会返回已创建的 `annotations`,保存的 `mask_data.source``<model_id>_propagation`
## 已完成的接口对齐 ## 已完成的接口对齐
@@ -270,7 +270,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask避免同
- `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel` - `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`
- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry` - `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。 - `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
- Dashboard 任务列表已展示 queued/running/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情。 - Dashboard 任务列表已展示 queued/running/success/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情`summary.parsing_task_count` 仍只统计 queued/running
- 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。 - 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。
- PNG mask ZIP 已包含每帧 `semantic_frame_*.png``semantic_classes.json`,重叠区域按 zIndex 裁决。 - PNG mask ZIP 已包含每帧 `semantic_frame_*.png``semantic_classes.json`,重叠区域按 zIndex 裁决。

View File

@@ -16,7 +16,7 @@
剩余边界: 剩余边界:
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接、本地 `sam3权重/sam3.pt` checkpoint 状态检查、本地 checkpoint 加载参数接入、单帧文本/框提示和 video tracker API 入口;下一步需要基于真实业务帧验证语义召回质量和视频 tracker 稳定性 1. SAM 3 相关源码和安装脚本保留,但当前产品入口已禁用:前端不展示 SAM 3后端 registry 不暴露 `sam3``model=sam3` 请求返回不支持。若后续重新需要文本语义提示再恢复前端入口、registry、状态接口和对应测试
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。 2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。
## 阶段 2打通标注保存已完成基础闭环 ## 阶段 2打通标注保存已完成基础闭环
@@ -136,14 +136,14 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前
1. 前端 `propagateMasks()` 已接入 `POST /api/ai/propagate` 1. 前端 `propagateMasks()` 已接入 `POST /api/ai/propagate`
2. 工作区按钮会把 seed mask 的 normalized polygon、bbox、label、color 和 class 元数据传给后端。 2. 工作区按钮会把 seed mask 的 normalized polygon、bbox、label、color 和 class 元数据传给后端。
3. SAM 2 路径使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()` 3. SAM 2 路径使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`
4. SAM 3 路径通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境中的官方 `build_sam3_video_predictor()` 4. SAM 3 video tracker 路径已从当前产品入口禁用,相关 helper 仅保留作后续恢复参考
5. 后端会跳过源帧,把传播结果保存到后续帧 `annotations`,并在完成后由前端刷新回显。 5. 后端会跳过源帧,把传播结果保存到后续帧 `annotations`,并在完成后由前端刷新回显。
剩余建议: 剩余建议:
1. 把传播任务改为异步任务,接入 Dashboard 和 WebSocket 进度。 1. 把传播任务改为异步任务,接入 Dashboard 和 WebSocket 进度。
2. 前端增加传播方向、帧数和覆盖已有标注策略设置。 2. 前端增加传播方向、帧数和覆盖已有标注策略设置。
3. 用真实长视频分别做 SAM 2 / SAM 3 tracker smoke test 和质量评估。 3. 用真实长视频做 SAM 2 tracker smoke test 和质量评估;如果未来恢复 SAM 3再单独补充 SAM 3 tracker 评估。
## 阶段 8清理 UI 文案与 Mock ## 阶段 8清理 UI 文案与 Mock

View File

@@ -65,7 +65,7 @@ FastAPI 会根据代码里的路由和 Pydantic schema 自动生成 OpenAPI 描
- Projects项目 CRUD、项目帧 CRUD - Projects项目 CRUD、项目帧 CRUD
- Templates模板 CRUD - Templates模板 CRUD
- Media上传视频/DICOM、触发拆帧 - Media上传视频/DICOM、触发拆帧
- AISAM 2 / SAM 3 可选推理、模型状态、自动分割、保存标注 - AI当前启用 SAM 2 推理、模型状态、自动分割、保存标注SAM 3 源码保留但产品入口禁用
- Export导出 COCO JSON、导出 PNG masks - Export导出 COCO JSON、导出 PNG masks
- Health健康检查 - Health健康检查

View File

@@ -29,7 +29,7 @@
- 未提供项目 ID 上传时,后端自动创建项目。 - 未提供项目 ID 上传时,后端自动创建项目。
- 提供项目 ID 上传时,后端把上传对象关联到该项目。 - 提供项目 ID 上传时,后端把上传对象关联到该项目。
- 拆帧接口根据项目 `source_type` 处理视频或 DICOM。 - 拆帧接口根据项目 `source_type` 处理视频或 DICOM。
- 拆帧接口支持 `parse_fps``max_frames``target_width` 参数,用于生成可被 SAM 2 / SAM 3 视频处理复用的标准帧序列。 - 拆帧接口支持 `parse_fps``max_frames``target_width` 参数,用于生成可被 SAM 2 视频处理复用的标准帧序列。
- 视频帧使用连续 `frame_%06d.jpg` 命名,默认从 `frame_000000.jpg` 开始,并按 `target_width` 缩放。 - 视频帧使用连续 `frame_%06d.jpg` 命名,默认从 `frame_000000.jpg` 开始,并按 `target_width` 缩放。
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready` - 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`
- 每条帧记录包含 `frame_index``image_url``width``height``timestamp_ms``source_frame_number` - 每条帧记录包含 `frame_index``image_url``width``height``timestamp_ms``source_frame_number`
@@ -49,6 +49,7 @@
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 - 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
- 播放帧率使用项目 `parse_fps``original_fps`,限制在 1 到 30 FPS。 - 播放帧率使用项目 `parse_fps``original_fps`,限制在 1 到 30 FPS。
- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps``original_fps`,格式为 `mm:ss.cc` - 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps``original_fps`,格式为 `mm:ss.cc`
- 时间轴在顶部进度条和底部缩略图导航轴之间显示“已编辑”标记带,基于当前项目帧内的 `masks` 标出已有编辑/标注的帧;点击标记可跳转到对应帧。
## R5 工具栏 ## R5 工具栏
@@ -71,26 +72,30 @@
## R6 AI 推理 ## R6 AI 推理
-端可以在 AI 页面选择 `sam2``sam3`,选择结果存放在全局 store - 前 AI 页面支持选择 `sam2.1_hiera_tiny``sam2.1_hiera_small``sam2.1_hiera_base_plus``sam2.1_hiera_large`SAM 3 选择、文本输入和相关状态展示已隐藏
- 前端和工作区通过 `GET /api/ai/models/status` 展示 GPU、SAM 2 和 SAM 3 的真实运行状态 - 前端和工作区通过 `GET /api/ai/models/status` 展示 GPU 和四个 SAM 2.1 变体的真实运行状态;`selected_model=sam3` 会被后端拒绝
- 前端 `predictMask()` 调用 `POST /api/ai/predict` - 前端 `predictMask()` 调用 `POST /api/ai/predict`
- 前端发送后端契约:`image_id``prompt_type``prompt_data``model` - 前端发送后端契约:`image_id``prompt_type``prompt_data``model`
- 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。 - 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。
- AI 页面在已有候选 mask 上点击正向/反向选点时,应继续添加提示点,不应被 mask 选择事件拦截。
- 框选提示传归一化 `[x1, y1, x2, y2]` - 框选提示传归一化 `[x1, y1, x2, y2]`
- 工作区 SAM 2 框选会建立一个候选 mask后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。 - 工作区 SAM 2.1 框选会建立一个候选 mask后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。
- 工作区 SAM 2 一旦包含反向点,会随请求启用 `auto_filter_background``min_score=0.05`;若后端判定反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask避免继续显示已被否定的区域。 - 工作区 SAM 2.1 一旦包含反向点,会随请求启用 `auto_filter_background``min_score=0.05`;若后端判定反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask避免继续显示已被否定的区域。
- SAM 2 不支持文本语义提示AI 页面在 SAM 2 下输入纯文本时会提示用户改用点提示或切换 SAM 3不再回退到自动分割 - SAM 2.1 不支持文本语义提示;当前 AI 页面不提供文本语义输入,必须使用点/框提示
- SAM 2 点提示和 auto fallback 默认只采用一个最高分候选 mask避免多个候选 mask 作为同一结果重叠显示。 - SAM 2.1 点提示和 auto fallback 默认只采用一个最高分候选 mask避免多个候选 mask 作为同一结果重叠显示。
- AI 页面生成的 SAM 2/SAM 3 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签 - AI 页面只渲染本页新生成的候选 mask工作区已有手工、保存、传播或 GT 导入 mask 不会自动进入 AI 画布
- AI 页面提供“遮罩清晰度”滑杆,调节本页候选 mask 的预览透明度,不改变 mask 几何、分类或保存数据。
- AI 页面生成的 SAM 2.1 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签。
- AI 页“清空全体锚点”只清空本页提示点和本页生成的候选 mask不删除工作区已有 mask。
- AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 以便继续编辑轮廓和归档保存。 - AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 以便继续编辑轮廓和归档保存。
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理 - 工作区加载后端已保存标注时,必须保留当前项目帧里尚未保存的 AI/手工 draft mask避免 AI 页推送到工作区的候选 mask 被异步回显流程覆盖
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框 - 语义文本提示 `semantic` 当前被后端禁用并返回 400
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2 - SAM 3 源码和历史测试保留,但不属于当前产品可用功能;前端不再展示 SAM 3 入口,后端 registry 不暴露 `sam3`
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。 - 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
- `POST /api/ai/propagate` 支持 `model=sam2` `model=sam3`SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker - `POST /api/ai/propagate` 当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`
- 传播结果会写入后续帧 `annotations``mask_data.source` 分别标记为 `sam2_propagation``sam3_propagation`,并保留 label、color 和 class 元数据。 - 传播结果会写入后续帧 `annotations``mask_data.source` 标记为 `<model_id>_propagation`,并保留 label、color 和 class 元数据。
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。 - AI 页面会对未放置点提示、后端错误和返回 0 个 mask 的情况显示明确反馈。
- AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygonSAM 3 semantic 会用 `min_score` 控制 external worker 的置信度阈值 - AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
- 后端返回 `polygons``scores` - 后端返回 `polygons``scores`
- 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area` - 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area`
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。 - AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
@@ -103,6 +108,7 @@
- 后端提供 `PATCH /api/ai/annotations/{annotation_id}` 更新已保存标注的 `mask_data``points``bbox``template_id` - 后端提供 `PATCH /api/ai/annotations/{annotation_id}` 更新已保存标注的 `mask_data``points``bbox``template_id`
- 后端提供 `DELETE /api/ai/annotations/{annotation_id}` 删除已保存标注。 - 后端提供 `DELETE /api/ai/annotations/{annotation_id}` 删除已保存标注。
- 当前前端“结构化归档保存”会保存当前项目未保存 mask并会更新已标记为 dirty 的已保存 mask。 - 当前前端“结构化归档保存”会保存当前项目未保存 mask并会更新已标记为 dirty 的已保存 mask。
- 保存成功后,前端会重新拉取后端标注,并用后端 saved annotation 替换本次提交的 draft mask未提交的其他 draft mask 仍保留。
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。 - 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
- 工作区加载项目帧后会查询已保存标注并回显。 - 工作区加载项目帧后会查询已保存标注并回显。
- 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask` - 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask`
@@ -128,10 +134,10 @@
## R10 Dashboard 与 WebSocket ## R10 Dashboard 与 WebSocket
- Dashboard 显示基础统计、解析队列和活动日志。 - Dashboard 显示基础统计、任务进度和活动日志。
- Dashboard 初始数据来自 `GET /api/dashboard/overview` - Dashboard 初始数据来自 `GET /api/dashboard/overview`
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。 - 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
- 解析队列`processing_tasks` 中的 queued/running/failed/cancelled 任务生成;活动日志由最近任务、项目、标注和模板记录生成。 - 任务进度`processing_tasks` 中的 queued/running/success/failed/cancelled 任务生成,避免刚完成任务从进度区立即消失;处理中任务数统计只计算 queued/running;活动日志由最近任务、项目、标注和模板记录生成。
- Dashboard 对 queued/running 任务提供取消按钮,对 failed/cancelled 任务提供重试按钮。 - Dashboard 对 queued/running 任务提供取消按钮,对 failed/cancelled 任务提供重试按钮。
- Dashboard 任务详情会读取 `GET /api/tasks/{task_id}` 并展示失败 error、payload、result、Celery ID 和时间信息。 - Dashboard 任务详情会读取 `GET /api/tasks/{task_id}` 并展示失败 error、payload、result、Celery ID 和时间信息。
- Dashboard 会连接 `/ws/progress` - Dashboard 会连接 `/ws/progress`

View File

@@ -10,7 +10,7 @@
- React + TypeScript 前端 SPA。 - React + TypeScript 前端 SPA。
- FastAPI 后端 API。 - FastAPI 后端 API。
- PostgreSQL、MinIO、Redis、SAM 2 / SAM 3 等外部基础设施。 - PostgreSQL、MinIO、Redis、SAM 2 等外部基础设施。SAM 3 相关源码保留,但当前产品入口禁用
开发时前端通过 `server.ts` 启动 Express + Vite middleware后端通过 `backend/main.py` 启动 FastAPI。前端业务接口主要访问 FastAPI不依赖 `server.ts` 中保留的旧 mock API。 开发时前端通过 `server.ts` 启动 Express + Vite middleware后端通过 `backend/main.py` 启动 FastAPI。前端业务接口主要访问 FastAPI不依赖 `server.ts` 中保留的旧 mock API。
@@ -30,7 +30,7 @@
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 |
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | | 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、左右方向键切帧、播放和当前/总时长显示 | | 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、已编辑帧标记、左右方向键切帧、播放和当前/总时长显示 |
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 | | 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 | | AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 | | 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
@@ -48,11 +48,11 @@
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD | | Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 | | Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
| Media | `backend/routers/media.py` | 上传媒体和拆帧 | | Media | `backend/routers/media.py` | 上传媒体和拆帧 |
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、视频传播、模型状态和标注保存 | | AI | `backend/routers/ai.py` | 当前启用 SAM 2 推理、视频传播、模型状态和标注保存 |
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 | | Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 | | SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 |
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接、本地 checkpoint 加载、文本语义推理、正框几何提示和 video tracker 适配 | | SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | 历史保留的 SAM 3 桥接源码和脚本;当前未接入 registry |
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 | | SAM Registry | `backend/services/sam_registry.py` | 当前暴露 SAM 2.1 四个变体、GPU 状态和推理分发 |
## 状态模型 ## 状态模型
@@ -66,7 +66,7 @@
- `maskHistory` / `maskFuture`mask 编辑历史栈,用于撤销和重做。 - `maskHistory` / `maskFuture`mask 编辑历史栈,用于撤销和重做。
- `activeModule`:当前页面。 - `activeModule`:当前页面。
- `activeTool`:当前工具。 - `activeTool`:当前工具。
- `aiModel`:当前选择的 AI 模型,取值为 `sam2``sam3` - `aiModel`:当前启用的 AI 模型,取值为 `sam2.1_hiera_tiny``sam2.1_hiera_small``sam2.1_hiera_base_plus``sam2.1_hiera_large`,默认 `sam2.1_hiera_tiny`
## 关键数据流 ## 关键数据流
@@ -89,7 +89,7 @@
### 任务控制 ### 任务控制
1. Dashboard 从 `GET /api/dashboard/overview` 读取 queued/running/failed/cancelled 任务。 1. Dashboard 从 `GET /api/dashboard/overview` 读取 queued/running/success/failed/cancelled 任务queued/running 代表当前进度success/failed/cancelled 代表最近任务状态
2. 用户取消任务时,前端调用 `POST /api/tasks/{task_id}/cancel`;后端写入 `cancelled`、设置 `finished_at`,并尝试 `celery_app.control.revoke(..., terminate=True)` 2. 用户取消任务时,前端调用 `POST /api/tasks/{task_id}/cancel`;后端写入 `cancelled`、设置 `finished_at`,并尝试 `celery_app.control.revoke(..., terminate=True)`
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。 3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。 4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
@@ -101,23 +101,29 @@
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()` 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`
2. 若无帧但项目有 `video_path`,显示“尚未生成帧”的状态提示,不自动触发 `parseMedia()` 2. 若无帧但项目有 `video_path`,显示“尚未生成帧”的状态提示,不自动触发 `parseMedia()`
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs``sourceFrameNumber`,供时间轴和后续视频传播使用。 3. 帧数据映射为 store `Frame[]`,包含 `timestampMs``sourceFrameNumber`,供时间轴和后续视频传播使用。
4. 当前帧传入 `CanvasArea` 4. 工作区调用 `GET /api/ai/annotations` 回显已保存标注时,会替换当前项目帧中的已保存 mask但保留没有 `annotationId` 的未保存 draft mask这保证 AI 页推送到工作区的候选 mask 不会被异步回显覆盖,并会在合并完成后恢复仍然存在的已选 mask id
5. `CanvasArea` 会把全局 `selectedMaskIds` 中仍存在于当前帧的 id 同步回本地选区,避免帧初始化时的临时清空覆盖 AI 页推送过来的选中态。
6. `FrameTimeline` 根据当前项目 `frames` 和全局 `masks` 计算有编辑/标注的帧,在进度条与缩略图导航轴之间渲染可点击标记。
7. 当前帧传入 `CanvasArea`
### AI 点/框推理 ### AI 点/框推理
1. 用户在 Canvas 选择正向点、反向点或框选。 1. 用户在 Canvas 选择正向点、反向点或框选。
2. `CanvasArea` 读取当前帧 ID 和宽高。 2. `CanvasArea` 读取当前帧 ID 和宽高。
3. SAM 2 框选会创建一个候选 mask并记录原始框后续正向点/反向点会累计到同一候选上。 3. SAM 2.1 框选会创建一个候选 mask并记录原始框后续正向点/反向点会累计到同一候选上。
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。 4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。
5. SAM 2 请求中只要存在反向点,`CanvasArea` 会额外发送 `options.auto_filter_background=true``options.min_score=0.05`,让后端移除低分结果和包含负向点的 polygon。 5. SAM 2.1 请求中只要存在反向点,`CanvasArea` 会额外发送 `options.auto_filter_background=true``options.min_score=0.05`,让后端移除低分结果和包含负向点的 polygon。
6. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3 6. 后端加载帧图片并通过 SAM registry 分发到所选 SAM 2.1 变体;`model=sam2` 会兼容归一化为 tiny`model=sam3` 会被拒绝
7. 前端把 `polygons` 转为 mask交互式细化会替换同一个候选 mask而不是新增多个 mask。 7. 前端把 `polygons` 转为 mask交互式细化会替换同一个候选 mask而不是新增多个 mask。
8. 若带反向点的 SAM 2 细化返回空结果,前端会删除当前旧候选 mask 并提示反向点已排除该区域。 8. 若带反向点的 SAM 2.1 细化返回空结果,前端会删除当前旧候选 mask 并提示反向点已排除该区域。
9. Canvas 按当前帧过滤并渲染 mask 9. AI 页面只按本页生成的候选 id 渲染 mask不把工作区已有 mask 带入 AI 画布
10. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft` 10. AI 页面候选 mask 的 Path 点击事件会先判断当前工具;正向/反向选点工具下点击 mask 会继续追加提示点,其他工具下才选中 mask
11. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}` 11. AI 页面“遮罩清晰度”滑杆只调节候选 mask 的 Konva preview opacity不写入 `Mask.segmentation`、分类元数据或后端 payload
12. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 12. Canvas 按当前帧过滤并渲染 mask。
13. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask 13. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex``metadata.source=ai_segmentation` 和保存状态 `draft`
14. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;保存成功后本次提交的 draft mask id 会从本地保留列表中排除,并由后端 saved annotation 回显替换。
15. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
16. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
### 视频片段传播 ### 视频片段传播
@@ -125,8 +131,8 @@
2. `VideoWorkspace``buildAnnotationPayload()` 把 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。 2. `VideoWorkspace``buildAnnotationPayload()` 把 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。
3. 前端调用 `POST /api/ai/propagate`,默认 `direction=forward``max_frames=30``include_source=false` 3. 前端调用 `POST /api/ai/propagate`,默认 `direction=forward``max_frames=30``include_source=false`
4. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。 4. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。
5. `model=sam2` 时,`sam2_engine` 使用 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask再用 `propagate_in_video()` 传播。 5. `model` 为任一 SAM 2.1 变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask再用 `propagate_in_video()` 传播。
6. `model=sam3` 时,`sam3_engine` 将请求交给 `sam3_external_worker.py`,由独立 Python 3.12 环境调用官方 `build_sam3_video_predictor()`,以 seed bbox 走 video tracker 6. `model=sam3` 当前不支持SAM 3 video tracker 代码保留但没有接入产品路径
7. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。 7. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。
8. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注。 8. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注。
@@ -180,9 +186,10 @@
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。 6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds` 7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds`
8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon``CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。 8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon``CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。
9. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask 时使用 9. 工作区帧/标注异步加载完成后,`hydrateSavedAnnotations()` 会合并本地未保存 draft mask 和后端已保存 mask不会用后端回显结果直接覆盖整个 `masks` store
10. 如果 `selectedMaskIds` 中存在当前 store 的 mask点击分类时会立即更新这些 mask 的 `templateId``classId``className``classZIndex``label``color` 10. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask 时使用
11. 已保存 mask 被重新分类后进入 `dirty``saved=false`,继续复用工作区归档保存的 PATCH 链路 11. 如果 `selectedMaskIds` 中存在当前 store 的 mask点击分类时会立即更新这些 mask 的 `templateId``classId``className``classZIndex``label``color`
12. 已保存 mask 被重新分类后进入 `dirty``saved=false`,继续复用工作区归档保存的 PATCH 链路。
### 导出 ### 导出
@@ -210,15 +217,13 @@
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}` - `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps``max_frames``target_width`,用于生成标准帧序列。 - `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps``max_frames``target_width`,用于生成标准帧序列。
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms``source_frame_number` - `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms``source_frame_number`
- 后端 `/api/ai/predict` 支持 point、box、interactivesemantic 四种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3 - 后端 `/api/ai/predict` 当前支持 SAM 2.1 的 point、box、interactive`semantic` 文本提示禁用并返回 400
- SAM 2 是点/框交互式分割模型不做文本语义分割AI 页面在 SAM 2 + 纯文本时直接提示用户改用点提示或切换 SAM 3 - SAM 2.1 是点/框交互式分割模型不做文本语义分割AI 页面已经移除纯文本输入
- SAM 2 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。 - SAM 2.1 点提示和 auto fallback 只返回一个最高分候选,避免同一提示产生多个重叠候选 mask。
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败 - SAM 3 前端入口、后端 registry 入口和状态展示均已禁用;`model=sam3` 会返回不支持
- SAM 3 box prompt 复用后端 `/api/ai/predict` `box` prompt_type输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框 - 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt不会把正/反点误发送为 SAM 3 point prompt空文本、后端错误和空结果都会显示反馈消息 - 后端 `/api/ai/propagate` 当前支持所选 SAM 2.1 mask seed 视频传播;当前前端默认向后传播 30 帧并保存结果标注
- 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果SAM 3 semantic 会把正数 `min_score` 传给 external worker 作为 `confidence_threshold` - 后端 `/api/ai/models/status` 返回 GPU 和四个 SAM 2.1 变体的真实运行状态
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker当前前端默认向后传播 30 帧并保存结果标注。
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
## 外部依赖边界 ## 外部依赖边界
@@ -235,8 +240,8 @@
以下能力属于当前冻结版本的占位或半可用功能: 以下能力属于当前冻结版本的占位或半可用功能:
- Dashboard 初始快照来自 `GET /api/dashboard/overview`解析队列`processing_tasks` queued/running/failed/cancelled 任务生成。 - Dashboard 初始快照来自 `GET /api/dashboard/overview`任务进度区`processing_tasks` queued/running/success/failed/cancelled 任务生成,处理中统计只计算 queued/running
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。 - 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和本地 checkpoint状态接口会暴露真实可用性运行时缺失时 `available=false` - SAM 3 文本语义分割已从当前产品路径中禁用相关源码保留恢复时需要重新接入前端入口、registry、状态接口和测试
- 自定义分类只存在本地组件状态。 - 自定义分类只存在本地组件状态。
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。 - GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。

View File

@@ -17,13 +17,13 @@
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | | R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | | R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、回显已保存标注时保留本地未保存 draft mask、缩略图/range/已编辑帧标记/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 | | R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、SAM 2 框选后正负点细化同一候选 mask、SAM 2 反向点启用背景过滤且空结果移除旧候选、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频传播、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、保存后用后端 saved annotation 替换已提交 draft、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD | | R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat | | R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动进度区、最近完成任务保留显示、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat |
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
@@ -35,14 +35,14 @@
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 | | R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 | | R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 | | R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `test_media.py`, `test_tasks.py` | 已覆盖 |
| R4 | 工作区加载帧、无帧项目不自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | | R4 | 工作区加载帧、无帧项目不自动解析、后端标注回显保留本地未保存 draft mask、Canvas 底图、缩略图/range/已编辑帧标记/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | | R5 | 工具切换、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 顶点编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | | R5 | 顶点编辑、边中点插点、双击边界按位置插点、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R6 | SAM 2 点/框/interactive、SAM 2 纯文本提示拦截、SAM 2 最高分候选去重、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2 视频传播、SAM 3 semantic、SAM 3 semantic 请求级阈值、SAM 3 worker 单 2D mask 转 polygon、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py`, `test_sam3_engine.py` | 已覆盖 | | R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择、SAM 2.1 视频传播、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam2_engine.py` | 已覆盖 |
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 | | R7 | 保存、保存后替换已提交 draft、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | | R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | | R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | | R10 | Dashboard 概览、任务进度区、最近完成任务保留显示、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 | | R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 | | R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 | | R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
@@ -50,14 +50,14 @@
## 本轮补齐记录 ## 本轮补齐记录
- R5补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。 - R5补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
- R6补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别 - R6补充 `AISegmentation.test.tsx` 中 SAM 2.1 变体选择测试,验证前端不展示 SAM 3 入口、选择 small 后请求携带对应模型,且未放置点提示时不发起推理
- R6补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。 - R6补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。
- R6补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。 - R6补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。
- R6补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败 - R6补充 `ModelStatusBadge.test.tsx` 中 SAM 3 不展示测试,避免禁用入口重新出现在前端
- R6补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路 - R6补充后端 `selected_model=sam3` 拒绝测试和 semantic 禁用测试,避免后端继续暴露 SAM 3 产品能力
- R6补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。 - R6补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
- R6补充 `propagateMasks()` API 封装和 `VideoWorkspace` 传播按钮测试,验证当前选中区域会发送到后端视频传播接口。 - R6补充 `propagateMasks()` API 封装和 `VideoWorkspace` 传播按钮测试,验证当前选中区域会发送到后端视频传播接口。
- R6补充 SAM 3 external video tracker 请求测试验证主后端会把帧目录、源帧索引、seed bbox 和方向传给独立 Python helper - R6`backend/tests/test_sam3_engine.py` 已标记跳过,仅作为历史保留实现的参考测试,不计入当前产品功能覆盖
- R3补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps``max_frames``target_width` 会进入任务。 - R3补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps``max_frames``target_width` 会进入任务。
- R3补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms``source_frame_number``result.frame_sequence` 元数据。 - R3补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms``source_frame_number``result.frame_sequence` 元数据。
- R8补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。 - R8补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。

View File

@@ -22,23 +22,24 @@ describe('AISegmentation', () => {
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }], frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
}); });
apiMock.getAiModelStatus.mockResolvedValue({ apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2', selected_model: 'sam2.1_hiera_tiny',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true }, gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [ models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, { id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2.1 Tiny ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
], ],
}); });
}); });
it('lets the user choose SAM3 for subsequent predictions', async () => { it('shows the SAM2.1 variant selector without exposing SAM3', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />); render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!; expect(await screen.findByText('SAM 2.1 Tiny')).toBeInTheDocument();
fireEvent.click(sam3Button); expect(screen.getByText('tiny')).toBeInTheDocument();
expect(screen.getByText('small')).toBeInTheDocument();
expect(useStore.getState().aiModel).toBe('sam3'); expect(screen.getByText('base+')).toBeInTheDocument();
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument(); expect(screen.getByText('large')).toBeInTheDocument();
expect(screen.queryByText('SAM3')).not.toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
}); });
it('passes enabled inference parameters to the backend', async () => { it('passes enabled inference parameters to the backend', async () => {
@@ -53,7 +54,7 @@ describe('AISegmentation', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: [{ x: 120, y: 80, type: 'pos' }], points: [{ x: 120, y: 80, type: 'pos' }],
options: { options: {
crop_to_prompt: false, crop_to_prompt: false,
@@ -63,16 +64,50 @@ describe('AISegmentation', () => {
})); }));
}); });
it('does not run SAM2 text-only prompts as semantic segmentation', async () => { it('sends the selected SAM2.1 variant to prediction', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />); render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), { fireEvent.click(await screen.findByText('small'));
target: { value: '胆囊' }, fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_small');
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam2.1_hiera_small',
}));
}); });
it('does not render masks that were created in the workspace', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
selectedMaskIds: ['workspace-mask'],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.queryAllByTestId('konva-path')).toHaveLength(0);
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual([]));
});
it('requires point prompts before running SAM2 inference', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(await screen.findByText('执行高精度语义分割')); fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled(); expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument(); expect(await screen.findByText('请先放置正/反向提示点。')).toBeInTheDocument();
}); });
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => { it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
@@ -106,8 +141,116 @@ describe('AISegmentation', () => {
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1)); await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
expect(useStore.getState().masks[0].id).toBe('sam2-best'); expect(useStore.getState().masks[0].id).toBe('sam2-best');
expect(useStore.getState().masks[0].metadata).toEqual({ source: 'ai_segmentation' });
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']); expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument(); expect(await screen.findByText('SAM 2.1 Tiny 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
});
it('adjusts the AI mask preview opacity without changing mask data', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
expect(maskGroup()).toHaveAttribute('data-opacity', '0.72');
fireEvent.change(screen.getByLabelText('遮罩清晰度'), { target: { value: '35' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.35');
expect(useStore.getState().masks[0].segmentation).toEqual([[10, 10, 40, 10, 40, 40]]);
});
it('lets positive and negative prompt points be added on top of an AI mask', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
})
.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
fireEvent.click(screen.getByText('反向选点'));
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 220, clientY: 140 });
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(4));
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenLastCalledWith(expect.objectContaining({
points: [
{ x: 120, y: 80, type: 'pos' },
{ x: 220, y: 140, type: 'neg' },
],
}));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
});
it('clears only AI page candidates and keeps workspace masks in the store', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(2));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask', 'sam2-mask']));
fireEvent.click(screen.getByText('清空全体锚点'));
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask']);
expect(useStore.getState().selectedMaskIds).toEqual([]);
}); });
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => { it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
@@ -186,111 +329,4 @@ describe('AISegmentation', () => {
expect(onSendToWorkspace).toHaveBeenCalled(); expect(onSendToWorkspace).toHaveBeenCalled();
}); });
it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
});
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam3',
points: undefined,
text: '胆囊',
})));
expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
});
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'semantic-1',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'semantic result',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
useStore.setState({
activeTemplateId: 'template-1',
activeClassId: 'class-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
text: '胆囊',
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
})));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'semantic-1',
frameId: 'frame-1',
templateId: 'template-1',
classId: 'class-1',
className: '胆囊',
classZIndex: 30,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
}); });

View File

@@ -4,7 +4,7 @@ import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva'; import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
import useImage from 'use-image'; import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector'; import { OntologyInspector } from './OntologyInspector';
import { useStore } from '../store/useStore'; import { SAM2_MODEL_OPTIONS, useStore } from '../store/useStore';
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api'; import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
interface AISegmentationProps { interface AISegmentationProps {
@@ -16,7 +16,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const setActiveTool = useStore((state) => state.setActiveTool); const setActiveTool = useStore((state) => state.setActiveTool);
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks); const setMasks = useStore((state) => state.setMasks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds); const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds); const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const maskHistory = useStore((state) => state.maskHistory); const maskHistory = useStore((state) => state.maskHistory);
@@ -25,17 +25,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const redoMasks = useStore((state) => state.redoMasks); const redoMasks = useStore((state) => state.redoMasks);
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const aiModel = useStore((state) => state.aiModel); const aiModel = useStore((state) => state.aiModel);
const setAiModel = useStore((state) => state.setAiModel); const setAiModel = useStore((state) => state.setAiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const [semanticText, setSemanticText] = useState('');
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null); const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
const [autoDeleteBg, setAutoDeleteBg] = useState(true); const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false); const [cropMode, setCropMode] = useState(false);
const [maskOpacity, setMaskOpacity] = useState(72);
const [isInferencing, setIsInferencing] = useState(false); const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState(''); const [inferenceMessage, setInferenceMessage] = useState('');
const [aiMaskIds, setAiMaskIds] = useState<string[]>([]);
// Canvas state // Canvas state
const [scale, setScale] = useState(1); const [scale, setScale] = useState(1);
@@ -45,7 +46,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const currentFrame = frames[currentFrameIndex] || null; const currentFrame = frames[currentFrameIndex] || null;
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop'; const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
const [image] = useImage(previewUrl); const [image] = useImage(previewUrl);
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks; const aiMaskIdSet = new Set(aiMaskIds);
const frameMasks = currentFrame
? masks.filter((mask) => mask.frameId === currentFrame.id && aiMaskIdSet.has(mask.id))
: masks.filter((mask) => aiMaskIdSet.has(mask.id));
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel); const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
const modelCanInfer = selectedModelStatus?.available ?? true; const modelCanInfer = selectedModelStatus?.available ?? true;
@@ -65,6 +69,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
}; };
}, [aiModel]); }, [aiModel]);
useEffect(() => {
const visibleIds = new Set(frameMasks.map((mask) => mask.id));
const nextSelectedMaskIds = selectedMaskIds.filter((id) => visibleIds.has(id));
const changed = nextSelectedMaskIds.length !== selectedMaskIds.length
|| nextSelectedMaskIds.some((id, index) => id !== selectedMaskIds[index]);
if (changed) {
setSelectedMaskIds(nextSelectedMaskIds);
}
}, [frameMasks, selectedMaskIds, setSelectedMaskIds]);
const handleWheel = (e: any) => { const handleWheel = (e: any) => {
e.evt.preventDefault(); e.evt.preventDefault();
const scaleBy = 1.1; const scaleBy = 1.1;
@@ -94,17 +108,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
}; };
const runInference = useCallback(async () => { const runInference = useCallback(async () => {
const textPrompt = semanticText.trim(); if (points.length === 0) {
if (aiModel === 'sam3' && !textPrompt) { setInferenceMessage('请先放置正/反向提示点。');
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return;
}
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
return;
}
if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
return; return;
} }
if (!currentFrame?.id) { if (!currentFrame?.id) {
@@ -129,8 +134,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
imageWidth, imageWidth,
imageHeight, imageHeight,
model: aiModel, model: aiModel,
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })), points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: textPrompt || undefined,
options: { options: {
crop_to_prompt: cropMode, crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg, auto_filter_background: autoDeleteBg,
@@ -138,15 +142,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
}, },
}); });
const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks; const masksToApply = result.masks.slice(0, 1);
if (masksToApply.length === 0) { if (masksToApply.length === 0) {
setInferenceMessage(aiModel === 'sam3' setInferenceMessage('模型没有返回可用区域,请调整提示点后重试。');
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
} else { } else {
setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1 setInferenceMessage(result.masks.length > 1
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。` ? `${selectedModelStatus?.label || 'SAM 2.1'} 返回 ${result.masks.length} 个候选,已采用最高分区域。`
: `已生成 ${masksToApply.length} 个候选区域。`); : `已生成 ${masksToApply.length} 个候选区域。`);
} }
const generatedMaskIds: string[] = []; const generatedMaskIds: string[] = [];
@@ -169,9 +171,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
segmentation: m.segmentation, segmentation: m.segmentation,
bbox: m.bbox, bbox: m.bbox,
area: m.area, area: m.area,
metadata: { source: 'ai_segmentation' },
}); });
}); });
if (generatedMaskIds.length > 0) { if (generatedMaskIds.length > 0) {
setAiMaskIds((existingIds) => [...existingIds, ...generatedMaskIds]);
setSelectedMaskIds(generatedMaskIds); setSelectedMaskIds(generatedMaskIds);
} }
} catch (err) { } catch (err) {
@@ -181,17 +185,32 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]); }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, selectedModelStatus?.label, setSelectedMaskIds]);
const clearAiLayer = useCallback(() => {
setPoints([]);
if (aiMaskIds.length === 0) return;
const idsToRemove = new Set(aiMaskIds);
setMasks(masks.filter((mask) => !idsToRemove.has(mask.id)));
setSelectedMaskIds(selectedMaskIds.filter((id) => !idsToRemove.has(id)));
setAiMaskIds([]);
}, [aiMaskIds, masks, selectedMaskIds, setMasks, setSelectedMaskIds]);
const addPromptPointFromEvent = useCallback((event: any) => {
if (effectiveTool !== 'point_pos' && effectiveTool !== 'point_neg') return false;
const stage = event.target?.getStage?.();
const pos = stage?.getRelativePointerPosition?.();
if (!pos) return false;
setPoints((currentPoints) => [
...currentPoints,
{ x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' },
]);
return true;
}, [effectiveTool]);
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (effectiveTool === 'move') return;
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') { addPromptPointFromEvent(e);
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
setPoints([...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' }]);
}
}
}; };
return ( return (
@@ -206,24 +225,39 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div> </div>
<div className="flex-1 overflow-y-auto p-6 flex flex-col gap-8"> <div className="flex-1 overflow-y-auto p-6 flex flex-col gap-8">
{/* Model Select */} {/* Model Status */}
<div> <div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3> <h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3>
<div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg"> <div className="bg-[#111] border border-white/5 p-3 rounded-lg">
{(modelStatus?.models || [ <div className="flex items-center justify-between">
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' }, <span className="text-xs uppercase tracking-wider font-mono text-white">{selectedModelStatus?.label || 'SAM 2.1'}</span>
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' }, <span className={cn("text-xs", modelCanInfer ? "text-emerald-400" : "text-amber-400")}>
]).map((m) => ( {modelCanInfer ? '可用' : '不可用'}
</span>
</div>
<div className="mt-3 grid grid-cols-2 gap-2">
{SAM2_MODEL_OPTIONS.map((option) => {
const status = modelStatus?.models.find((model) => model.id === option.id);
const available = status?.available ?? false;
const selected = aiModel === option.id;
return (
<button <button
key={m.id} key={option.id}
className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")} type="button"
onClick={() => setAiModel(m.id)} onClick={() => setAiModel(option.id)}
title={m.message} className={cn(
"h-8 rounded border px-2 text-[10px] uppercase tracking-wider transition-colors flex items-center justify-between",
selected
? "bg-cyan-500/10 border-cyan-400/40 text-cyan-300"
: "bg-white/[0.03] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200"
)}
> >
{m.label.replace(' ', '')} <span>{option.shortLabel}</span>
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}></span> <span className={cn("h-1.5 w-1.5 rounded-full", available ? "bg-emerald-400" : "bg-amber-400")} />
</button> </button>
))} );
})}
</div>
</div> </div>
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed"> <div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div> <div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
@@ -269,20 +303,6 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div> </div>
</div> </div>
{/* Semantic Description */}
<div>
<div className="flex justify-between items-center mb-3">
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest"></h3>
<span className="text-[9px] bg-cyan-500/10 text-cyan-400 px-1.5 py-0.5 rounded border border-cyan-500/20 font-mono"></span>
</div>
<textarea
value={semanticText}
onChange={e => setSemanticText(e.target.value)}
placeholder="例如:'分割出左侧车道上行驶的所有红色汽车'..."
className="w-full bg-[#111] border border-white/5 rounded-lg p-3 text-sm text-white placeholder-gray-600 focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-sans min-h-[100px] resize-none hover:border-white/10"
/>
</div>
{/* Parameters */} {/* Parameters */}
<div> <div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3 flex items-center gap-2"></h3> <h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3 flex items-center gap-2"></h3>
@@ -300,6 +320,23 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className={cn("absolute top-0.5 left-0.5 w-3 h-3 bg-white rounded-full transition-transform shadow-sm", autoDeleteBg ? "translate-x-4" : "")} /> <div className={cn("absolute top-0.5 left-0.5 w-3 h-3 bg-white rounded-full transition-transform shadow-sm", autoDeleteBg ? "translate-x-4" : "")} />
</button> </button>
</div> </div>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium"></label>
<span className="text-[10px] font-mono text-cyan-400">{maskOpacity}%</span>
</div>
<input
id="ai-mask-opacity"
type="range"
min="20"
max="100"
step="5"
value={maskOpacity}
onChange={(event) => setMaskOpacity(Number(event.target.value))}
className="w-full accent-cyan-400"
/>
</div>
</div> </div>
</div> </div>
</div> </div>
@@ -340,7 +377,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0"> <header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
<div className="flex flex-col"> <div className="flex flex-col">
<h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2> <h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span> <span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 2.1 </span>
</div> </div>
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<button <button
@@ -363,7 +400,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5"> <button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5">
<ImageIcon size={14} /> <ImageIcon size={14} />
</button> </button>
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => { setPoints([]); clearMasks(); }}> <button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={clearAiLayer}>
</button> </button>
</div> </div>
@@ -395,24 +432,38 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
)} )}
{/* AI Returned Masks */} {/* AI Returned Masks */}
{frameMasks.map((mask) => ( {frameMasks.map((mask) => {
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}> const isSelected = selectedMaskIds.includes(mask.id);
const previewOpacity = isSelected
? maskOpacity / 100
: Math.max(0.18, (maskOpacity / 100) * 0.62);
return (
<Group key={mask.id} opacity={previewOpacity}>
<Path <Path
data={mask.pathData} data={mask.pathData}
fill={mask.color} fill={mask.color}
stroke={mask.color} stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale} strokeWidth={(isSelected ? 2.5 : 1) / scale}
onClick={(event: any) => { onClick={(event: any) => {
if (addPromptPointFromEvent(event)) {
event.cancelBubble = true;
return;
}
event.cancelBubble = true; event.cancelBubble = true;
setSelectedMaskIds([mask.id]); setSelectedMaskIds([mask.id]);
}} }}
onTap={(event: any) => { onTap={(event: any) => {
if (addPromptPointFromEvent(event)) {
event.cancelBubble = true;
return;
}
event.cancelBubble = true; event.cancelBubble = true;
setSelectedMaskIds([mask.id]); setSelectedMaskIds([mask.id]);
}} }}
/> />
</Group> </Group>
))} );
})}
{/* Points */} {/* Points */}
{points.map((p, i) => ( {points.map((p, i) => (

View File

@@ -47,7 +47,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: [{ x: 120, y: 80, type: 'pos' }], points: [{ x: 120, y: 80, type: 'pos' }],
box: undefined, box: undefined,
})); }));
@@ -65,55 +65,6 @@ describe('CanvasArea', () => {
})); }));
}); });
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
});
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam3-box-mask',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
});
render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'sam3-box-mask',
metadata: expect.objectContaining({
source: 'sam3_box',
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
}),
}));
});
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => { it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
apiMock.predictMask apiMock.predictMask
.mockResolvedValueOnce({ .mockResolvedValueOnce({
@@ -166,7 +117,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: undefined, points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 }, box: { x1: 120, y1: 80, x2: 260, y2: 200 },
})); }));
@@ -179,7 +130,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: [{ x: 150, y: 100, type: 'pos' }], points: [{ x: 150, y: 100, type: 'pos' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 }, box: { x1: 120, y1: 80, x2: 260, y2: 200 },
})); }));
@@ -200,7 +151,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: [ points: [
{ x: 150, y: 100, type: 'pos' }, { x: 150, y: 100, type: 'pos' },
{ x: 300, y: 150, type: 'neg' }, { x: 300, y: 150, type: 'neg' },
@@ -249,7 +200,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1', imageId: 'frame-1',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam2', model: 'sam2.1_hiera_tiny',
points: [{ x: 180, y: 120, type: 'neg' }], points: [{ x: 180, y: 120, type: 'neg' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 }, box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 }, options: { auto_filter_background: true, min_score: 0.05 },

View File

@@ -326,12 +326,35 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}, [frame?.id]); }, [frame?.id]);
useEffect(() => { useEffect(() => {
const currentGlobalSelectedIds = useStore.getState().selectedMaskIds;
if (selectedMaskIds.length === 0) {
const validGlobalSelectedIds = currentGlobalSelectedIds.filter((id) => (
frameMasks.some((mask) => mask.id === id)
));
if (validGlobalSelectedIds.length > 0) return;
}
const isSameSelection = currentGlobalSelectedIds.length === selectedMaskIds.length
&& currentGlobalSelectedIds.every((id, index) => id === selectedMaskIds[index]);
if (!isSameSelection) {
setGlobalSelectedMaskIds(selectedMaskIds); setGlobalSelectedMaskIds(selectedMaskIds);
}, [selectedMaskIds, setGlobalSelectedMaskIds]); }
}, [frameMasks, selectedMaskIds, setGlobalSelectedMaskIds]);
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]); useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
useEffect(() => { useEffect(() => {
if (!selectedMaskId) {
const validGlobalSelectedIds = useStore.getState().selectedMaskIds.filter((id) => (
frameMasks.some((mask) => mask.id === id)
));
if (validGlobalSelectedIds.length > 0) {
setSelectedMaskId(validGlobalSelectedIds[0]);
setSelectedMaskIds(validGlobalSelectedIds);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
return;
}
}
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) { if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null); setSelectedMaskId(null);
setSelectedMaskIds([]); setSelectedMaskIds([]);
@@ -444,11 +467,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setInferenceMessage('请先选择一帧图像。'); setInferenceMessage('请先选择一帧图像。');
return; return;
} }
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
return;
}
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0; const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0; const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) { if (imageWidth <= 0 || imageHeight <= 0) {
@@ -482,7 +500,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const color = activeClass?.color || existingCandidate?.color || m.color; const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = { const metadata = {
...(existingCandidate?.metadata || {}), ...(existingCandidate?.metadata || {}),
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive', source: 'sam2_interactive',
promptBox: promptBox || null, promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0, promptPointCount: promptPoints?.length || 0,
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0, promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,

View File

@@ -97,6 +97,42 @@ describe('Dashboard', () => {
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument(); expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
}); });
it('keeps a recently completed task visible in the progress panel', async () => {
apiMock.getDashboardOverview.mockResolvedValueOnce({
summary: {
project_count: 1,
parsing_task_count: 0,
annotation_count: 0,
frame_count: 120,
template_count: 1,
system_load_percent: 8,
},
tasks: [
{
id: 'task-20',
task_id: 20,
project_id: 1,
name: 'completed.mp4',
progress: 100,
status: '解析完成',
raw_status: 'success',
error: null,
frame_count: 120,
updated_at: '2026-05-01T00:00:00Z',
},
],
activity: [],
});
render(<Dashboard />);
expect(await screen.findByText('任务进度 (当前 / 最近)')).toBeInTheDocument();
expect(screen.getByText('completed.mp4')).toBeInTheDocument();
expect(screen.getByText('100%')).toBeInTheDocument();
expect(screen.getByText('解析完成')).toBeInTheDocument();
expect(screen.queryByText(/当前无处理任务/)).not.toBeInTheDocument();
});
it('connects to the progress stream and updates progress tasks', async () => { it('connects to the progress stream and updates progress tasks', async () => {
render(<Dashboard />); render(<Dashboard />);

View File

@@ -312,7 +312,7 @@ export function Dashboard() {
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6"> <div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]"> <div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ()</h2> <h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ( / )</h2>
<div className="space-y-4"> <div className="space-y-4">
{isLoading && ( {isLoading && (
<div className="text-sm text-gray-500 text-center py-12"> Dashboard ...</div> <div className="text-sm text-gray-500 text-center py-12"> Dashboard ...</div>
@@ -371,7 +371,7 @@ export function Dashboard() {
</div> </div>
))} ))}
{!isLoading && tasks.length === 0 && ( {!isLoading && tasks.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div> <div className="text-sm text-gray-500 text-center py-12"></div>
)} )}
</div> </div>
</div> </div>

View File

@@ -51,6 +51,28 @@ describe('FrameTimeline', () => {
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0); expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
}); });
it('marks edited frames between the time progress bar and frame navigator', () => {
useStore.setState({
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
masks: [
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#06b6d4' },
{ id: 'm2', frameId: 'f3', annotationId: '9', pathData: 'M 0 0 Z', label: 'Saved', color: '#22c55e' },
{ id: 'outside', frameId: 'other-frame', pathData: 'M 0 0 Z', label: 'Other', color: '#fff' },
],
});
render(<FrameTimeline />);
expect(screen.getByText('已编辑')).toBeInTheDocument();
expect(screen.getByText('2 帧')).toBeInTheDocument();
fireEvent.click(screen.getByLabelText('跳转到已编辑帧 3'));
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('changes frames with left and right arrow keys without leaving bounds', () => { it('changes frames with left and right arrow keys without leaving bounds', () => {
useStore.setState({ useStore.setState({
currentFrameIndex: 1, currentFrameIndex: 1,

View File

@@ -7,6 +7,7 @@ export function FrameTimeline() {
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentProject = useStore((state) => state.currentProject); const currentProject = useStore((state) => state.currentProject);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const [isPlaying, setIsPlaying] = useState(false); const [isPlaying, setIsPlaying] = useState(false);
@@ -22,6 +23,17 @@ export function FrameTimeline() {
}, [currentProject?.original_fps, currentProject?.parse_fps]); }, [currentProject?.original_fps, currentProject?.parse_fps]);
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0; const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0; const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const editedFrameMarkers = useMemo(() => {
const frameIds = new Set(frames.map((frame) => frame.id));
const editedIds = new Set(
masks
.filter((mask) => frameIds.has(mask.frameId))
.map((mask) => mask.frameId),
);
return frames
.map((frame, index) => ({ frame, index }))
.filter(({ frame }) => editedIds.has(frame.id));
}, [frames, masks]);
const formatTime = (seconds: number) => { const formatTime = (seconds: number) => {
const safeSeconds = Math.max(0, seconds); const safeSeconds = Math.max(0, seconds);
@@ -83,7 +95,7 @@ export function FrameTimeline() {
: []; : [];
return ( return (
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20"> <div className="h-36 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
<div className="h-4 bg-[#0d0d0d] flex items-center group relative"> <div className="h-4 bg-[#0d0d0d] flex items-center group relative">
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none"> <div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(currentSeconds)} {formatTime(currentSeconds)}
@@ -118,6 +130,34 @@ export function FrameTimeline() {
</div> </div>
</div> </div>
<div className="h-5 bg-[#0f0f0f] border-y border-white/[0.03] px-4 flex items-center gap-3">
<div className="w-20 text-[9px] font-mono uppercase tracking-widest text-gray-500 shrink-0"></div>
<div className="relative h-3 flex-1">
<div className="absolute left-0 right-0 top-1/2 h-px -translate-y-1/2 bg-white/5" />
{editedFrameMarkers.map(({ frame, index }) => {
const isCurrent = index === currentFrameIndex;
const left = totalFrames > 0 ? ((index + 1) / totalFrames) * 100 : 0;
return (
<button
key={frame.id}
type="button"
aria-label={`跳转到已编辑帧 ${index + 1}`}
title={`已编辑帧 ${index + 1}`}
onClick={() => setCurrentFrame(index)}
className={cn(
"absolute top-1/2 -translate-x-1/2 -translate-y-1/2 rounded-full border transition-all",
isCurrent
? "h-3 w-3 bg-cyan-300 border-cyan-100 shadow-[0_0_12px_rgba(34,211,238,0.65)]"
: "h-2 w-2 bg-amber-300 border-amber-100/80 hover:h-3 hover:w-3 hover:bg-cyan-300 hover:border-cyan-100"
)}
style={{ left: `${left}%` }}
/>
);
})}
</div>
<div className="w-20 text-right text-[9px] font-mono text-gray-500 shrink-0">{editedFrameMarkers.length} </div>
</div>
<div className="flex-1 flex items-center px-4 gap-6"> <div className="flex-1 flex items-center px-4 gap-6">
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0"> <div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">
<button <button

View File

@@ -17,11 +17,10 @@ describe('ModelStatusBadge', () => {
resetStore(); resetStore();
vi.clearAllMocks(); vi.clearAllMocks();
apiMock.getAiModelStatus.mockResolvedValue({ apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2', selected_model: 'sam2.1_hiera_tiny',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true }, gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [ models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, { id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2.1 Tiny ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
], ],
}); });
}); });
@@ -29,17 +28,14 @@ describe('ModelStatusBadge', () => {
it('loads real model status for the selected model', async () => { it('loads real model status for the selected model', async () => {
render(<ModelStatusBadge />); render(<ModelStatusBadge />);
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument(); expect(await screen.findByText('SAM 2.1 Tiny 可用')).toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2'); expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
}); });
it('shows unavailable state when SAM3 is selected but not runnable', async () => { it('does not expose disabled SAM3 status in the badge', async () => {
useStore.getState().setAiModel('sam3');
render(<ModelStatusBadge />); render(<ModelStatusBadge />);
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3')); await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny'));
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument(); expect(screen.queryByText(/SAM 3/)).not.toBeInTheDocument();
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
}); });
}); });

View File

@@ -50,7 +50,7 @@ describe('VideoWorkspace', () => {
apiMock.annotationToMask.mockReturnValue(null); apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' }); apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.propagateMasks.mockResolvedValue({ apiMock.propagateMasks.mockResolvedValue({
model: 'sam2', model: 'sam2.1_hiera_tiny',
direction: 'forward', direction: 'forward',
source_frame_id: 10, source_frame_id: 10,
processed_frame_count: 3, processed_frame_count: 3,
@@ -58,11 +58,10 @@ describe('VideoWorkspace', () => {
annotations: [], annotations: [],
}); });
apiMock.getAiModelStatus.mockResolvedValue({ apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2', selected_model: 'sam2.1_hiera_tiny',
gpu: { available: false, device: 'cpu', name: null, torch_available: true }, gpu: { available: false, device: 'cpu', name: null, torch_available: true },
models: [ models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, { id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
], ],
}); });
}); });
@@ -116,12 +115,65 @@ describe('VideoWorkspace', () => {
])); ]));
}); });
it('preserves unsaved AI masks when hydrating saved annotations after entering the workspace', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-99',
annotationId: '99',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
});
useStore.setState({
activeTool: 'edit_polygon',
selectedMaskIds: ['ai-mask'],
masks: [{
id: 'ai-mask',
frameId: '10',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
saveStatus: 'draft',
saved: false,
metadata: { source: 'ai_segmentation' },
}],
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual([
'ai-mask',
'annotation-99',
]));
expect(useStore.getState().selectedMaskIds).toEqual(['ai-mask']);
expect(useStore.getState().activeTool).toBe('edit_polygon');
});
it('saves pending masks through the archive button', async () => { it('saves pending masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([ apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]); ]);
apiMock.getProjectAnnotations
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ id: 5, frame_id: 10 }]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } }); apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 }); apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-5',
annotationId: '5',
frameId: '10',
saved: true,
saveStatus: 'saved',
pathData: 'M 0 0 Z',
label: 'Saved AI Mask',
color: '#06b6d4',
});
render(<VideoWorkspace />); render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1)); await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
@@ -153,6 +205,10 @@ describe('VideoWorkspace', () => {
expect.objectContaining({ id: '10' }), expect.objectContaining({ id: '10' }),
'2', '2',
); );
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-5', saved: true, saveStatus: 'saved' }),
]));
expect(useStore.getState().masks.some((mask) => mask.id === 'mask-1')).toBe(false);
}); });
it('updates dirty saved masks through the archive button', async () => { it('updates dirty saved masks through the archive button', async () => {
@@ -346,7 +402,7 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2)); await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => { act(() => {
useStore.setState({ useStore.setState({
aiModel: 'sam2', aiModel: 'sam2.1_hiera_tiny',
activeTemplateId: '2', activeTemplateId: '2',
selectedMaskIds: ['mask-1'], selectedMaskIds: ['mask-1'],
masks: [{ masks: [{
@@ -366,7 +422,7 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({ await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
project_id: 1, project_id: 1,
frame_id: 10, frame_id: 10,
model: 'sam2', model: 'sam2.1_hiera_tiny',
direction: 'forward', direction: 'forward',
max_frames: 30, max_frames: 30,
include_source: false, include_source: false,

View File

@@ -34,9 +34,14 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const activeTemplateId = useStore((state) => state.activeTemplateId); const activeTemplateId = useStore((state) => state.activeTemplateId);
const aiModel = useStore((state) => state.aiModel); const aiModel = useStore((state) => state.aiModel);
const selectedMaskIds = useStore((state) => state.selectedMaskIds); const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const latestSelectedMaskIdsRef = React.useRef<string[]>(selectedMaskIds);
if (selectedMaskIds.length > 0) {
latestSelectedMaskIdsRef.current = selectedMaskIds;
}
const setFrames = useStore((state) => state.setFrames); const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks); const setMasks = useStore((state) => state.setMasks);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const undoMasks = useStore((state) => state.undoMasks); const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks); const redoMasks = useStore((state) => state.redoMasks);
const [isSaving, setIsSaving] = useState(false); const [isSaving, setIsSaving] = useState(false);
@@ -45,8 +50,15 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const [isPropagating, setIsPropagating] = useState(false); const [isPropagating, setIsPropagating] = useState(false);
const [statusMessage, setStatusMessage] = useState(''); const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => { const hydrateSavedAnnotations = useCallback(async (
projectId: string,
projectFrames: Frame[],
preserveSelectedIds: string[] = [],
excludeUnsavedMaskIds: string[] = [],
) => {
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame])); const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
const projectFrameIds = new Set(projectFrames.map((frame) => frame.id));
const excludedDraftIds = new Set(excludeUnsavedMaskIds);
const annotations = await getProjectAnnotations(projectId); const annotations = await getProjectAnnotations(projectId);
const savedMasks = annotations const savedMasks = annotations
.map((annotation) => { .map((annotation) => {
@@ -54,14 +66,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
return frame ? annotationToMask(annotation, frame) : null; return frame ? annotationToMask(annotation, frame) : null;
}) })
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask)); .filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
setMasks(savedMasks); const currentMasks = useStore.getState().masks;
}, [setMasks]); const unsavedMasks = currentMasks.filter((mask) => (
!projectFrameIds.has(mask.frameId) || (!mask.annotationId && !excludedDraftIds.has(mask.id))
));
const mergedMasks = [...unsavedMasks, ...savedMasks];
setMasks(mergedMasks);
if (preserveSelectedIds.length > 0) {
const mergedMaskIds = new Set(mergedMasks.map((mask) => mask.id));
const nextSelectedIds = preserveSelectedIds.filter((id) => mergedMaskIds.has(id));
if (nextSelectedIds.length > 0) {
setSelectedMaskIds(nextSelectedIds);
}
}
}, [setMasks, setSelectedMaskIds]);
useEffect(() => { useEffect(() => {
if (!currentProject?.id) return; if (!currentProject?.id) return;
let cancelled = false; let cancelled = false;
const loadFrames = async () => { const loadFrames = async () => {
const selectedIdsBeforeLoad = latestSelectedMaskIdsRef.current;
try { try {
const data = await getProjectFrames(String(currentProject.id)); const data = await getProjectFrames(String(currentProject.id));
if (cancelled) return; if (cancelled) return;
@@ -90,7 +115,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
return; return;
} }
setStatusMessage(''); setStatusMessage('');
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames); await hydrateSavedAnnotations(String(currentProject.id), mappedFrames, selectedIdsBeforeLoad);
} catch (err) { } catch (err) {
console.error('Failed to load frames:', err); console.error('Failed to load frames:', err);
} }
@@ -126,12 +151,13 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
setIsSaving(true); setIsSaving(true);
setStatusMessage('正在保存标注...'); setStatusMessage('正在保存标注...');
try { try {
const createPayloads = pendingMasks const createItems = pendingMasks
.map((mask) => { .map((mask) => {
const frame = frameById.get(mask.frameId); const frame = frameById.get(mask.frameId);
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null; const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
return payload ? { maskId: mask.id, payload } : null;
}) })
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload)); .filter((item): item is NonNullable<typeof item> => Boolean(item));
const updatePayloads = dirtyMasks const updatePayloads = dirtyMasks
.map((mask) => { .map((mask) => {
@@ -148,17 +174,22 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}) })
.filter((item): item is NonNullable<typeof item> => Boolean(item)); .filter((item): item is NonNullable<typeof item> => Boolean(item));
if (createPayloads.length === 0 && updatePayloads.length === 0) { if (createItems.length === 0 && updatePayloads.length === 0) {
setStatusMessage('没有可保存的标注数据'); setStatusMessage('没有可保存的标注数据');
return 0; return 0;
} }
await Promise.all([ await Promise.all([
...createPayloads.map((payload) => saveAnnotation(payload)), ...createItems.map(({ payload }) => saveAnnotation(payload)),
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)), ...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
]); ]);
await hydrateSavedAnnotations(currentProject.id, frames); await hydrateSavedAnnotations(
const savedCount = createPayloads.length + updatePayloads.length; currentProject.id,
frames,
useStore.getState().selectedMaskIds,
createItems.map(({ maskId }) => maskId),
);
const savedCount = createItems.length + updatePayloads.length;
setStatusMessage(`已保存 ${savedCount} 个标注`); setStatusMessage(`已保存 ${savedCount} 个标注`);
return savedCount; return savedCount;
} catch (err) { } catch (err) {

View File

@@ -224,7 +224,7 @@ describe('api client contracts', () => {
axiosMock.client.post.mockResolvedValueOnce({ axiosMock.client.post.mockResolvedValueOnce({
data: { data: {
model: 'sam2', model: 'sam2.1_hiera_tiny',
direction: 'forward', direction: 'forward',
source_frame_id: 5, source_frame_id: 5,
processed_frame_count: 3, processed_frame_count: 3,
@@ -235,7 +235,7 @@ describe('api client contracts', () => {
await expect(propagateMasks({ await expect(propagateMasks({
project_id: 9, project_id: 9,
frame_id: 5, frame_id: 5,
model: 'sam2', model: 'sam2.1_hiera_tiny',
seed: { seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]], polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask', label: 'mask',
@@ -247,7 +247,7 @@ describe('api client contracts', () => {
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', { expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
project_id: 9, project_id: 9,
frame_id: 5, frame_id: 5,
model: 'sam2', model: 'sam2.1_hiera_tiny',
seed: { seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]], polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask', label: 'mask',
@@ -384,7 +384,7 @@ describe('api client contracts', () => {
points: [[0.5, 0.5], [0.1, 0.1]], points: [[0.5, 0.5], [0.1, 0.1]],
labels: [1, 0], labels: [1, 0],
}, },
model: 'sam2', model: 'sam2.1_hiera_tiny',
}); });
expect(result.masks[0]).toEqual(expect.objectContaining({ expect(result.masks[0]).toEqual(expect.objectContaining({
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z', pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
@@ -410,7 +410,7 @@ describe('api client contracts', () => {
image_id: 5, image_id: 5,
prompt_type: 'box', prompt_type: 'box',
prompt_data: [0.1, 0.1, 0.5, 0.5], prompt_data: [0.1, 0.1, 0.5, 0.5],
model: 'sam2', model: 'sam2.1_hiera_tiny',
}); });
}); });
@@ -437,11 +437,11 @@ describe('api client contracts', () => {
points: [[0.2, 0.2], [0.4, 0.4]], points: [[0.2, 0.2], [0.4, 0.4]],
labels: [1, 0], labels: [1, 0],
}, },
model: 'sam2', model: 'sam2.1_hiera_tiny',
}); });
}); });
it('uses semantic prompt type for text-only AI prediction', async () => { it('serializes text-only prediction as semantic when called directly', async () => {
const { predictMask } = await import('./api'); const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } }); axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
@@ -449,7 +449,6 @@ describe('api client contracts', () => {
imageId: '6', imageId: '6',
imageWidth: 640, imageWidth: 640,
imageHeight: 360, imageHeight: 360,
model: 'sam3',
text: '分割胆囊', text: '分割胆囊',
}); });
@@ -457,7 +456,7 @@ describe('api client contracts', () => {
image_id: 6, image_id: 6,
prompt_type: 'semantic', prompt_type: 'semantic',
prompt_data: '分割胆囊', prompt_data: '分割胆囊',
model: 'sam3', model: 'sam2.1_hiera_tiny',
}); });
}); });
@@ -484,7 +483,7 @@ describe('api client contracts', () => {
points: [[0.5, 0.5]], points: [[0.5, 0.5]],
labels: [1], labels: [1],
}, },
model: 'sam2', model: 'sam2.1_hiera_tiny',
options: { options: {
crop_to_prompt: true, crop_to_prompt: true,
auto_filter_background: true, auto_filter_background: true,
@@ -496,18 +495,17 @@ describe('api client contracts', () => {
it('loads AI model and GPU runtime status', async () => { it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api'); const { getAiModelStatus } = await import('./api');
const status = { const status = {
selected_model: 'sam2', selected_model: 'sam2.1_hiera_tiny',
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null }, gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
models: [ models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false }, { id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
], ],
}; };
axiosMock.client.get.mockResolvedValueOnce({ data: status }); axiosMock.client.get.mockResolvedValueOnce({ data: status });
await expect(getAiModelStatus('sam3')).resolves.toEqual(status); await expect(getAiModelStatus('sam2.1_hiera_tiny')).resolves.toEqual(status);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', { expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
params: { selected_model: 'sam3' }, params: { selected_model: 'sam2.1_hiera_tiny' },
}); });
}); });
}); });

View File

@@ -1,5 +1,5 @@
import axios, { AxiosError } from 'axios'; import axios, { AxiosError } from 'axios';
import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore'; import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template } from '../store/useStore';
import { API_BASE_URL } from './config'; import { API_BASE_URL } from './config';
const apiClient = axios.create({ const apiClient = axios.create({
@@ -557,7 +557,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
image_id: Number(payload.imageId), image_id: Number(payload.imageId),
prompt_type, prompt_type,
prompt_data, prompt_data,
model: payload.model || 'sam2', model: payload.model || DEFAULT_AI_MODEL_ID,
...(payload.options ? { options: payload.options } : {}), ...(payload.options ? { options: payload.options } : {}),
}); });

View File

@@ -17,7 +17,20 @@ export interface Project {
updatedAt?: string; updatedAt?: string;
} }
export type AiModelId = 'sam2' | 'sam3'; export type AiModelId =
| 'sam2.1_hiera_tiny'
| 'sam2.1_hiera_small'
| 'sam2.1_hiera_base_plus'
| 'sam2.1_hiera_large';
export const DEFAULT_AI_MODEL_ID: AiModelId = 'sam2.1_hiera_tiny';
export const SAM2_MODEL_OPTIONS: Array<{ id: AiModelId; label: string; shortLabel: string }> = [
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', shortLabel: 'tiny' },
{ id: 'sam2.1_hiera_small', label: 'SAM 2.1 Small', shortLabel: 'small' },
{ id: 'sam2.1_hiera_base_plus', label: 'SAM 2.1 Base+', shortLabel: 'base+' },
{ id: 'sam2.1_hiera_large', label: 'SAM 2.1 Large', shortLabel: 'large' },
];
export interface Frame { export interface Frame {
id: string; id: string;
@@ -195,7 +208,7 @@ export const useStore = create<AppState>((set) => ({
// Workspace // Workspace
activeModule: 'workspace', activeModule: 'workspace',
activeTool: 'move', activeTool: 'move',
aiModel: 'sam2', aiModel: DEFAULT_AI_MODEL_ID,
frames: [], frames: [],
currentFrameIndex: 0, currentFrameIndex: 0,
annotations: [], annotations: [],

View File

@@ -63,7 +63,7 @@ vi.mock('react-konva', () => ({
); );
}, },
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>, Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>, Group: ({ children, opacity }: any) => <div data-testid="konva-group" data-opacity={opacity}>{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />, Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => ( Circle: (props: any) => (
<span <span
@@ -72,7 +72,11 @@ vi.mock('react-konva', () => ({
data-x={props.x} data-x={props.x}
data-y={props.y} data-y={props.y}
onClick={(event) => { onClick={(event) => {
const konvaEvent = { cancelBubble: false }; const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onClick?.(konvaEvent); props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation(); if (konvaEvent.cancelBubble) event.stopPropagation();
}} }}
@@ -98,7 +102,11 @@ vi.mock('react-konva', () => ({
data-fill={props.fill} data-fill={props.fill}
data-fill-rule={props.fillRule} data-fill-rule={props.fillRule}
onClick={(event) => { onClick={(event) => {
const konvaEvent = { cancelBubble: false }; const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onClick?.(konvaEvent); props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation(); if (konvaEvent.cancelBubble) event.stopPropagation();
}} }}

View File

@@ -1,4 +1,4 @@
import { useStore } from '../store/useStore'; import { DEFAULT_AI_MODEL_ID, useStore } from '../store/useStore';
export function resetStore() { export function resetStore() {
useStore.setState({ useStore.setState({
@@ -8,7 +8,7 @@ export function resetStore() {
currentProject: null, currentProject: null,
activeModule: 'workspace', activeModule: 'workspace',
activeTool: 'move', activeTool: 'move',
aiModel: 'sam2', aiModel: DEFAULT_AI_MODEL_ID,
frames: [], frames: [],
currentFrameIndex: 0, currentFrameIndex: 0,
annotations: [], annotations: [],