feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -19,4 +19,5 @@ VITE_WS_PROGRESS_URL="ws://192.168.3.11:8000/ws/progress"
|
||||
sam_default_model="sam2"
|
||||
sam_model_path="/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config="configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version="sam3.1"
|
||||
sam3_model_version="sam3"
|
||||
sam3_checkpoint_path="/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ coverage/
|
||||
!.env.example
|
||||
# Data & Models
|
||||
models/
|
||||
sam3权重/
|
||||
uploads/
|
||||
frames/
|
||||
minio_data/
|
||||
|
||||
29
AGENTS.md
29
AGENTS.md
@@ -6,7 +6,7 @@
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
||||
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
||||
|
||||
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
||||
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
||||
@@ -39,7 +39,7 @@
|
||||
| 缓存 / 队列 Broker | Redis |
|
||||
| 后台任务 | Celery worker |
|
||||
| 对象存储 | MinIO |
|
||||
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 |
|
||||
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/本地 checkpoint 状态 |
|
||||
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
|
||||
| 运行时 | Node.js ES Modules;Python 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 |
|
||||
|
||||
@@ -78,12 +78,12 @@ Seg_Server/
|
||||
│ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
|
||||
│ │ ├── templates.py # /api/templates
|
||||
│ │ ├── media.py # /api/media/upload、/upload/dicom、/parse
|
||||
│ │ ├── ai.py # /api/ai/predict、/models/status、/auto、/annotate
|
||||
│ │ ├── ai.py # /api/ai/predict、/propagate、/models/status、/auto、/annotate
|
||||
│ │ └── export.py # /api/export/{project_id}/coco、/masks
|
||||
│ └── services/
|
||||
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
|
||||
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
|
||||
│ ├── sam2_engine.py # SAM 2 单帧推理和 video predictor 传播封装
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器
|
||||
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
|
||||
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||
└── src/ # React 前端
|
||||
@@ -194,6 +194,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- `POST /api/tasks/{task_id}/cancel`
|
||||
- `POST /api/tasks/{task_id}/retry`
|
||||
- `POST /api/ai/predict`
|
||||
- `POST /api/ai/propagate`
|
||||
- `GET /api/ai/models/status`
|
||||
- `POST /api/ai/auto`
|
||||
- `POST /api/ai/annotate`
|
||||
@@ -219,14 +220,15 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
||||
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
|
||||
3. 上传资源:视频走 `/api/media/upload`;DICOM 批量走 `/api/media/upload/dicom`。
|
||||
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
|
||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
|
||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/difference;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割;`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为不可用。
|
||||
9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||
10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||
11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。
|
||||
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。
|
||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播;`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。
|
||||
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`。
|
||||
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||
12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。
|
||||
|
||||
---
|
||||
|
||||
@@ -240,6 +242,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
||||
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`;SAM 2 路径使用视频 predictor,SAM 3 路径使用独立 Python helper 的官方 video tracker,完成后刷新后端已保存标注。
|
||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
||||
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
||||
|
||||
20
README.md
20
README.md
@@ -6,14 +6,14 @@
|
||||
|
||||
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
|
||||
>
|
||||
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。
|
||||
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、视频片段传播、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,SAM 3 支持语义文本、框选提示和 video tracker,前端会显示真实 GPU/模型状态。
|
||||
|
||||
---
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析
|
||||
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性
|
||||
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||
@@ -104,8 +104,8 @@ Seg_Server/
|
||||
│ │ ├── ai.py # SAM 推理与模型状态接口
|
||||
│ │ └── export.py # 数据导出
|
||||
│ └── services/ # 业务服务
|
||||
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级)
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
|
||||
│ ├── sam2_engine.py # SAM 2 推理引擎(单帧推理 + video predictor 传播)
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器
|
||||
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
|
||||
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
|
||||
@@ -255,12 +255,11 @@ python download_sam2.py
|
||||
cd ~/Desktop/Seg_Server
|
||||
./backend/setup_sam3_env.sh
|
||||
|
||||
# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录
|
||||
conda activate sam3
|
||||
huggingface-cli login
|
||||
# 如果已把权重放在 sam3权重/sam3.pt,可直接走本地 checkpoint;
|
||||
# 未配置本地 checkpoint 时,才需要 Hugging Face gated repo 授权和登录。
|
||||
```
|
||||
|
||||
官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度;`facebook/sam3.1` 约 3.5 GB,主要面向新的视频 multiplex checkpoint。未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
|
||||
官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度。当前仓库默认使用本机 `sam3权重/sam3.pt`,不会提交权重文件;未配置本地 checkpoint 且未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
|
||||
|
||||
### 步骤 6: 配置环境变量
|
||||
|
||||
@@ -278,6 +277,7 @@ sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
|
||||
sam_model_config=configs/sam2/sam2_hiera_t.yaml
|
||||
sam_default_model=sam2
|
||||
sam3_model_version=sam3
|
||||
sam3_checkpoint_path=/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt
|
||||
sam3_external_enabled=true
|
||||
sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
|
||||
sam3_timeout_seconds=300
|
||||
@@ -311,6 +311,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
|
||||
- 测试 Redis 连接
|
||||
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态
|
||||
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
|
||||
- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:SAM 2 使用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker
|
||||
|
||||
### 步骤 6.1: 启动 Celery Worker
|
||||
|
||||
@@ -324,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||
```
|
||||
|
||||
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||
|
||||
### 步骤 7: 安装前端依赖并构建
|
||||
|
||||
@@ -460,6 +461,7 @@ pip install -e . --no-build-isolation
|
||||
|
||||
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed,调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
|
||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。
|
||||
|
||||
@@ -23,6 +23,7 @@ class Settings(BaseSettings):
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version: str = "sam3"
|
||||
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
|
||||
sam3_external_enabled: bool = True
|
||||
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
|
||||
sam3_timeout_seconds: int = 300
|
||||
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine, SessionLocal
|
||||
@@ -30,6 +31,20 @@ logger = logging.getLogger(__name__)
|
||||
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
"""Add nullable columns introduced after initial create_all deployments."""
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||||
with engine.begin() as connection:
|
||||
if "timestamp_ms" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||||
if "source_frame_number" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Runtime schema column check failed: %s", exc)
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
@@ -93,12 +108,16 @@ def _seed_default_project_sync() -> None:
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
img = cv2.imread(frame_files[idx])
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
timestamp_ms = idx * 1000.0 / 30.0
|
||||
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=timestamp_ms,
|
||||
source_frame_number=source_frame_number,
|
||||
)
|
||||
db.add(frame)
|
||||
|
||||
@@ -176,6 +195,7 @@ async def lifespan(app: FastAPI):
|
||||
# Initialize database tables
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_schema_columns()
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
|
||||
@@ -56,6 +56,8 @@ class Frame(Base):
|
||||
image_url = Column(String(512), nullable=False)
|
||||
width = Column(Integer, nullable=True)
|
||||
height = Column(Integer, nullable=True)
|
||||
timestamp_ms = Column(Float, nullable=True)
|
||||
source_frame_number = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project = relationship("Project", back_populates="frames")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
@@ -15,6 +17,8 @@ from schemas import (
|
||||
AiRuntimeStatus,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
PropagateRequest,
|
||||
PropagateResponse,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
@@ -66,6 +70,48 @@ def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||||
]
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
if direction == "both":
|
||||
before = (count - 1) // 2
|
||||
after = count - 1 - before
|
||||
start = max(0, source_position - before)
|
||||
end = min(len(frames), source_position + after + 1)
|
||||
while end - start < count and start > 0:
|
||||
start -= 1
|
||||
while end - start < count and end < len(frames):
|
||||
end += 1
|
||||
return frames[start:end], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
path = directory / f"frame_{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||||
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||||
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||||
@@ -184,6 +230,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
@@ -246,6 +293,51 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "interactive":
|
||||
prompt = payload.prompt_data
|
||||
if not isinstance(prompt, dict):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
|
||||
box = prompt.get("box")
|
||||
points = prompt.get("points") or []
|
||||
labels = prompt.get("labels")
|
||||
if box is not None and (not isinstance(box, list) or len(box) != 4):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
|
||||
if not isinstance(points, list):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
|
||||
if not box and len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
negative_points = [
|
||||
point for point, label in zip(points, labels) if label == 0
|
||||
]
|
||||
inference_image = image
|
||||
inference_box = box
|
||||
inference_points = points
|
||||
crop_bounds = None
|
||||
if options.get("crop_to_prompt"):
|
||||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||||
crop_points = list(points)
|
||||
if box:
|
||||
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
|
||||
crop_bounds = _crop_bounds_from_points(crop_points, margin)
|
||||
inference_image = _crop_image(image, crop_bounds)
|
||||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||||
if box:
|
||||
inference_box = [
|
||||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||||
]
|
||||
polygons, scores = sam_registry.predict_interactive(
|
||||
payload.model,
|
||||
inference_image,
|
||||
inference_box,
|
||||
inference_points,
|
||||
labels,
|
||||
)
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
||||
@@ -276,6 +368,124 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/propagate",
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
SAM 3 uses the external Python 3.12 video tracker with the seed bbox.
|
||||
"""
|
||||
direction = payload.direction.lower()
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
bbox = seed.get("bbox")
|
||||
points = seed.get("points") or []
|
||||
if not polygons and not bbox and not points:
|
||||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
if len(selected_frames) == 0:
|
||||
raise HTTPException(status_code=400, detail="No frames available for propagation")
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
payload.model,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Video propagation failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
|
||||
|
||||
created: list[Annotation] = []
|
||||
if payload.save_annotations:
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.model)
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not payload.include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=payload.project_id,
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygon_bbox(polygon),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
|
||||
return {
|
||||
"model": sam_registry.normalize_model_id(payload.model),
|
||||
"direction": direction,
|
||||
"source_frame_id": source_frame.id,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"annotations": created,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
@@ -169,6 +169,9 @@ async def upload_dicom_batch(
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
parse_fps: Optional[float] = Query(None, gt=0, le=120),
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
@@ -184,14 +187,21 @@ def parse_media(
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
effective_parse_fps = parse_fps or project.parse_fps or 30.0
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={"source_type": effective_source},
|
||||
payload={
|
||||
"source_type": effective_source,
|
||||
"parse_fps": effective_parse_fps,
|
||||
"max_frames": max_frames,
|
||||
"target_width": target_width,
|
||||
},
|
||||
)
|
||||
project.parse_fps = effective_parse_fps
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
@@ -51,6 +51,8 @@ class FrameBase(BaseModel):
|
||||
image_url: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
timestamp_ms: Optional[float] = None
|
||||
source_frame_number: Optional[int] = None
|
||||
|
||||
|
||||
class FrameCreate(FrameBase):
|
||||
@@ -188,6 +190,37 @@ class PredictResponse(BaseModel):
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
class PropagationSeed(BaseModel):
|
||||
polygons: Optional[list[list[list[float]]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
labels: Optional[list[int]] = None
|
||||
label: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
class_metadata: Optional[dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2"
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class PropagateResponse(BaseModel):
|
||||
model: str
|
||||
direction: str
|
||||
source_frame_id: int
|
||||
processed_frame_count: int
|
||||
created_annotation_count: int
|
||||
annotations: list[AnnotationOut]
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@@ -52,6 +52,7 @@ def parse_video(
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
target_width: int = 640,
|
||||
) -> Tuple[List[str], float]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
@@ -60,6 +61,7 @@ def parse_video(
|
||||
output_dir: Directory to save extracted frames.
|
||||
fps: Target frame extraction rate.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
target_width: Output frame width for model-friendly frame sequences.
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_paths, original_fps).
|
||||
@@ -67,6 +69,8 @@ def parse_video(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
original_fps = get_video_fps(video_path)
|
||||
safe_fps = max(int(fps), 1)
|
||||
safe_width = max(int(target_width), 1)
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
@@ -75,7 +79,8 @@ def parse_video(
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", video_path,
|
||||
"-vf", f"fps={fps},scale=640:-1",
|
||||
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
|
||||
"-start_number", "0",
|
||||
"-q:v", "5",
|
||||
"-y",
|
||||
pattern,
|
||||
@@ -102,7 +107,7 @@ def parse_video(
|
||||
raise RuntimeError(f"Cannot open video: {video_path}")
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
interval = max(1, int(round(video_fps / fps)))
|
||||
interval = max(1, int(round(video_fps / safe_fps)))
|
||||
count = 0
|
||||
saved = 0
|
||||
|
||||
@@ -112,6 +117,10 @@ def parse_video(
|
||||
break
|
||||
if count % interval == 0:
|
||||
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
|
||||
h, w = frame.shape[:2]
|
||||
if safe_width > 0 and w != safe_width:
|
||||
scale = safe_width / max(w, 1)
|
||||
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
frame_paths.append(path)
|
||||
saved += 1
|
||||
|
||||
@@ -76,6 +76,38 @@ def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def _positive_int(value: Any, default: int | None = None) -> int | None:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _positive_float(value: Any, default: float) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _frame_sequence_metadata(
|
||||
index: int,
|
||||
parse_fps: float,
|
||||
original_fps: float | None,
|
||||
) -> dict[str, float | int | None]:
|
||||
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
|
||||
timestamp_ms = index * 1000.0 / safe_parse_fps
|
||||
source_frame_number = None
|
||||
if original_fps and original_fps > 0:
|
||||
source_frame_number = int(round(index * original_fps / safe_parse_fps))
|
||||
return {
|
||||
"timestamp_ms": timestamp_ms,
|
||||
"source_frame_number": source_frame_number,
|
||||
}
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
@@ -138,8 +170,12 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
payload = task.payload or {}
|
||||
effective_source = payload.get("source_type") or project.source_type or "video"
|
||||
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
|
||||
max_frames = _positive_int(payload.get("max_frames"))
|
||||
target_width = _positive_int(payload.get("target_width"), 640) or 640
|
||||
project.parse_fps = parse_fps
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -163,7 +199,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
|
||||
else:
|
||||
_ensure_not_cancelled(db, task)
|
||||
media_bytes = download_file(project.video_path)
|
||||
@@ -173,7 +209,13 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
frame_files, original_fps = parse_video(
|
||||
local_path,
|
||||
output_dir,
|
||||
fps=int(parse_fps),
|
||||
max_frames=max_frames,
|
||||
target_width=target_width,
|
||||
)
|
||||
project.original_fps = original_fps
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
@@ -205,12 +247,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=sequence_meta["timestamp_ms"],
|
||||
source_frame_number=sequence_meta["source_frame_number"],
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
@@ -223,6 +268,17 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": PROJECT_STATUS_READY,
|
||||
"message": "Frame extraction completed successfully.",
|
||||
"frame_sequence": {
|
||||
"original_fps": project.original_fps,
|
||||
"parse_fps": parse_fps,
|
||||
"frame_count": len(frames_out),
|
||||
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
|
||||
"target_width": target_width,
|
||||
"frame_width": frames_out[0].width if frames_out else None,
|
||||
"frame_height": frames_out[0].height if frames_out else None,
|
||||
"max_frames": max_frames,
|
||||
"object_prefix": f"projects/{project.id}/frames",
|
||||
},
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
|
||||
@@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
@@ -38,9 +39,12 @@ class SAM2Engine:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._video_predictor = None
|
||||
self._model_loaded = False
|
||||
self._video_model_loaded = False
|
||||
self._loaded_device: str | None = None
|
||||
self._last_error: str | None = None
|
||||
self._video_last_error: str | None = None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -85,6 +89,40 @@ class SAM2Engine:
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _load_video_model(self) -> None:
|
||||
"""Load the SAM 2 video predictor on first propagation use."""
|
||||
if self._video_model_loaded:
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._video_last_error = "PyTorch is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not SAM2_AVAILABLE:
|
||||
self._video_last_error = "sam2 package is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
self._video_predictor = build_sam2_video_predictor(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device=device,
|
||||
)
|
||||
self._video_model_loaded = True
|
||||
self._loaded_device = device
|
||||
self._video_last_error = None
|
||||
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._video_last_error = str(exc)
|
||||
self._video_model_loaded = True
|
||||
logger.error("Failed to load SAM 2 video predictor: %s", exc)
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -95,6 +133,11 @@ class SAM2Engine:
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
def _ensure_video_ready(self) -> bool:
|
||||
"""Ensure the video predictor is loaded; return whether it is usable."""
|
||||
self._load_video_model()
|
||||
return SAM2_AVAILABLE and self._video_predictor is not None
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
checkpoint_exists = os.path.isfile(settings.sam_model_path)
|
||||
@@ -121,7 +164,7 @@ class SAM2Engine:
|
||||
"available": available,
|
||||
"loaded": self._predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "auto"],
|
||||
"supports": ["point", "box", "interactive", "auto", "propagate"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
@@ -221,6 +264,52 @@ class SAM2Engine:
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run combined box and point prompt segmentation for refinement."""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
bbox = None
|
||||
if box:
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
pts = None
|
||||
lbls = None
|
||||
if points:
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 interactive prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
@@ -260,6 +349,89 @@ class SAM2Engine:
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict,
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
|
||||
if not self._ensure_video_ready():
|
||||
raise RuntimeError(self._video_last_error or self.status()["message"])
|
||||
if not frame_paths:
|
||||
return []
|
||||
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
|
||||
raise ValueError("source_frame_index is outside the frame sequence.")
|
||||
|
||||
import cv2
|
||||
|
||||
source_image = cv2.imread(frame_paths[source_frame_index])
|
||||
if source_image is None:
|
||||
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
|
||||
height, width = source_image.shape[:2]
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
|
||||
if not seed_mask.any():
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
seed_mask = self._bbox_to_mask(bbox, width, height)
|
||||
if not seed_mask.any():
|
||||
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
|
||||
|
||||
inference_state = self._video_predictor.init_state(
|
||||
video_path=os.path.dirname(frame_paths[0]),
|
||||
offload_video_to_cpu=True,
|
||||
offload_state_to_cpu=True,
|
||||
)
|
||||
self._video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
frame_idx=source_frame_index,
|
||||
obj_id=1,
|
||||
mask=seed_mask,
|
||||
)
|
||||
|
||||
results: dict[int, dict] = {}
|
||||
|
||||
def collect(reverse: bool) -> None:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
||||
inference_state,
|
||||
start_frame_idx=source_frame_index,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=reverse,
|
||||
):
|
||||
masks = out_mask_logits
|
||||
if hasattr(masks, "detach"):
|
||||
masks = masks.detach().cpu().numpy()
|
||||
masks = np.asarray(masks)
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
polygons = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
polygon = self._mask_to_polygon(mask > 0)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
scores.append(1.0)
|
||||
results[int(out_frame_idx)] = {
|
||||
"frame_index": int(out_frame_idx),
|
||||
"polygons": polygons,
|
||||
"scores": scores,
|
||||
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
|
||||
}
|
||||
|
||||
normalized_direction = direction.lower()
|
||||
if normalized_direction in {"forward", "both"}:
|
||||
collect(reverse=False)
|
||||
if normalized_direction in {"backward", "both"}:
|
||||
collect(reverse=True)
|
||||
|
||||
try:
|
||||
self._video_predictor.reset_state(inference_state)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return [results[index] for index in sorted(results)]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -292,6 +464,38 @@ class SAM2Engine:
|
||||
]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in polygon
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
|
||||
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
left = int(round(x * max(width - 1, 1)))
|
||||
top = int(round(y * max(height - 1, 1)))
|
||||
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
|
||||
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
|
||||
mask = np.zeros((height, width), dtype=bool)
|
||||
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
|
||||
return mask
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
|
||||
@@ -56,8 +56,22 @@ class SAM3Engine:
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _checkpoint_path(self) -> str | None:
|
||||
path = settings.sam3_checkpoint_path.strip()
|
||||
return path if path else None
|
||||
|
||||
def _checkpoint_exists(self) -> bool:
|
||||
path = self._checkpoint_path()
|
||||
return bool(path and os.path.isfile(path))
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
return bool(
|
||||
SAM3_PACKAGE_AVAILABLE
|
||||
and TORCH_AVAILABLE
|
||||
and self._python_ok()
|
||||
and self._gpu_ok()
|
||||
and self._checkpoint_exists()
|
||||
)
|
||||
|
||||
def _worker_path(self) -> Path:
|
||||
return Path(__file__).with_name("sam3_external_worker.py")
|
||||
@@ -98,6 +112,8 @@ class SAM3Engine:
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||
capture_output=True,
|
||||
@@ -146,7 +162,10 @@ class SAM3Engine:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model()
|
||||
self._model = build_sam3_image_model(
|
||||
checkpoint_path=self._checkpoint_path(),
|
||||
load_from_HF=False,
|
||||
)
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
@@ -170,6 +189,8 @@ class SAM3Engine:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if not self._checkpoint_exists():
|
||||
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
@@ -182,7 +203,7 @@ class SAM3Engine:
|
||||
if self._processor is not None:
|
||||
message = "SAM 3 model loaded and ready."
|
||||
elif external_ready:
|
||||
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
elif external_status.get("message") and not self._can_load():
|
||||
message = str(external_status["message"])
|
||||
return {
|
||||
@@ -191,11 +212,11 @@ class SAM3Engine:
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||
"supports": ["semantic"],
|
||||
"supports": ["semantic", "box", "video_track"],
|
||||
"message": message,
|
||||
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||
"cuda_required": True,
|
||||
@@ -203,7 +224,43 @@ class SAM3Engine:
|
||||
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||
}
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def _predict_external(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt_type: str,
|
||||
*,
|
||||
text: str = "",
|
||||
box: list[float] | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
@@ -217,8 +274,11 @@ class SAM3Engine:
|
||||
json.dumps(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"prompt_type": prompt_type,
|
||||
"text": text.strip(),
|
||||
"box": box,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
@@ -227,6 +287,8 @@ class SAM3Engine:
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
@@ -250,6 +312,72 @@ class SAM3Engine:
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("polygons", []), payload.get("scores", [])
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "semantic", text=text)
|
||||
|
||||
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "box", box=box)
|
||||
|
||||
def _propagate_video_external(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
if not frame_paths:
|
||||
return []
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
|
||||
request_path = Path(tmpdir) / "request.json"
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt_type": "video_track",
|
||||
"frame_dir": str(Path(frame_paths[0]).parent),
|
||||
"source_frame_index": source_frame_index,
|
||||
"seed": seed,
|
||||
"direction": direction,
|
||||
"max_frames": max_frames,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("frames", [])
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
@@ -263,29 +391,37 @@ class SAM3Engine:
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
|
||||
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_box_external(image, box)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=self._xyxy_to_cxcywh(box),
|
||||
label=True,
|
||||
)
|
||||
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
|
||||
@@ -43,6 +43,13 @@ def _compact_error(exc: Exception) -> str:
|
||||
|
||||
|
||||
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
|
||||
if checkpoint_path:
|
||||
path = Path(checkpoint_path)
|
||||
if path.is_file():
|
||||
return True, None
|
||||
return False, f"local checkpoint not found: {checkpoint_path}"
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -55,6 +62,7 @@ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
|
||||
def runtime_status() -> dict[str, Any]:
|
||||
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
|
||||
package_error = None
|
||||
package_available = importlib.util.find_spec("sam3") is not None
|
||||
if package_available:
|
||||
@@ -85,6 +93,7 @@ def runtime_status() -> dict[str, Any]:
|
||||
"available": available,
|
||||
"package_available": package_available,
|
||||
"checkpoint_access": checkpoint_access,
|
||||
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
|
||||
"python_ok": python_ok,
|
||||
"torch_ok": torch_version is not None,
|
||||
"torch_version": torch_version,
|
||||
@@ -118,34 +127,67 @@ def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
|
||||
def _to_numpy(value: Any) -> np.ndarray:
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach().cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.detach()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.cpu()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.numpy()
|
||||
return np.asarray(value)
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
image_path = Path(payload["image_path"])
|
||||
text = str(payload["text"]).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
return [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model()
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
polygons = seed.get("polygons") or []
|
||||
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
|
||||
if not points:
|
||||
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
|
||||
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
|
||||
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
|
||||
|
||||
|
||||
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(outputs.get("out_binary_masks", []))
|
||||
scores = _to_numpy(outputs.get("out_probs", []))
|
||||
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
elif masks.ndim == 2:
|
||||
masks = masks[None, ...]
|
||||
|
||||
polygons = []
|
||||
out_scores = []
|
||||
out_ids = []
|
||||
for index, mask in enumerate(masks):
|
||||
polygon = _mask_to_polygon(mask)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
|
||||
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
|
||||
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
|
||||
|
||||
|
||||
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(output.get("masks", []))
|
||||
scores = _to_numpy(output.get("scores", []))
|
||||
if masks.ndim == 4:
|
||||
@@ -165,6 +207,115 @@ def predict(request_path: Path) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def predict_video(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model_builder import build_sam3_video_predictor
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
frame_dir = Path(payload["frame_dir"])
|
||||
source_frame_index = int(payload.get("source_frame_index", 0))
|
||||
seed = payload.get("seed") or {}
|
||||
direction = str(payload.get("direction") or "forward").lower()
|
||||
max_frames = payload.get("max_frames")
|
||||
max_frames = int(max_frames) if max_frames else None
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise ValueError(f"Unsupported propagation direction: {direction}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
predictor = build_sam3_video_predictor(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
async_loading_frames=False,
|
||||
)
|
||||
session_id = None
|
||||
try:
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
session = predictor.handle_request(
|
||||
{
|
||||
"type": "start_session",
|
||||
"resource_path": str(frame_dir),
|
||||
"offload_video_to_cpu": True,
|
||||
"offload_state_to_cpu": True,
|
||||
}
|
||||
)
|
||||
session_id = session["session_id"]
|
||||
predictor.handle_request(
|
||||
{
|
||||
"type": "add_prompt",
|
||||
"session_id": session_id,
|
||||
"frame_index": source_frame_index,
|
||||
"bounding_boxes": [_bbox_from_seed(seed)],
|
||||
"bounding_box_labels": [1],
|
||||
"output_prob_thresh": threshold,
|
||||
"rel_coordinates": True,
|
||||
}
|
||||
)
|
||||
frames = []
|
||||
for item in predictor.handle_stream_request(
|
||||
{
|
||||
"type": "propagate_in_video",
|
||||
"session_id": session_id,
|
||||
"propagation_direction": direction,
|
||||
"start_frame_index": source_frame_index,
|
||||
"max_frame_num_to_track": max_frames,
|
||||
"output_prob_thresh": threshold,
|
||||
}
|
||||
):
|
||||
frame_response = _video_outputs_to_response(item.get("outputs") or {})
|
||||
frame_response["frame_index"] = int(item["frame_index"])
|
||||
frames.append(frame_response)
|
||||
finally:
|
||||
if session_id:
|
||||
predictor.handle_request({"type": "close_session", "session_id": session_id})
|
||||
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
|
||||
return predict_video(request_path)
|
||||
|
||||
image_path = Path(payload["image_path"])
|
||||
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
|
||||
text = str(payload.get("text") or "").strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
if prompt_type == "semantic" and not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if prompt_type not in {"semantic", "box"}:
|
||||
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
load_from_HF=not bool(checkpoint_path),
|
||||
)
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
if prompt_type == "box":
|
||||
output = processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=_xyxy_to_cxcywh(payload.get("box") or []),
|
||||
label=True,
|
||||
)
|
||||
else:
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
|
||||
return _prediction_to_response(output)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
|
||||
parser.add_argument("--status", action="store_true")
|
||||
|
||||
@@ -67,6 +67,19 @@ class SAMRegistry:
|
||||
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
|
||||
return self._ensure_available(model_id).predict_box(image, box)
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: Any,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
):
|
||||
model = self.normalize_model_id(model_id)
|
||||
if model != "sam2":
|
||||
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
|
||||
return self._ensure_available(model).predict_interactive(image, box, points, labels)
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: Any):
|
||||
return self._ensure_available(model_id).predict_auto(image)
|
||||
|
||||
@@ -76,5 +89,22 @@ class SAMRegistry:
|
||||
return self._ensure_available(model).predict_semantic(image, text)
|
||||
return self._ensure_available(model).predict_auto(image)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
model_id: str | None,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
):
|
||||
return self._ensure_available(model_id).propagate_video(
|
||||
frame_paths,
|
||||
source_frame_index,
|
||||
seed,
|
||||
direction=direction,
|
||||
max_frames=max_frames,
|
||||
)
|
||||
|
||||
|
||||
sam_registry = SAMRegistry()
|
||||
|
||||
@@ -116,6 +116,44 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||
assert semantic_response.json()["scores"] == [0.5]
|
||||
|
||||
|
||||
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
calls = {}
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
|
||||
def fake_predict_interactive(model, image, box, points, labels):
|
||||
calls["model"] = model
|
||||
calls["box"] = box
|
||||
calls["points"] = points
|
||||
calls["labels"] = labels
|
||||
return (
|
||||
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
[0.88],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
|
||||
|
||||
response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "interactive",
|
||||
"prompt_data": {
|
||||
"box": [0.1, 0.1, 0.9, 0.9],
|
||||
"points": [[0.5, 0.5], [0.2, 0.2]],
|
||||
"labels": [1, 0],
|
||||
},
|
||||
"model": "sam2",
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["scores"] == [0.88]
|
||||
assert calls == {
|
||||
"model": "sam2",
|
||||
"box": [0.1, 0.1, 0.9, 0.9],
|
||||
"points": [[0.5, 0.5], [0.2, 0.2]],
|
||||
"labels": [1, 0],
|
||||
}
|
||||
|
||||
|
||||
def test_model_status_reports_runtime(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
||||
"selected_model": selected_model or "sam2",
|
||||
@@ -170,6 +208,80 @@ def test_model_status_reports_runtime(client, monkeypatch):
|
||||
assert body["models"][1]["available"] is False
|
||||
|
||||
|
||||
def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Video Project"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(3)
|
||||
]
|
||||
calls = {}
|
||||
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
|
||||
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
calls["model"] = model
|
||||
calls["source_frame_index"] = source_frame_index
|
||||
calls["seed"] = seed
|
||||
calls["direction"] = direction
|
||||
calls["max_frames"] = max_frames
|
||||
calls["frame_count"] = len(frame_paths)
|
||||
return [
|
||||
{
|
||||
"frame_index": 0,
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"scores": [0.9],
|
||||
"object_ids": [1],
|
||||
},
|
||||
{
|
||||
"frame_index": 1,
|
||||
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
|
||||
"scores": [0.8],
|
||||
"object_ids": [1],
|
||||
},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
response = client.post("/api/ai/propagate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2",
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"include_source": False,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"bbox": [0.1, 0.1, 0.1, 0.1],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": None,
|
||||
},
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["created_annotation_count"] == 1
|
||||
assert body["processed_frame_count"] == 2
|
||||
assert calls["model"] == "sam2"
|
||||
assert calls["source_frame_index"] == 0
|
||||
assert calls["direction"] == "forward"
|
||||
assert calls["frame_count"] == 2
|
||||
saved = body["annotations"][0]
|
||||
assert saved["frame_id"] == frames[1]["id"]
|
||||
assert saved["mask_data"]["source"] == "sam2_propagation"
|
||||
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
||||
assert saved["mask_data"]["score"] == 0.8
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert len(listing.json()) == 1
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
|
||||
@@ -84,6 +84,12 @@ def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
assert data["progress"] == 0
|
||||
assert data["project_id"] == project["id"]
|
||||
assert data["celery_task_id"] == "celery-1"
|
||||
assert data["payload"] == {
|
||||
"source_type": "video",
|
||||
"parse_fps": 5.0,
|
||||
"max_frames": None,
|
||||
"target_width": 640,
|
||||
}
|
||||
assert queued == [data["id"]]
|
||||
assert published == [data["id"]]
|
||||
|
||||
@@ -94,6 +100,35 @@ def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
assert project_detail["status"] == "parsing"
|
||||
|
||||
|
||||
def test_parse_media_accepts_frame_sequence_options(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Options",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 30,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-options"
|
||||
|
||||
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post(
|
||||
f"/api/media/parse?project_id={project['id']}&parse_fps=15&max_frames=120&target_width=960"
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["payload"] == {
|
||||
"source_type": "video",
|
||||
"parse_fps": 15.0,
|
||||
"max_frames": 120,
|
||||
"target_width": 960,
|
||||
}
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["parse_fps"] == 15.0
|
||||
|
||||
|
||||
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
@@ -118,10 +153,14 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
frame_file.write_bytes(b"fake image")
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
|
||||
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.parse_video",
|
||||
lambda local_path, output_dir, fps, max_frames=None, target_width=640: ([str(frame_file)], 25.0),
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
|
||||
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
|
||||
monkeypatch.setattr("routers.projects.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
published = []
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.publish_task_progress_event",
|
||||
@@ -131,6 +170,17 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 1
|
||||
assert result["frame_sequence"] == {
|
||||
"original_fps": 25.0,
|
||||
"parse_fps": 5.0,
|
||||
"frame_count": 1,
|
||||
"duration_ms": 0.0,
|
||||
"target_width": 640,
|
||||
"frame_width": None,
|
||||
"frame_height": None,
|
||||
"max_frames": None,
|
||||
"object_prefix": f"projects/{project['id']}/frames",
|
||||
}
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
@@ -140,6 +190,8 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
assert project_detail["status"] == "ready"
|
||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||
assert frames[0]["timestamp_ms"] == 0.0
|
||||
assert frames[0]["source_frame_number"] == 0
|
||||
|
||||
|
||||
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from services.sam3_engine import SAM3Engine
|
||||
from services.sam3_external_worker import _to_numpy
|
||||
|
||||
|
||||
class _Completed:
|
||||
@@ -14,6 +15,8 @@ class _Completed:
|
||||
|
||||
|
||||
def _external_settings(monkeypatch, python_path: Path):
|
||||
checkpoint_path = python_path.with_name("sam3.pt")
|
||||
checkpoint_path.write_bytes(b"checkpoint")
|
||||
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
|
||||
python_path.chmod(0o755)
|
||||
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
|
||||
@@ -23,6 +26,7 @@ def _external_settings(monkeypatch, python_path: Path):
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_checkpoint_path", str(checkpoint_path))
|
||||
|
||||
|
||||
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
@@ -30,9 +34,12 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
assert "--status" in args
|
||||
assert _kwargs["env"]["SAM3_CHECKPOINT_PATH"].endswith("sam3.pt")
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"checkpoint_path": _kwargs["env"]["SAM3_CHECKPOINT_PATH"],
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -48,7 +55,10 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
assert status["external_available"] is True
|
||||
assert status["package_available"] is True
|
||||
assert status["python_ok"] is True
|
||||
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
assert status["checkpoint_exists"] is True
|
||||
assert status["checkpoint_path"].endswith("sam3.pt")
|
||||
assert status["supports"] == ["semantic", "box", "video_track"]
|
||||
assert status["message"] == "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
@@ -61,6 +71,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -71,6 +82,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["text"] == "vessel"
|
||||
assert request["confidence_threshold"] == 0.4
|
||||
assert request["checkpoint_path"].endswith("sam3.pt")
|
||||
assert Path(request["image_path"]).exists()
|
||||
return _Completed(stdout=json.dumps({
|
||||
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
@@ -86,6 +98,97 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
assert any("--request" in args for args in calls)
|
||||
|
||||
|
||||
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
request_path = Path(args[-1])
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["prompt_type"] == "box"
|
||||
assert request["box"] == [0.1, 0.2, 0.7, 0.8]
|
||||
assert request["text"] == ""
|
||||
return _Completed(stdout=json.dumps({
|
||||
"polygons": [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]],
|
||||
"scores": [0.88],
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
polygons, scores = SAM3Engine().predict_box(
|
||||
np.zeros((8, 8, 3), dtype=np.uint8),
|
||||
[0.1, 0.2, 0.7, 0.8],
|
||||
)
|
||||
|
||||
assert polygons == [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]]
|
||||
assert scores == [0.88]
|
||||
|
||||
|
||||
def test_sam3_propagate_video_uses_external_worker(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
frame_dir = tmp_path / "frames"
|
||||
frame_dir.mkdir()
|
||||
frame_paths = []
|
||||
for index in range(2):
|
||||
frame_path = frame_dir / f"frame_{index:06d}.jpg"
|
||||
frame_path.write_bytes(b"jpeg")
|
||||
frame_paths.append(str(frame_path))
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
request_path = Path(args[-1])
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["prompt_type"] == "video_track"
|
||||
assert request["frame_dir"] == str(frame_dir)
|
||||
assert request["source_frame_index"] == 0
|
||||
assert request["direction"] == "forward"
|
||||
assert request["max_frames"] == 2
|
||||
assert request["seed"]["bbox"] == [0.1, 0.1, 0.2, 0.2]
|
||||
return _Completed(stdout=json.dumps({
|
||||
"frames": [
|
||||
{
|
||||
"frame_index": 1,
|
||||
"polygons": [[[0.2, 0.2], [0.4, 0.2], [0.4, 0.4]]],
|
||||
"scores": [0.7],
|
||||
"object_ids": [1],
|
||||
}
|
||||
]
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
frames = SAM3Engine().propagate_video(
|
||||
frame_paths,
|
||||
0,
|
||||
{"bbox": [0.1, 0.1, 0.2, 0.2]},
|
||||
direction="forward",
|
||||
max_frames=2,
|
||||
)
|
||||
|
||||
assert frames[0]["frame_index"] == 1
|
||||
assert frames[0]["scores"] == [0.7]
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
@@ -94,6 +197,7 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -110,3 +214,32 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
assert "HF access denied" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected SAM 3 external inference failure.")
|
||||
|
||||
|
||||
def test_sam3_worker_casts_floating_tensors_before_numpy():
|
||||
class FakeTensor:
|
||||
def __init__(self):
|
||||
self.float_called = False
|
||||
|
||||
def detach(self):
|
||||
return self
|
||||
|
||||
def is_floating_point(self):
|
||||
return True
|
||||
|
||||
def float(self):
|
||||
self.float_called = True
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
|
||||
tensor = FakeTensor()
|
||||
|
||||
result = _to_numpy(tensor)
|
||||
|
||||
assert tensor.float_called is True
|
||||
assert result.tolist() == [1.0]
|
||||
|
||||
@@ -38,14 +38,14 @@ Word 方案描述的理想系统包含:
|
||||
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` |
|
||||
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
|
||||
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` |
|
||||
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 |
|
||||
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 2 已接 video predictor 片段传播;SAM 3 通过独立 Python 3.12 环境桥接,支持文本/框提示和 official video tracker 入口,状态会检查 Python/CUDA/包/本地 checkpoint |
|
||||
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
|
||||
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
|
||||
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON 和 PNG mask ZIP 前端按钮均已接入,ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
|
||||
|
||||
## 当前代码尚未落地的目标
|
||||
|
||||
- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU 和官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny。
|
||||
- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU、官方包导入和本地 `sam3权重/sam3.pt` checkpoint 状态检查。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny;video tracker 入口已接入,真实效果取决于本地 checkpoint 是否兼容 video model。
|
||||
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。
|
||||
- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point;骨架提取、HDBSCAN 和模板自动映射尚未实现。
|
||||
- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。
|
||||
@@ -55,4 +55,4 @@ Word 方案描述的理想系统包含:
|
||||
|
||||
## 结论
|
||||
|
||||
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。
|
||||
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可用 SAM 2 / SAM 3 进行视频片段传播、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 SAM 3 真实视频 tracker smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
|
||||
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 |
|
||||
| “传播片段”按钮 | 真实可用 | 使用当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;SAM 2 走 video predictor,SAM 3 走 external video tracker,完成后刷新已保存标注 |
|
||||
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
|
||||
|
||||
## CanvasArea 画布
|
||||
@@ -78,11 +79,11 @@
|
||||
| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 |
|
||||
| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
|
||||
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
|
||||
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon;点工具生成小区域;均写入 `Mask.segmentation`,可归档保存 |
|
||||
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后可按 Enter 完成,也可在三点后点击首节点闭合;矩形/圆/线拖拽生成 polygon;点工具生成小区域;绘制工具可在已有 mask 上继续落点;均写入 `Mask.segmentation`,可归档保存 |
|
||||
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 |
|
||||
| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点 |
|
||||
| Polygon 逐点编辑 / 删除 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点;选中 mask 但未选中顶点时 Delete/Backspace 删除整个 mask,已保存 mask 会同步调用后端删除 |
|
||||
| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty,归档保存会更新后端 |
|
||||
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||
| 应用分类 | 真实可用 | Canvas 右下角按钮可将当前选择的模板分类应用到本帧 mask;右侧语义分类树点击分类时会优先改当前已选 mask;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
|
||||
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
|
||||
| 当前图层树文字 | Mock / UI-only | 固定显示 `OBJECT_VEHICLE_01` |
|
||||
@@ -93,7 +94,7 @@
|
||||
|------|------|------|
|
||||
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
||||
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
|
||||
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask |
|
||||
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
|
||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
||||
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
||||
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y |
|
||||
@@ -105,15 +106,16 @@
|
||||
| 帧缩略图 | 真实可用 | 使用 `frames[].url` |
|
||||
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
|
||||
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
|
||||
| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` |
|
||||
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
|
||||
| 方向键切帧 | Mock / UI-only | Word 提到,但当前没有键盘监听 |
|
||||
| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 |
|
||||
|
||||
## OntologyInspector 本体面板
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 模板选择 | 部分可用 | 读取全局 templates,可切换 activeTemplateId |
|
||||
| 分类树展示 | 真实可用 | 显示模板 classes 和本地 customClasses |
|
||||
| 分类树展示 / 换标签 | 真实可用 | 显示模板 classes 和本地 customClasses;点击分类会设为后续新 mask 的 activeClass,如果 Canvas 已选 mask,则同步更新已选 mask 的标签、颜色和 class 元数据 |
|
||||
| 添加自定义分类 | 部分可用 | 只存在组件本地状态,不保存到后端 |
|
||||
| 置信度条 | Mock / UI-only | 固定 `0.9412` |
|
||||
| 拓扑锚点数量 | Mock / UI-only | 固定 `12 节点` |
|
||||
@@ -124,8 +126,9 @@
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry |
|
||||
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
|
||||
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
|
||||
| 正向/反向点 | 真实可用 | 可在当前项目帧上加点并调用 AI 推理接口;SAM 2 框选后会携带原始框和累计正/反点细化同一个候选 mask;SAM 3 选择后会提示点交互需切回 SAM 2 |
|
||||
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 |
|
||||
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 |
|
||||
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 |
|
||||
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
|
||||
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
||||
@@ -150,6 +153,6 @@
|
||||
|
||||
## 总体结论
|
||||
|
||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||
|
||||
当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。
|
||||
|
||||
@@ -32,12 +32,13 @@ Authorization: Bearer <token>
|
||||
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
|
||||
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
|
||||
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
||||
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
|
||||
| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;支持 `parse_fps`、`max_frames`、`target_width` |
|
||||
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
||||
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
|
||||
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
|
||||
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
|
||||
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url,以及 `timestamp_ms`、`source_frame_number` |
|
||||
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
|
||||
| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 |
|
||||
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
|
||||
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
|
||||
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
|
||||
@@ -70,12 +71,13 @@ Authorization: Bearer <token>
|
||||
| DELETE | `/api/templates/{template_id}` | 删除模板 |
|
||||
| POST | `/api/media/upload` | 上传视频/图片/DICOM 单文件 |
|
||||
| POST | `/api/media/upload/dicom` | 批量上传 DICOM |
|
||||
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
|
||||
| POST | `/api/media/parse` | 创建 Celery 拆帧任务;query 支持 `project_id`、`source_type`、`parse_fps`、`max_frames`、`target_width` |
|
||||
| GET | `/api/tasks` | 查询后台任务列表 |
|
||||
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
|
||||
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
|
||||
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
|
||||
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
|
||||
| POST | `/api/ai/propagate` | SAM 2 / SAM 3 视频片段传播并保存标注 |
|
||||
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
|
||||
| POST | `/api/ai/auto` | 自动分割 |
|
||||
| POST | `/api/ai/annotate` | 保存 AI 标注 |
|
||||
@@ -110,6 +112,25 @@ Authorization: Bearer <token>
|
||||
}
|
||||
```
|
||||
|
||||
### 创建标准帧序列拆帧任务
|
||||
|
||||
```text
|
||||
POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
|
||||
```
|
||||
|
||||
任务 `payload` 会记录本次拆帧参数;完成后的 `result.frame_sequence` 返回 `original_fps`、`parse_fps`、`frame_count`、`duration_ms`、`target_width`、帧宽高和 MinIO object prefix。每条 `FrameOut` 包含:
|
||||
|
||||
```json
|
||||
{
|
||||
"frame_index": 0,
|
||||
"image_url": "http://...",
|
||||
"width": 960,
|
||||
"height": 540,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0
|
||||
}
|
||||
```
|
||||
|
||||
### 创建/更新模板
|
||||
|
||||
```json
|
||||
@@ -150,11 +171,14 @@ Authorization: Bearer <token>
|
||||
|
||||
- `point`
|
||||
- `box`
|
||||
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定。
|
||||
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points` 和 `labels`。
|
||||
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。
|
||||
|
||||
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。
|
||||
|
||||
可选 `options` 字段:
|
||||
|
||||
- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
|
||||
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
|
||||
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
|
||||
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
|
||||
|
||||
@@ -183,6 +207,32 @@ Authorization: Bearer <token>
|
||||
}
|
||||
```
|
||||
|
||||
### 视频片段传播请求体
|
||||
|
||||
工作区“传播片段”调用:
|
||||
|
||||
```json
|
||||
{
|
||||
"project_id": 1,
|
||||
"frame_id": 123,
|
||||
"model": "sam2",
|
||||
"direction": "forward",
|
||||
"max_frames": 30,
|
||||
"include_source": false,
|
||||
"save_annotations": true,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
"bbox": [0.1, 0.1, 0.2, 0.2],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`model=sam2` 使用 SAM 2 video predictor 的 mask seed 传播;`model=sam3` 使用独立 Python 3.12 helper 中的 SAM 3 video tracker,并以 seed bbox 作为初始提示。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `sam2_propagation` 或 `sam3_propagation`。
|
||||
|
||||
## 已完成的接口对齐
|
||||
|
||||
- `updateProject()` 已从 `PUT` 改为 `PATCH`。
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
剩余边界:
|
||||
|
||||
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test。
|
||||
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接、本地 `sam3权重/sam3.pt` checkpoint 状态检查、本地 checkpoint 加载参数接入、单帧文本/框提示和 video tracker API 入口;下一步需要基于真实业务帧验证语义召回质量和视频 tracker 稳定性。
|
||||
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。
|
||||
|
||||
## 阶段 2:打通标注保存(已完成基础闭环)
|
||||
@@ -127,6 +127,24 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前
|
||||
1. 在前端预览重叠裁决结果。
|
||||
2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。
|
||||
|
||||
## 阶段 7.5:视频片段传播(已完成基础闭环)
|
||||
|
||||
当前工作区“传播片段”会使用当前选中 mask 或当前帧第一个 mask 作为 seed,默认向后传播 30 帧并把结果写入后端标注表。
|
||||
|
||||
已完成:
|
||||
|
||||
1. 前端 `propagateMasks()` 已接入 `POST /api/ai/propagate`。
|
||||
2. 工作区按钮会把 seed mask 的 normalized polygon、bbox、label、color 和 class 元数据传给后端。
|
||||
3. SAM 2 路径使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。
|
||||
4. SAM 3 路径通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境中的官方 `build_sam3_video_predictor()`。
|
||||
5. 后端会跳过源帧,把传播结果保存到后续帧 `annotations`,并在完成后由前端刷新回显。
|
||||
|
||||
剩余建议:
|
||||
|
||||
1. 把传播任务改为异步任务,接入 Dashboard 和 WebSocket 进度。
|
||||
2. 前端增加传播方向、帧数和覆盖已有标注策略设置。
|
||||
3. 用真实长视频分别做 SAM 2 / SAM 3 tracker smoke test 和质量评估。
|
||||
|
||||
## 阶段 8:清理 UI 文案与 Mock
|
||||
|
||||
建议统一这些文案和真实能力:
|
||||
|
||||
@@ -28,7 +28,11 @@
|
||||
- 未提供项目 ID 上传时,后端自动创建项目。
|
||||
- 提供项目 ID 上传时,后端把上传对象关联到该项目。
|
||||
- 拆帧接口根据项目 `source_type` 处理视频或 DICOM。
|
||||
- 拆帧接口支持 `parse_fps`、`max_frames` 和 `target_width` 参数,用于生成可被 SAM 2 / SAM 3 视频处理复用的标准帧序列。
|
||||
- 视频帧使用连续 `frame_%06d.jpg` 命名,默认从 `frame_000000.jpg` 开始,并按 `target_width` 缩放。
|
||||
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。
|
||||
- 每条帧记录包含 `frame_index`、`image_url`、`width`、`height`、`timestamp_ms` 和 `source_frame_number`。
|
||||
- 任务完成结果包含 `frame_sequence` 元数据:`original_fps`、`parse_fps`、`frame_count`、`duration_ms`、`target_width`、帧宽高和对象存储前缀。
|
||||
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
|
||||
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
|
||||
- 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。
|
||||
@@ -41,8 +45,9 @@
|
||||
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。
|
||||
- Canvas 显示当前帧图片。
|
||||
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、播放/暂停顺序推进帧。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
||||
- 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。
|
||||
- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps` 或 `original_fps`,格式为 `mm:ss.cc`。
|
||||
|
||||
## R5 工具栏
|
||||
|
||||
@@ -50,12 +55,16 @@
|
||||
- 正向点、反向点、框选工具会影响 Canvas 交互。
|
||||
- 魔法棒按钮切换到 AI 页面。
|
||||
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
|
||||
- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
||||
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
||||
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
|
||||
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
|
||||
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
|
||||
- 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。
|
||||
- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
|
||||
- 区域合并工具支持多选当前帧 mask,并使用 polygon union 生成合并后的主 mask。
|
||||
- 区域去除工具支持多选当前帧 mask,并从第一个选中的主 mask 中扣除后续选中 mask。
|
||||
- 区域合并/去除模式显示已选数量,并隐藏 polygon 编辑手柄以避免手柄抢占多选点击。
|
||||
- 区域去除结果包含内洞时,前端保留 hole ring 并用 even-odd 规则渲染。
|
||||
|
||||
## R6 AI 推理
|
||||
|
||||
@@ -65,7 +74,14 @@
|
||||
- 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
||||
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
||||
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
||||
- 工作区 SAM 2 框选会建立一个候选 mask;后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。
|
||||
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
||||
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。
|
||||
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。
|
||||
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed,调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
|
||||
- `POST /api/ai/propagate` 支持 `model=sam2` 或 `model=sam3`;SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。
|
||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 分别标记为 `sam2_propagation` 或 `sam3_propagation`,并保留 label、color 和 class 元数据。
|
||||
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。
|
||||
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
|
||||
- 后端返回 `polygons` 和 `scores`。
|
||||
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
||||
@@ -98,6 +114,7 @@
|
||||
- 工作区右侧可以选择模板。
|
||||
- 面板显示模板分类和组件本地自定义分类。
|
||||
- 用户可以选择具体分类;新 AI mask 会记录 `classId`、`className`、`classZIndex`,并在保存时写入 `mask_data.class`。
|
||||
- 如果 Canvas 当前已经选中一个或多个 mask,点击语义分类树会把这些 mask 的 `label`、`color` 和 class 元数据改为该分类;已保存 mask 会进入 `dirty` 状态,归档保存时更新后端。
|
||||
- 添加自定义分类只存在组件本地状态,不保存到后端。
|
||||
- 置信度、拓扑锚点和重新提取骨架按钮当前为展示/占位。
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
| 模块 | 文件 | 设计职责 |
|
||||
|------|------|----------|
|
||||
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
|
||||
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 |
|
||||
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 |
|
||||
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
||||
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
||||
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
|
||||
@@ -30,7 +30,7 @@
|
||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、左右方向键切帧、播放和当前/总时长显示 |
|
||||
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
|
||||
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
||||
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
|
||||
@@ -48,10 +48,10 @@
|
||||
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
|
||||
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
|
||||
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
|
||||
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
|
||||
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、视频传播、模型状态和标注保存 |
|
||||
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
|
||||
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
|
||||
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 |
|
||||
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 |
|
||||
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接、本地 checkpoint 加载、文本语义推理、正框几何提示和 video tracker 适配 |
|
||||
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
|
||||
|
||||
## 状态模型
|
||||
@@ -59,9 +59,10 @@
|
||||
前端 store 的核心对象:
|
||||
|
||||
- `Project`:项目基本信息、状态、帧数、fps、媒体路径。
|
||||
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
|
||||
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高、序列时间戳和原视频源帧号。
|
||||
- `Template` / `TemplateClass`:模板和分类定义。
|
||||
- `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。
|
||||
- `selectedMaskIds`:Canvas 当前选中的 mask id 列表,供右侧本体面板对已选区域直接换标签。
|
||||
- `maskHistory` / `maskFuture`:mask 编辑历史栈,用于撤销和重做。
|
||||
- `activeModule`:当前页面。
|
||||
- `activeTool`:当前工具。
|
||||
@@ -79,9 +80,11 @@
|
||||
|
||||
1. `ProjectLibrary` 创建项目。
|
||||
2. 上传视频或 DICOM 到 `/api/media/upload` 或 `/api/media/upload/dicom`。
|
||||
3. 调用 `/api/media/parse` 创建异步拆帧任务。
|
||||
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||
5. 刷新项目列表。
|
||||
3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。
|
||||
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。
|
||||
5. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。
|
||||
6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||
7. 刷新项目列表。
|
||||
|
||||
### 任务控制
|
||||
|
||||
@@ -95,29 +98,43 @@
|
||||
|
||||
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
||||
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。
|
||||
3. 帧数据映射为 store `Frame[]`。
|
||||
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。
|
||||
4. 当前帧传入 `CanvasArea`。
|
||||
|
||||
### AI 点/框推理
|
||||
|
||||
1. 用户在 Canvas 选择正向点、反向点或框选。
|
||||
2. `CanvasArea` 读取当前帧 ID 和宽高。
|
||||
3. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`。
|
||||
4. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
|
||||
5. 前端把 `polygons` 转为 mask,写入 store。
|
||||
6. Canvas 按当前帧过滤并渲染 mask。
|
||||
7. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。
|
||||
8. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||
3. SAM 2 框选会创建一个候选 mask,并记录原始框;后续正向点/反向点会累计到同一候选上。
|
||||
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。
|
||||
5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
|
||||
6. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。
|
||||
7. Canvas 按当前帧过滤并渲染 mask。
|
||||
8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。
|
||||
9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||
10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||
11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||
|
||||
### 视频片段传播
|
||||
|
||||
1. 用户在工作区选中一个当前帧 mask;如果未显式选中,前端使用当前帧第一个 mask。
|
||||
2. `VideoWorkspace` 用 `buildAnnotationPayload()` 把 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。
|
||||
3. 前端调用 `POST /api/ai/propagate`,默认 `direction=forward`、`max_frames=30`、`include_source=false`。
|
||||
4. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。
|
||||
5. `model=sam2` 时,`sam2_engine` 使用 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播。
|
||||
6. `model=sam3` 时,`sam3_engine` 将请求交给 `sam3_external_worker.py`,由独立 Python 3.12 环境调用官方 `build_sam3_video_predictor()`,以 seed bbox 走 video tracker。
|
||||
7. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。
|
||||
8. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注。
|
||||
|
||||
### 手工绘制与历史栈
|
||||
|
||||
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
|
||||
2. `CanvasArea` 将交互坐标转换成像素 polygon。
|
||||
3. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||
4. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||
5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
||||
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。
|
||||
4. mask path 只在 `move`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
||||
5. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||
6. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||
7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
||||
|
||||
### Polygon 逐点编辑
|
||||
|
||||
@@ -126,14 +143,16 @@
|
||||
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。
|
||||
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
|
||||
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
|
||||
6. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。
|
||||
|
||||
### 区域合并与去除
|
||||
|
||||
1. 用户选择 `area_merge` 或 `area_remove` 后,点击多个当前帧 mask 组成选择集。
|
||||
2. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。
|
||||
3. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。
|
||||
4. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留。
|
||||
5. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路。
|
||||
2. 合并/去除模式隐藏 polygon 顶点和边中点编辑手柄,并在右下角显示已选数量;少于两个 mask 时操作按钮禁用。
|
||||
3. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。
|
||||
4. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。
|
||||
5. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留;如果 difference 产生内洞,`segmentation` 保留外圈和 hole ring,渲染时使用 even-odd fill。
|
||||
6. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路;带洞结果的面积按外圈减内洞计算。
|
||||
|
||||
### GT Mask 导入
|
||||
|
||||
@@ -153,7 +172,10 @@
|
||||
3. 保存时调用 `createTemplate()` 或 `updateTemplate()`。
|
||||
4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。
|
||||
5. 返回时再解包给前端。
|
||||
6. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
|
||||
7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||
8. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
||||
9. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。
|
||||
|
||||
### 导出
|
||||
|
||||
@@ -173,13 +195,20 @@
|
||||
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。
|
||||
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。
|
||||
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`。
|
||||
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
||||
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
|
||||
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
||||
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。
|
||||
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。
|
||||
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
||||
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
|
||||
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps`、`max_frames`、`target_width`,用于生成标准帧序列。
|
||||
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms` 和 `source_frame_number`。
|
||||
- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
||||
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。
|
||||
- SAM 3 box prompt 复用后端 `/api/ai/predict` 的 `box` prompt_type,输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。
|
||||
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt,不会把正/反点误发送为 SAM 3 point prompt;空文本、后端错误和空结果都会显示反馈消息。
|
||||
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
|
||||
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker;当前前端默认向后传播 30 帧并保存结果标注。
|
||||
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
|
||||
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
||||
|
||||
@@ -198,7 +227,7 @@
|
||||
以下能力属于当前冻结版本的占位或半可用功能:
|
||||
|
||||
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。
|
||||
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。
|
||||
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权时 `available=false`。
|
||||
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。
|
||||
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和本地 checkpoint;状态接口会暴露真实可用性,运行时缺失时 `available=false`。
|
||||
- 自定义分类只存在本地组件状态。
|
||||
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。
|
||||
|
||||
@@ -16,18 +16,51 @@
|
||||
|------|----------|--------|
|
||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
||||
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
|
||||
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
|
||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
|
||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
||||
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
||||
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
||||
|
||||
## 逐功能点追踪
|
||||
|
||||
| 需求 | 功能点 | 对应测试 | 当前状态 |
|
||||
|------|--------|----------|----------|
|
||||
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
|
||||
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
||||
| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 |
|
||||
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
||||
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
||||
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
|
||||
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
|
||||
|
||||
## 本轮补齐记录
|
||||
|
||||
- R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
|
||||
- R6:补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。
|
||||
- R6:补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。
|
||||
- R6:补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。
|
||||
- R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
|
||||
- R6:补充 `propagateMasks()` API 封装和 `VideoWorkspace` 传播按钮测试,验证当前选中区域会发送到后端视频传播接口。
|
||||
- R6:补充 SAM 3 external video tracker 请求测试,验证主后端会把帧目录、源帧索引、seed bbox 和方向传给独立 Python helper。
|
||||
- R3:补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps`、`max_frames`、`target_width` 会进入任务。
|
||||
- R3:补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms`、`source_frame_number` 和 `result.frame_sequence` 元数据。
|
||||
- R8:补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。
|
||||
- R9:补充 Canvas 选中 mask id 全局同步、本体树点击分类给已选 mask 换标签的测试,验证已保存 mask 会进入 dirty 状态。
|
||||
|
||||
## 运行命令
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
@@ -62,4 +62,112 @@ describe('AISegmentation', () => {
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
it('prompts for semantic text before running SAM3 inference', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||
target: { value: '胆囊' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
text: '胆囊',
|
||||
})));
|
||||
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'semantic-1',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'semantic result',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
useStore.setState({
|
||||
activeTemplateId: 'template-1',
|
||||
activeClassId: 'class-1',
|
||||
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||
target: { value: '胆囊' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
text: '胆囊',
|
||||
options: {
|
||||
crop_to_prompt: false,
|
||||
auto_filter_background: true,
|
||||
min_score: 0.05,
|
||||
},
|
||||
})));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'semantic-1',
|
||||
frameId: 'frame-1',
|
||||
templateId: 'template-1',
|
||||
classId: 'class-1',
|
||||
className: '胆囊',
|
||||
classZIndex: 30,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
|
||||
// Canvas state
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -91,9 +92,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
};
|
||||
|
||||
const runInference = useCallback(async () => {
|
||||
if (points.length === 0 && !semanticText.trim()) return;
|
||||
const textPrompt = semanticText.trim();
|
||||
if (aiModel === 'sam3' && !textPrompt) {
|
||||
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
|
||||
return;
|
||||
}
|
||||
if (points.length === 0 && !textPrompt) {
|
||||
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
|
||||
return;
|
||||
}
|
||||
if (!currentFrame?.id) {
|
||||
console.warn('AI inference skipped: no project frame is selected');
|
||||
setInferenceMessage('请先在项目工作区选择一帧图像。');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -101,18 +111,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('AI inference skipped: active frame dimensions are unavailable');
|
||||
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
setInferenceMessage('');
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageId: currentFrame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: textPrompt || undefined,
|
||||
options: {
|
||||
crop_to_prompt: cropMode,
|
||||
auto_filter_background: autoDeleteBg,
|
||||
@@ -120,6 +132,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
},
|
||||
});
|
||||
|
||||
if (result.masks.length === 0) {
|
||||
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
||||
} else {
|
||||
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
|
||||
}
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
@@ -142,6 +159,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('AI inference failed:', err);
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
@@ -282,6 +301,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
|
||||
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
|
||||
</button>
|
||||
{inferenceMessage && (
|
||||
<div className="rounded border border-white/10 bg-white/5 px-3 py-2 text-[11px] leading-relaxed text-gray-300">
|
||||
{inferenceMessage}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
||||
|
||||
@@ -65,6 +65,157 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
|
||||
useStore.setState({ aiModel: 'sam3' });
|
||||
|
||||
render(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
|
||||
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
|
||||
useStore.setState({ aiModel: 'sam3' });
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam3-box-mask',
|
||||
pathData: 'M 20 20 L 80 20 L 80 80 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
bbox: [20, 20, 60, 60],
|
||||
area: 3600,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'sam3-box-mask',
|
||||
metadata: expect.objectContaining({
|
||||
source: 'sam3_box',
|
||||
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}),
|
||||
}));
|
||||
});
|
||||
|
||||
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
|
||||
apiMock.predictMask
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-box',
|
||||
pathData: 'M 10 10 L 90 10 L 90 90 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 90, 10, 90, 90]],
|
||||
bbox: [10, 10, 80, 80],
|
||||
area: 6400,
|
||||
},
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-refined-pos',
|
||||
pathData: 'M 20 20 L 80 20 L 80 80 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
bbox: [20, 20, 60, 60],
|
||||
area: 3600,
|
||||
},
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-refined-neg',
|
||||
pathData: 'M 30 30 L 70 30 L 70 70 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||
bbox: [30, 30, 40, 40],
|
||||
area: 1600,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(1, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: undefined,
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||
|
||||
rerender(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(stage, { clientX: 150, clientY: 100 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 150, y: 100, type: 'pos' }],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-box',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'sam2_interactive',
|
||||
promptPointCount: 1,
|
||||
}),
|
||||
}));
|
||||
|
||||
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
|
||||
fireEvent.click(stage, { clientX: 300, clientY: 150 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(3, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [
|
||||
{ x: 150, y: 100, type: 'pos' },
|
||||
{ x: 300, y: 150, type: 'neg' },
|
||||
],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-box',
|
||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||
points: [[150, 100]],
|
||||
metadata: expect.objectContaining({ promptPointCount: 2 }),
|
||||
}));
|
||||
});
|
||||
|
||||
it('renders only masks that belong to the current frame', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -79,6 +230,26 @@ describe('CanvasArea', () => {
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('publishes the selected mask ids for the ontology panel', async () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'A',
|
||||
color: '#fff',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
|
||||
});
|
||||
|
||||
it('renders imported GT seed points for editable point regions', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -164,6 +335,57 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('deletes the selected draft mask with Delete when no vertex is selected', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
fireEvent.keyDown(window, { key: 'Delete' });
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().maskHistory.at(-1)).toEqual([
|
||||
expect.objectContaining({ id: 'draft-1' }),
|
||||
]);
|
||||
});
|
||||
|
||||
it('deletes the selected saved mask locally and notifies the backend deletion callback', () => {
|
||||
const onDeleteMaskAnnotations = vi.fn();
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} onDeleteMaskAnnotations={onDeleteMaskAnnotations} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
fireEvent.keyDown(window, { key: 'Backspace' });
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(onDeleteMaskAnnotations).toHaveBeenCalledWith(['99']);
|
||||
});
|
||||
|
||||
it('inserts a polygon vertex from an edge midpoint handle', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -248,9 +470,13 @@ describe('CanvasArea', () => {
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_merge" frame={frame} />);
|
||||
expect(screen.getByText('已选 0')).toBeInTheDocument();
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
expect(screen.getByText('已选 1')).toBeInTheDocument();
|
||||
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
|
||||
fireEvent.click(paths[1]);
|
||||
expect(screen.getByText('已选 2')).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
@@ -300,6 +526,45 @@ describe('CanvasArea', () => {
|
||||
expect(useStore.getState().masks[1].id).toBe('m2');
|
||||
});
|
||||
|
||||
it('renders inner overlap removal as a hole in the primary mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 110 10 L 110 110 L 10 110 Z',
|
||||
label: 'A',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 110, 10, 110, 110, 10, 110]],
|
||||
},
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 40 40 L 80 40 L 80 80 L 40 80 Z',
|
||||
label: 'B',
|
||||
color: '#ff0000',
|
||||
segmentation: [[40, 40, 80, 40, 80, 80, 40, 80]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_remove" frame={frame} />);
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
fireEvent.click(paths[1]);
|
||||
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
|
||||
|
||||
const [primary] = useStore.getState().masks;
|
||||
expect(primary).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
area: 8400,
|
||||
bbox: [10, 10, 100, 100],
|
||||
metadata: expect.objectContaining({ hasHoles: true }),
|
||||
}));
|
||||
expect(primary.segmentation).toHaveLength(2);
|
||||
expect(screen.getAllByTestId('konva-path')[0]).toHaveAttribute('data-fill-rule', 'evenodd');
|
||||
});
|
||||
|
||||
it('creates a manual rectangle mask that can be undone and redone', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
@@ -329,6 +594,93 @@ describe('CanvasArea', () => {
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('creates a manual circle mask from a drag gesture', () => {
|
||||
render(<CanvasArea activeTool="create_circle" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工圆形',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
bbox: [120, 80, 140, 120],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '圆形',
|
||||
}),
|
||||
}));
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
|
||||
});
|
||||
|
||||
it('creates a manual line region from a drag gesture', () => {
|
||||
render(<CanvasArea activeTool="create_line" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工线段',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '线段',
|
||||
}),
|
||||
}));
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
|
||||
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
|
||||
});
|
||||
|
||||
it('creates an editable point region on click', () => {
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工点区域',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
points: [[120, 80]],
|
||||
bbox: expect.arrayContaining([115, 75]),
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '点区域',
|
||||
}),
|
||||
}));
|
||||
});
|
||||
|
||||
it('creates a point region when clicking over an existing mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 200 10 L 200 200 Z',
|
||||
label: 'Existing',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 200, 10, 200, 200]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(2);
|
||||
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
|
||||
metadata: expect.objectContaining({ shape: '点区域' }),
|
||||
points: [[120, 80]],
|
||||
}));
|
||||
});
|
||||
|
||||
it('finalizes a clicked polygon with Enter', () => {
|
||||
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
@@ -344,6 +696,29 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('closes a clicked polygon by clicking the first node again', () => {
|
||||
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.click(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 220, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 180, clientY: 160 });
|
||||
|
||||
const handles = screen.getAllByTestId('konva-circle');
|
||||
expect(handles[0]).toHaveAttribute('data-fill', '#facc15');
|
||||
fireEvent.click(handles[0]);
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 120 80 L 220 80 L 180 160 Z',
|
||||
segmentation: [[120, 80, 220, 80, 180, 160]],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '多边形',
|
||||
}),
|
||||
}));
|
||||
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
|
||||
@@ -14,11 +14,14 @@ interface CanvasAreaProps {
|
||||
}
|
||||
|
||||
type CanvasPoint = { x: number; y: number };
|
||||
type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
|
||||
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
||||
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||
const POLYGON_TOOL = 'create_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
const POLYGON_CLOSE_RADIUS = 8;
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
@@ -88,6 +91,10 @@ function polygonArea(points: CanvasPoint[]): number {
|
||||
return Math.abs(sum) / 2;
|
||||
}
|
||||
|
||||
function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
|
||||
return Math.hypot(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
function segmentationArea(segmentation?: number[][]): number {
|
||||
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||
}
|
||||
@@ -115,20 +122,35 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
|
||||
return polygons.length > 0 ? polygons : null;
|
||||
}
|
||||
|
||||
function openRingPoints(ring: Pair[]): CanvasPoint[] {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
&& ring[0][1] === ring[ring.length - 1][1]
|
||||
? ring.slice(0, -1)
|
||||
: ring;
|
||||
return openRing.map(([x, y]) => ({ x, y }));
|
||||
}
|
||||
|
||||
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
|
||||
return geometry
|
||||
.map((polygon) => polygon[0] || [])
|
||||
.map((ring) => {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
&& ring[0][1] === ring[ring.length - 1][1]
|
||||
? ring.slice(0, -1)
|
||||
: ring;
|
||||
return openRing.flatMap(([x, y]) => [x, y]);
|
||||
})
|
||||
.flatMap((polygon) => polygon)
|
||||
.map((ring) => openRingPoints(ring).flatMap(({ x, y }) => [x, y]))
|
||||
.filter((polygon) => polygon.length >= 6);
|
||||
}
|
||||
|
||||
function multiPolygonArea(geometry: MultiPolygon): number {
|
||||
return geometry.reduce((sum, polygon) => {
|
||||
const [outerRing, ...holeRings] = polygon;
|
||||
const outerArea = outerRing ? polygonArea(openRingPoints(outerRing)) : 0;
|
||||
const holesArea = holeRings.reduce((holeSum, ring) => holeSum + polygonArea(openRingPoints(ring)), 0);
|
||||
return sum + Math.max(outerArea - holesArea, 0);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
|
||||
return geometry.some((polygon) => polygon.length > 1);
|
||||
}
|
||||
|
||||
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
const x1 = Math.min(start.x, end.x);
|
||||
const y1 = Math.min(start.y, end.y);
|
||||
@@ -179,10 +201,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
|
||||
const [points, setPoints] = useState<PromptPoint[]>([]);
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
||||
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
||||
const [samPromptBox, setSamPromptBox] = useState<PromptBox | null>(null);
|
||||
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
|
||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||
@@ -191,12 +215,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const updateMask = useStore((state) => state.updateMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
@@ -226,6 +252,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -252,6 +279,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setSelectedVertexIndex(null);
|
||||
}, [effectiveTool, frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
setPoints([]);
|
||||
setSamPromptBox(null);
|
||||
setSamCandidateMaskId(null);
|
||||
}, [frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
setGlobalSelectedMaskIds(selectedMaskIds);
|
||||
}, [selectedMaskIds, setGlobalSelectedMaskIds]);
|
||||
|
||||
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
|
||||
setSelectedMaskId(null);
|
||||
@@ -324,6 +363,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
addMask(mask);
|
||||
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||
|
||||
const finishPolygon = useCallback(() => {
|
||||
if (polygonPoints.length < 3) return;
|
||||
createManualMask('多边形', polygonPoints);
|
||||
setPolygonPoints([]);
|
||||
}, [createManualMask, polygonPoints]);
|
||||
|
||||
const handleMouseMove = (e: any) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) return;
|
||||
@@ -349,9 +394,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
const runInference = useCallback(async (
|
||||
promptPoints?: PromptPoint[],
|
||||
promptBox?: PromptBox,
|
||||
options: { resetCandidate?: boolean } = {},
|
||||
) => {
|
||||
if (!frame?.id) {
|
||||
console.warn('Inference skipped: no active frame');
|
||||
setInferenceMessage('请先选择一帧图像。');
|
||||
return;
|
||||
}
|
||||
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
|
||||
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -359,31 +413,44 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('Inference skipped: active frame dimensions are unavailable');
|
||||
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
setInferenceMessage('');
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageId: frame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
points: promptPoints && promptPoints.length > 0
|
||||
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
|
||||
: undefined,
|
||||
box: promptBox,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
const [m] = result.masks;
|
||||
if (m) {
|
||||
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
||||
? masks.find((mask) => mask.id === samCandidateMaskId)
|
||||
: null;
|
||||
const label = activeClass?.name || existingCandidate?.label || m.label;
|
||||
const color = activeClass?.color || existingCandidate?.color || m.color;
|
||||
const metadata = {
|
||||
...(existingCandidate?.metadata || {}),
|
||||
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
|
||||
promptBox: promptBox || null,
|
||||
promptPointCount: promptPoints?.length || 0,
|
||||
};
|
||||
const nextMask = {
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
templateId: activeTemplateId || existingCandidate?.templateId || undefined,
|
||||
classId: activeClass?.id || existingCandidate?.classId,
|
||||
className: activeClass?.name || existingCandidate?.className,
|
||||
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
|
||||
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label,
|
||||
@@ -392,14 +459,33 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
});
|
||||
});
|
||||
metadata,
|
||||
};
|
||||
if (existingCandidate) {
|
||||
updateMask(existingCandidate.id, nextMask);
|
||||
setSelectedMaskId(existingCandidate.id);
|
||||
setSelectedMaskIds([existingCandidate.id]);
|
||||
} else {
|
||||
const id = m.id;
|
||||
setSamCandidateMaskId(id);
|
||||
setSelectedMaskId(id);
|
||||
setSelectedMaskIds([id]);
|
||||
addMask({
|
||||
id,
|
||||
...nextMask,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Inference failed:', err);
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
@@ -427,6 +513,29 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
clearMasks();
|
||||
};
|
||||
|
||||
const deleteMasksById = useCallback((maskIds: string[]) => {
|
||||
if (maskIds.length === 0) return;
|
||||
const idSet = new Set(maskIds);
|
||||
const deletingMasks = masks.filter((mask) => idSet.has(mask.id));
|
||||
if (deletingMasks.length === 0) return;
|
||||
setMasks(masks.filter((mask) => !idSet.has(mask.id)));
|
||||
const annotationIds = deletingMasks
|
||||
.map((mask) => mask.annotationId)
|
||||
.filter((annotationId): annotationId is string => Boolean(annotationId));
|
||||
if (annotationIds.length > 0) {
|
||||
void onDeleteMaskAnnotations?.(annotationIds);
|
||||
}
|
||||
if (samCandidateMaskId && idSet.has(samCandidateMaskId)) {
|
||||
setSamCandidateMaskId(null);
|
||||
setSamPromptBox(null);
|
||||
setPoints([]);
|
||||
}
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||
const pos = stagePoint(e);
|
||||
@@ -476,7 +585,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const y2 = Math.max(boxStart.y, boxCurrent.y);
|
||||
|
||||
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
|
||||
runInference(undefined, { x1, y1, x2, y2 });
|
||||
const nextBox = { x1, y1, x2, y2 };
|
||||
setPoints([]);
|
||||
setSamPromptBox(nextBox);
|
||||
setSamCandidateMaskId(null);
|
||||
runInference([], nextBox, { resetCandidate: true });
|
||||
}
|
||||
|
||||
setBoxStart(null);
|
||||
@@ -500,6 +613,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (effectiveTool === POLYGON_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
const closeRadius = POLYGON_CLOSE_RADIUS / Math.max(scale, 0.1);
|
||||
if (polygonPoints.length >= 3 && pointDistance(pos, polygonPoints[0]) <= closeRadius) {
|
||||
finishPolygon();
|
||||
return;
|
||||
}
|
||||
setPolygonPoints((current) => [...current, pos]);
|
||||
}
|
||||
return;
|
||||
@@ -514,8 +632,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
|
||||
];
|
||||
setPoints(newPoints);
|
||||
// Auto-trigger inference after point selection
|
||||
runInference(newPoints);
|
||||
runInference(newPoints, samPromptBox || undefined);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -535,14 +652,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
});
|
||||
}, [updateMask]);
|
||||
|
||||
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
|
||||
const updateMaskFromSegmentation = useCallback((
|
||||
mask: Mask,
|
||||
segmentation: number[][],
|
||||
options: { area?: number; hasHoles?: boolean } = {},
|
||||
): Mask => {
|
||||
const bbox = segmentationBbox(segmentation);
|
||||
const metadata = { ...(mask.metadata || {}) };
|
||||
if (options.hasHoles === true) metadata.hasHoles = true;
|
||||
if (options.hasHoles === false) delete metadata.hasHoles;
|
||||
return {
|
||||
...mask,
|
||||
pathData: segmentationPath(segmentation),
|
||||
segmentation,
|
||||
bbox,
|
||||
area: segmentationArea(segmentation),
|
||||
area: options.area ?? segmentationArea(segmentation),
|
||||
metadata,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
@@ -572,11 +697,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
}
|
||||
return;
|
||||
}
|
||||
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask) {
|
||||
event.preventDefault();
|
||||
const ids = selectedMaskIds.length > 0 ? selectedMaskIds : [selectedMask.id];
|
||||
deleteMasksById(ids);
|
||||
return;
|
||||
}
|
||||
if (effectiveTool !== POLYGON_TOOL) return;
|
||||
if (event.key === 'Enter' && polygonPoints.length >= 3) {
|
||||
event.preventDefault();
|
||||
createManualMask('多边形', polygonPoints);
|
||||
setPolygonPoints([]);
|
||||
finishPolygon();
|
||||
}
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault();
|
||||
@@ -586,7 +716,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
@@ -623,8 +753,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
|
||||
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||
if (effectiveTool !== 'move' && !isBooleanTool) return;
|
||||
event.cancelBubble = true;
|
||||
if (BOOLEAN_TOOLS.has(effectiveTool)) {
|
||||
if (isBooleanTool) {
|
||||
setSelectedMaskIds((current) => (
|
||||
current.includes(mask.id)
|
||||
? current.filter((id) => id !== mask.id)
|
||||
@@ -703,7 +834,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
return;
|
||||
}
|
||||
|
||||
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
|
||||
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation, {
|
||||
area: multiPolygonArea(resultGeometry),
|
||||
hasHoles: multiPolygonHasHoles(resultGeometry),
|
||||
});
|
||||
const secondaryIds = effectiveTool === 'area_merge'
|
||||
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
|
||||
: new Set<string>();
|
||||
@@ -731,6 +865,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
<span className="text-xs text-cyan-400 font-mono">AI 推理中...</span>
|
||||
</div>
|
||||
)}
|
||||
{!isInferencing && inferenceMessage && (
|
||||
<div className="absolute top-4 right-4 z-20 max-w-xs bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl text-xs leading-relaxed text-gray-300">
|
||||
{inferenceMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Stage
|
||||
width={stageSize.width}
|
||||
@@ -758,21 +897,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
|
||||
<Path
|
||||
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
))}
|
||||
{frameMasks.map((mask) => {
|
||||
const hasHoles = Boolean(mask.metadata?.hasHoles);
|
||||
const paths = hasHoles
|
||||
? [{ data: segmentationPath(mask.segmentation), polygonIndex: 0, fillRule: 'evenodd' }]
|
||||
: (mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => ({
|
||||
data: mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData,
|
||||
polygonIndex,
|
||||
fillRule: undefined,
|
||||
}));
|
||||
return (
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||
{paths.map(({ data, polygonIndex, fillRule }) => (
|
||||
<Path
|
||||
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||
data={data}
|
||||
fill={mask.color}
|
||||
fillRule={fillRule}
|
||||
stroke={mask.color}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Box selection preview */}
|
||||
{boxRect && effectiveTool === 'box_select' && (
|
||||
@@ -804,10 +954,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
key={`poly-point-${index}`}
|
||||
x={point.x}
|
||||
y={point.y}
|
||||
radius={4 / scale}
|
||||
fill="#22d3ee"
|
||||
stroke="#ffffff"
|
||||
radius={(index === 0 && polygonPoints.length >= 3 ? 6 : 4) / scale}
|
||||
fill={index === 0 && polygonPoints.length >= 3 ? '#facc15' : '#22d3ee'}
|
||||
stroke={index === 0 && polygonPoints.length >= 3 ? '#fef3c7' : '#ffffff'}
|
||||
strokeWidth={1 / scale}
|
||||
onClick={(event: any) => {
|
||||
if (index !== 0 || polygonPoints.length < 3) return;
|
||||
event.cancelBubble = true;
|
||||
finishPolygon();
|
||||
}}
|
||||
onTap={(event: any) => {
|
||||
if (index !== 0 || polygonPoints.length < 3) return;
|
||||
event.cancelBubble = true;
|
||||
finishPolygon();
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
|
||||
@@ -827,7 +987,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)))}
|
||||
|
||||
{/* Polygon edge insertion handles */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||
if (!next) return null;
|
||||
return (
|
||||
@@ -846,7 +1006,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
})}
|
||||
|
||||
{/* Polygon vertex editor */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||
x={point.x}
|
||||
@@ -900,13 +1060,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
{frameMasks.length > 0 && (
|
||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
|
||||
<button
|
||||
onClick={handleBooleanOperation}
|
||||
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||
</button>
|
||||
{isBooleanTool && (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs bg-white/5 text-gray-300 border border-white/10 px-2.5 py-1.5 rounded">
|
||||
已选 {booleanSelectedMasks.length}
|
||||
</span>
|
||||
<button
|
||||
onClick={handleBooleanOperation}
|
||||
disabled={booleanSelectedMasks.length < 2}
|
||||
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed disabled:hover:bg-emerald-500/10"
|
||||
>
|
||||
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{activeClass && (
|
||||
<button
|
||||
|
||||
@@ -34,6 +34,65 @@ describe('FrameTimeline', () => {
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
});
|
||||
|
||||
it('shows current and total timeline time based on project fps', () => {
|
||||
useStore.setState({
|
||||
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getAllByText('00:00.10').length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('changes frames with left and right arrow keys without leaving bounds', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
fireEvent.keyDown(window, { key: 'ArrowRight' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'ArrowRight' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'ArrowLeft' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('does not change frames while typing in editable fields', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(
|
||||
<>
|
||||
<input aria-label="annotation-name" />
|
||||
<FrameTimeline />
|
||||
</>,
|
||||
);
|
||||
fireEvent.keyDown(screen.getByLabelText('annotation-name'), { key: 'ArrowRight' });
|
||||
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('plays forward using the project parse fps and stops at the end', () => {
|
||||
vi.useFakeTimers();
|
||||
useStore.setState({
|
||||
|
||||
@@ -16,6 +16,20 @@ export function FrameTimeline() {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.min(Math.max(fps, 1), 30);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
const timeBaseFps = useMemo(() => {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.max(fps, 1);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
|
||||
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
|
||||
|
||||
const formatTime = (seconds: number) => {
|
||||
const safeSeconds = Math.max(0, seconds);
|
||||
const minutes = Math.floor(safeSeconds / 60);
|
||||
const wholeSeconds = Math.floor(safeSeconds % 60);
|
||||
const centiseconds = Math.floor((safeSeconds % 1) * 100);
|
||||
return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPlaying || totalFrames <= 1) return;
|
||||
@@ -38,6 +52,30 @@ export function FrameTimeline() {
|
||||
}
|
||||
}, [totalFrames]);
|
||||
|
||||
useEffect(() => {
|
||||
const isEditableTarget = (target: EventTarget | null) => {
|
||||
if (!(target instanceof HTMLElement)) return false;
|
||||
const tagName = target.tagName.toLowerCase();
|
||||
return target.isContentEditable || ['input', 'textarea', 'select'].includes(tagName);
|
||||
};
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (isEditableTarget(event.target) || totalFrames <= 1) return;
|
||||
if (event.key !== 'ArrowLeft' && event.key !== 'ArrowRight') return;
|
||||
|
||||
event.preventDefault();
|
||||
setIsPlaying(false);
|
||||
const direction = event.key === 'ArrowRight' ? 1 : -1;
|
||||
const nextIndex = Math.min(Math.max(currentFrameIndex + direction, 0), totalFrames - 1);
|
||||
if (nextIndex !== currentFrameIndex) {
|
||||
setCurrentFrame(nextIndex);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [currentFrameIndex, setCurrentFrame, totalFrames]);
|
||||
|
||||
// show frames around current frame
|
||||
const frameWindow = 20;
|
||||
const displayIndices = totalFrames > 0
|
||||
@@ -47,6 +85,12 @@ export function FrameTimeline() {
|
||||
return (
|
||||
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
|
||||
<div className="h-4 bg-[#0d0d0d] flex items-center group relative">
|
||||
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
<div className="absolute right-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(totalSeconds)}
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
@@ -65,6 +109,12 @@ export function FrameTimeline() {
|
||||
className="w-3 h-3 bg-white rounded-full absolute top-1/2 -translate-y-1/2 -ml-1.5 shadow-sm transform scale-0 group-hover:scale-100 transition-transform shadow-cyan-500/50"
|
||||
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
/>
|
||||
<div
|
||||
className="absolute -top-7 -translate-x-1/2 rounded bg-black/80 border border-white/10 px-2 py-0.5 text-[10px] font-mono text-cyan-300 opacity-0 group-hover:opacity-100 transition-opacity pointer-events-none"
|
||||
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
>
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -129,6 +179,9 @@ export function FrameTimeline() {
|
||||
|
||||
<div className="w-48 text-right shrink-0">
|
||||
<div className="text-2xl font-mono text-white">{currentFrame}<span className="text-xs text-gray-500"> / {totalFrames}</span></div>
|
||||
<div className="text-xs font-mono text-cyan-300 mt-1">
|
||||
{formatTime(currentSeconds)} <span className="text-gray-600">/</span> {formatTime(totalSeconds)}
|
||||
</div>
|
||||
<div className="text-[10px] text-gray-500 uppercase tracking-widest mt-1">底层时序视频图层截帧导航轴</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -45,6 +45,41 @@ describe('OntologyInspector', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('applies the selected class to currently selected masks', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
fireEvent.click(screen.getByText('肝脏'));
|
||||
|
||||
expect(useStore.getState().activeClassId).toBe('c2');
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
templateId: 't1',
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classZIndex: 10,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
|
||||
expect(screen.getByText('1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds custom classes locally without backend persistence', () => {
|
||||
const { container } = render(<OntologyInspector />);
|
||||
const customSection = screen.getByText('自定义分类').parentElement!;
|
||||
|
||||
@@ -10,6 +10,9 @@ export function OntologyInspector() {
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClassId = useStore((state) => state.activeClassId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
|
||||
const setActiveClass = useStore((state) => state.setActiveClass);
|
||||
|
||||
@@ -28,6 +31,25 @@ export function OntologyInspector() {
|
||||
setActiveTemplateId(activeTemplate.id);
|
||||
}
|
||||
setActiveClass(templateClass);
|
||||
const selectedIdSet = new Set(selectedMaskIds);
|
||||
const hasSelectedMasks = masks.some((mask) => selectedIdSet.has(mask.id));
|
||||
if (!hasSelectedMasks) return;
|
||||
|
||||
const templateId = activeTemplate?.id || activeTemplateId || undefined;
|
||||
setMasks(masks.map((mask) => {
|
||||
if (!selectedIdSet.has(mask.id)) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: templateId || mask.templateId,
|
||||
classId: templateClass.id,
|
||||
className: templateClass.name,
|
||||
classZIndex: templateClass.zIndex,
|
||||
label: templateClass.name,
|
||||
color: templateClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
}));
|
||||
};
|
||||
|
||||
const handleAddCustom = () => {
|
||||
@@ -164,6 +186,10 @@ export function OntologyInspector() {
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] text-gray-500 uppercase">当前选中区域:</span>
|
||||
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<label className="text-[10px] text-gray-500 uppercase">感知算法置信度</label>
|
||||
<div className="h-1.5 w-full bg-white/10 rounded-full overflow-hidden">
|
||||
|
||||
@@ -82,4 +82,65 @@ describe('TemplateRegistry', () => {
|
||||
|
||||
expect(screen.getByText('分类A')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('edits an existing template through the backend and store', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '旧模板',
|
||||
description: 'old desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 3,
|
||||
},
|
||||
]);
|
||||
apiMock.updateTemplate.mockResolvedValueOnce({
|
||||
id: 't1',
|
||||
name: '新模板',
|
||||
description: 'new desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
});
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(await screen.findByRole('button', { name: /修改库视图结构/ }));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: '新模板' } });
|
||||
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'new desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
|
||||
name: '新模板',
|
||||
description: 'new desc',
|
||||
classes: [expect.objectContaining({ id: 'c1', name: '胆囊' })],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 3,
|
||||
})));
|
||||
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({
|
||||
id: 't1',
|
||||
name: '新模板',
|
||||
}));
|
||||
});
|
||||
|
||||
it('deletes an existing template after confirmation', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '待删除模板',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
},
|
||||
]);
|
||||
apiMock.deleteTemplate.mockResolvedValueOnce(undefined);
|
||||
const { container } = render(<TemplateRegistry />);
|
||||
|
||||
await screen.findAllByText('待删除模板');
|
||||
const buttons = Array.from(container.querySelectorAll('button'));
|
||||
fireEvent.click(buttons[2]);
|
||||
|
||||
await waitFor(() => expect(apiMock.deleteTemplate).toHaveBeenCalledWith('t1'));
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ import { VideoWorkspace } from './VideoWorkspace';
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjectFrames: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
propagateMasks: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
getTemplates: vi.fn(),
|
||||
getProjectAnnotations: vi.fn(),
|
||||
@@ -24,6 +25,7 @@ const apiMock = vi.hoisted(() => ({
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjectFrames: apiMock.getProjectFrames,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
propagateMasks: apiMock.propagateMasks,
|
||||
getTask: apiMock.getTask,
|
||||
getTemplates: apiMock.getTemplates,
|
||||
getProjectAnnotations: apiMock.getProjectAnnotations,
|
||||
@@ -47,6 +49,14 @@ describe('VideoWorkspace', () => {
|
||||
apiMock.getProjectAnnotations.mockResolvedValue([]);
|
||||
apiMock.annotationToMask.mockReturnValue(null);
|
||||
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
|
||||
apiMock.propagateMasks.mockResolvedValue({
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
source_frame_id: 10,
|
||||
processed_frame_count: 3,
|
||||
created_annotation_count: 2,
|
||||
annotations: [],
|
||||
});
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
|
||||
@@ -320,4 +330,64 @@ describe('VideoWorkspace', () => {
|
||||
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
|
||||
]));
|
||||
});
|
||||
|
||||
it('propagates the selected current-frame mask through the backend video tracker', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
aiModel: 'sam2',
|
||||
activeTemplateId: '2',
|
||||
selectedMaskIds: ['mask-1'],
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '传播片段' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
points: undefined,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
template_id: 2,
|
||||
},
|
||||
}));
|
||||
await waitFor(() => expect(screen.getByText('已传播并保存 2 个区域')).toBeInTheDocument());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
getTemplates,
|
||||
importGtMask,
|
||||
parseMedia,
|
||||
propagateMasks,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
} from '../lib/api';
|
||||
@@ -37,6 +38,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const maskHistory = useStore((state) => state.maskHistory);
|
||||
const maskFuture = useStore((state) => state.maskFuture);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
@@ -45,6 +48,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
const [isImportingGt, setIsImportingGt] = useState(false);
|
||||
const [isPropagating, setIsPropagating] = useState(false);
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||
@@ -102,6 +106,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
@@ -117,6 +123,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
@@ -314,6 +322,55 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
};
|
||||
|
||||
const handlePropagateSegment = async () => {
|
||||
if (!currentProject?.id || !currentFrame?.id) return;
|
||||
const currentFrameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
|
||||
const selectedMask = selectedMaskIds
|
||||
.map((id) => currentFrameMasks.find((mask) => mask.id === id))
|
||||
.find((mask): mask is NonNullable<typeof mask> => Boolean(mask));
|
||||
const seedMask = selectedMask || currentFrameMasks[0];
|
||||
if (!seedMask) {
|
||||
setStatusMessage('请先选择或创建一个当前帧区域');
|
||||
return;
|
||||
}
|
||||
|
||||
const seedPayload = buildAnnotationPayload(currentProject.id, seedMask, currentFrame, activeTemplateId);
|
||||
if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) {
|
||||
setStatusMessage('当前区域缺少可传播的 polygon 或 bbox');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsPropagating(true);
|
||||
setStatusMessage(`${aiModel.toUpperCase()} 正在传播当前区域...`);
|
||||
try {
|
||||
const result = await propagateMasks({
|
||||
project_id: Number(currentProject.id),
|
||||
frame_id: Number(currentFrame.id),
|
||||
model: aiModel,
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed: {
|
||||
polygons: seedPayload.mask_data?.polygons,
|
||||
bbox: seedPayload.bbox,
|
||||
points: seedPayload.points,
|
||||
label: seedPayload.mask_data?.label,
|
||||
color: seedPayload.mask_data?.color,
|
||||
class_metadata: seedPayload.mask_data?.class,
|
||||
template_id: seedPayload.template_id,
|
||||
},
|
||||
});
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
setStatusMessage(`已传播并保存 ${result.created_annotation_count} 个区域`);
|
||||
} catch (err) {
|
||||
console.error('Propagation failed:', err);
|
||||
setStatusMessage('传播失败,请检查模型状态或后端日志');
|
||||
} finally {
|
||||
setIsPropagating(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
{/* Top Header / Status bar */}
|
||||
@@ -339,28 +396,35 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
/>
|
||||
<button
|
||||
onClick={() => gtMaskInputRef.current?.click()}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isImportingGt ? '导入中...' : '导入 GT Mask'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handlePropagateSegment}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isPropagating ? '传播中...' : '传播片段'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExportMasks}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExport}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 JSON 标注集'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!currentProject?.id || isSaving || isExporting}
|
||||
disabled={!currentProject?.id || isSaving || isExporting || isPropagating}
|
||||
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isSaving ? '保存中...' : '结构化归档保存'}
|
||||
|
||||
@@ -159,9 +159,9 @@ describe('api client contracts', () => {
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
|
||||
|
||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||
await expect(parseMedia('9', { parseFps: 15, maxFrames: 120, targetWidth: 960 })).resolves.toEqual(task);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||
params: { project_id: '9' },
|
||||
params: { project_id: '9', parse_fps: 15, max_frames: 120, target_width: 960 },
|
||||
});
|
||||
|
||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||
@@ -175,7 +175,7 @@ describe('api client contracts', () => {
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const saved = {
|
||||
id: 1,
|
||||
project_id: 9,
|
||||
@@ -221,6 +221,43 @@ describe('api client contracts', () => {
|
||||
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
|
||||
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
source_frame_id: 5,
|
||||
processed_frame_count: 3,
|
||||
created_annotation_count: 2,
|
||||
annotations: [saved],
|
||||
},
|
||||
});
|
||||
await expect(propagateMasks({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2',
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
})).resolves.toEqual(expect.objectContaining({ created_annotation_count: 2 }));
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2',
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
}, {
|
||||
timeout: 600000,
|
||||
});
|
||||
});
|
||||
|
||||
it('imports GT masks through multipart form data', async () => {
|
||||
@@ -377,6 +414,33 @@ describe('api client contracts', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '5',
|
||||
imageWidth: 640,
|
||||
imageHeight: 320,
|
||||
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
|
||||
points: [
|
||||
{ x: 128, y: 64, type: 'pos' },
|
||||
{ x: 256, y: 128, type: 'neg' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 5,
|
||||
prompt_type: 'interactive',
|
||||
prompt_data: {
|
||||
box: [0.1, 0.1, 0.5, 0.5],
|
||||
points: [[0.2, 0.2], [0.4, 0.4]],
|
||||
labels: [1, 0],
|
||||
},
|
||||
model: 'sam2',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses semantic prompt type for text-only AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
@@ -153,6 +153,8 @@ export async function getProjectFrames(projectId: string): Promise<Array<{
|
||||
image_url: string;
|
||||
width: number | null;
|
||||
height: number | null;
|
||||
timestamp_ms?: number | null;
|
||||
source_frame_number?: number | null;
|
||||
}>> {
|
||||
const response = await apiClient.get(`/api/projects/${projectId}/frames`);
|
||||
return response.data;
|
||||
@@ -185,9 +187,18 @@ export interface ProcessingTask {
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
|
||||
export async function parseMedia(projectId: string, options: {
|
||||
parseFps?: number;
|
||||
maxFrames?: number;
|
||||
targetWidth?: number;
|
||||
} = {}): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post('/api/media/parse', null, {
|
||||
params: { project_id: projectId },
|
||||
params: {
|
||||
project_id: projectId,
|
||||
...(options.parseFps ? { parse_fps: options.parseFps } : {}),
|
||||
...(options.maxFrames ? { max_frames: options.maxFrames } : {}),
|
||||
...(options.targetWidth ? { target_width: options.targetWidth } : {}),
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
@@ -312,6 +323,40 @@ export interface SaveAnnotationPayload {
|
||||
|
||||
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
|
||||
|
||||
export interface PropagateMasksPayload {
|
||||
project_id: number;
|
||||
frame_id: number;
|
||||
model?: AiModelId;
|
||||
seed: {
|
||||
polygons?: number[][][];
|
||||
bbox?: number[];
|
||||
points?: number[][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class_metadata?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
template_id?: number;
|
||||
};
|
||||
direction?: 'forward' | 'backward' | 'both';
|
||||
max_frames?: number;
|
||||
include_source?: boolean;
|
||||
save_annotations?: boolean;
|
||||
}
|
||||
|
||||
export interface PropagateMasksResult {
|
||||
model: AiModelId;
|
||||
direction: string;
|
||||
source_frame_id: number;
|
||||
processed_frame_count: number;
|
||||
created_annotation_count: number;
|
||||
annotations: SavedAnnotation[];
|
||||
}
|
||||
|
||||
export interface DashboardTask {
|
||||
id: string;
|
||||
task_id?: number;
|
||||
@@ -474,10 +519,22 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
|
||||
}
|
||||
|
||||
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
|
||||
let prompt_type: 'point' | 'box' | 'semantic';
|
||||
let prompt_type: 'point' | 'box' | 'semantic' | 'interactive';
|
||||
let prompt_data: unknown;
|
||||
|
||||
if (payload.box) {
|
||||
if (payload.box && payload.points && payload.points.length > 0) {
|
||||
prompt_type = 'interactive';
|
||||
prompt_data = {
|
||||
box: [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
|
||||
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
|
||||
],
|
||||
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
|
||||
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
|
||||
};
|
||||
} else if (payload.box) {
|
||||
prompt_type = 'box';
|
||||
prompt_data = [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
@@ -540,6 +597,13 @@ export async function getProjectAnnotations(projectId: string, frameId?: string)
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function propagateMasks(payload: PropagateMasksPayload): Promise<PropagateMasksResult> {
|
||||
const response = await apiClient.post('/api/ai/propagate', payload, {
|
||||
timeout: 600000,
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.post('/api/ai/annotate', payload);
|
||||
return response.data;
|
||||
|
||||
@@ -30,6 +30,7 @@ describe('useStore', () => {
|
||||
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
|
||||
useStore.getState().setCurrentFrame(0);
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
|
||||
useStore.getState().setSelectedMaskIds(['m1']);
|
||||
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
|
||||
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
|
||||
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
|
||||
@@ -40,6 +41,7 @@ describe('useStore', () => {
|
||||
expect(useStore.getState().currentProject?.id).toBe('1');
|
||||
expect(useStore.getState().frames).toHaveLength(1);
|
||||
expect(useStore.getState().currentFrameIndex).toBe(0);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
|
||||
expect(useStore.getState().annotations).toHaveLength(1);
|
||||
expect(useStore.getState().templates[0].name).toBe('Template 2');
|
||||
@@ -51,6 +53,7 @@ describe('useStore', () => {
|
||||
|
||||
expect(useStore.getState().annotations).toEqual([]);
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual([]);
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ export interface Frame {
|
||||
width: number;
|
||||
height: number;
|
||||
timestamp?: string;
|
||||
timestampMs?: number;
|
||||
sourceFrameNumber?: number;
|
||||
}
|
||||
|
||||
export interface Annotation {
|
||||
@@ -112,6 +114,7 @@ export interface AppState {
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
selectedMaskIds: string[];
|
||||
maskHistory: Mask[][];
|
||||
maskFuture: Mask[][];
|
||||
setActiveModule: (module: string) => void;
|
||||
@@ -123,6 +126,7 @@ export interface AppState {
|
||||
addMask: (mask: Mask) => void;
|
||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
setSelectedMaskIds: (ids: string[]) => void;
|
||||
clearMasks: () => void;
|
||||
undoMasks: () => void;
|
||||
redoMasks: () => void;
|
||||
@@ -167,6 +171,7 @@ export const useStore = create<AppState>((set) => ({
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
activeTemplateId: null,
|
||||
@@ -195,6 +200,7 @@ export const useStore = create<AppState>((set) => ({
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
@@ -227,9 +233,11 @@ export const useStore = create<AppState>((set) => ({
|
||||
maskFuture: [],
|
||||
};
|
||||
}),
|
||||
setSelectedMaskIds: (selectedMaskIds: string[]) => set({ selectedMaskIds }),
|
||||
clearMasks: () =>
|
||||
set((state) => ({
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
})),
|
||||
|
||||
@@ -71,7 +71,11 @@ vi.mock('react-konva', () => ({
|
||||
data-fill={props.fill}
|
||||
data-x={props.x}
|
||||
data-y={props.y}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
onClick={(event) => {
|
||||
const konvaEvent = { cancelBubble: false };
|
||||
props.onClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||
target: {
|
||||
x: () => event.clientX || props.x || 0,
|
||||
@@ -92,7 +96,12 @@ vi.mock('react-konva', () => ({
|
||||
data-testid="konva-path"
|
||||
data-path={props.data}
|
||||
data-fill={props.fill}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
data-fill-rule={props.fillRule}
|
||||
onClick={(event) => {
|
||||
const konvaEvent = { cancelBubble: false };
|
||||
props.onClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
@@ -13,6 +13,7 @@ export function resetStore() {
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
templates: [],
|
||||
|
||||
Reference in New Issue
Block a user