feat: 完善视频传播、标注编辑和拆帧闭环

- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。
- 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。
- 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。
- 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。
- 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。
- 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。
- 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。
- 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。
- 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。
- 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。
- 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。
- 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。
- 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。
- 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。
- 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。

验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
2026-05-01 20:27:33 +08:00
parent 689a9ba283
commit 5ab4602535
43 changed files with 2722 additions and 216 deletions

View File

@@ -19,4 +19,5 @@ VITE_WS_PROGRESS_URL="ws://192.168.3.11:8000/ws/progress"
sam_default_model="sam2"
sam_model_path="/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config="configs/sam2/sam2_hiera_t.yaml"
sam3_model_version="sam3.1"
sam3_model_version="sam3"
sam3_checkpoint_path="/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@ coverage/
!.env.example
# Data & Models
models/
sam3权重/
uploads/
frames/
minio_data/

View File

@@ -6,7 +6,7 @@
## 项目概述
本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
- **项目名称**: `react-example``package.json` 中的 `name`
- **前端入口**: `src/main.tsx``src/App.tsx`
@@ -39,7 +39,7 @@
| 缓存 / 队列 Broker | Redis |
| 后台任务 | Celery worker |
| 对象存储 | MinIO |
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorchSAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 |
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorchSAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/本地 checkpoint 状态 |
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
| 运行时 | Node.js ES ModulesPython 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 |
@@ -78,12 +78,12 @@ Seg_Server/
│ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
│ │ ├── templates.py # /api/templates
│ │ ├── media.py # /api/media/upload、/upload/dicom、/parse
│ │ ├── ai.py # /api/ai/predict、/models/status、/auto、/annotate
│ │ ├── ai.py # /api/ai/predict、/propagate、/models/status、/auto、/annotate
│ │ └── export.py # /api/export/{project_id}/coco、/masks
│ └── services/
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接文本语义推理适配器
│ ├── sam2_engine.py # SAM 2 单帧推理和 video predictor 传播封装
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接文本语义推理、框选与 video tracker 适配器
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
└── src/ # React 前端
@@ -194,6 +194,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- `POST /api/tasks/{task_id}/cancel`
- `POST /api/tasks/{task_id}/retry`
- `POST /api/ai/predict`
- `POST /api/ai/propagate`
- `GET /api/ai/models/status`
- `POST /api/ai/auto`
- `POST /api/ai/annotate`
@@ -219,14 +220,15 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
3. 上传资源:视频走 `/api/media/upload`DICOM 批量走 `/api/media/upload/dicom`
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom,并持续更新任务进度
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/differenceZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为可用。
9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新
10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树
11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps``max_frames``target_width` 标准帧序列参数
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms``source_frame_number` 和任务 `frame_sequence` 元数据
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 holeZustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理、框选提示和 external video tracker,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。
9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`后端按项目帧序列下载片段帧SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树
12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`
---
@@ -240,6 +242,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`SAM 2 路径使用视频 predictorSAM 3 路径使用独立 Python helper 的官方 video tracker完成后刷新后端已保存标注。
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
- 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error`
- `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。

View File

@@ -6,14 +6,14 @@
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
>
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、视频片段传播、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2SAM 3 支持语义文本、框选提示和 video tracker,前端会显示真实 GPU/模型状态。
---
## 核心功能
- **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像的上传、存储与解析
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box自动分割autoSAM 3 入口支持文本语义提示并按真实运行环境显示可用性
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box自动分割auto和 video predictor 传播SAM 3 入口支持文本语义提示、框选提示和 external video tracker并按真实运行环境显示可用性
- **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index
@@ -104,8 +104,8 @@ Seg_Server/
│ │ ├── ai.py # SAM 推理与模型状态接口
│ │ └── export.py # 数据导出
│ └── services/ # 业务服务
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接文本语义推理适配器
│ ├── sam2_engine.py # SAM 2 推理引擎(单帧推理 + video predictor 传播
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接文本语义推理、框选与 video tracker 适配器
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
@@ -255,12 +255,11 @@ python download_sam2.py
cd ~/Desktop/Seg_Server
./backend/setup_sam3_env.sh
# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录
conda activate sam3
huggingface-cli login
# 如果已把权重放在 sam3权重/sam3.pt可直接走本地 checkpoint
# 未配置本地 checkpoint 时,才需要 Hugging Face gated repo 授权和登录。
```
官方 `facebook/sam3` 权重约 3.45 GB当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度`facebook/sam3.1` 约 3.5 GB主要面向新的视频 multiplex checkpoint未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
官方 `facebook/sam3` 权重约 3.45 GB当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度。当前仓库默认使用本机 `sam3权重/sam3.pt`,不会提交权重文件;未配置本地 checkpoint未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
### 步骤 6: 配置环境变量
@@ -278,6 +277,7 @@ sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
sam_model_config=configs/sam2/sam2_hiera_t.yaml
sam_default_model=sam2
sam3_model_version=sam3
sam3_checkpoint_path=/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt
sam3_external_enabled=true
sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
sam3_timeout_seconds=300
@@ -311,6 +311,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
- 测试 Redis 连接
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt``auto_filter_background``min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播SAM 2 使用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker
### 步骤 6.1: 启动 Celery Worker
@@ -324,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
```
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps``max_frames``target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms``source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
### 步骤 7: 安装前端依赖并构建
@@ -460,6 +461,7 @@ pip install -e . --no-build-isolation
- 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData`
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`
- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed调用 `/api/ai/propagate`,并在完成后刷新已保存标注。
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程导出前会先保存当前待归档的前端 mask。
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。

View File

@@ -23,6 +23,7 @@ class Settings(BaseSettings):
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
sam3_model_version: str = "sam3"
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
sam3_external_enabled: bool = True
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
sam3_timeout_seconds: int = 300

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import inspect, text
from config import settings
from database import Base, engine, SessionLocal
@@ -30,6 +31,20 @@ logger = logging.getLogger(__name__)
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
def _ensure_runtime_schema_columns() -> None:
"""Add nullable columns introduced after initial create_all deployments."""
try:
inspector = inspect(engine)
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
with engine.begin() as connection:
if "timestamp_ms" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
if "source_frame_number" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
except Exception as exc: # noqa: BLE001
logger.warning("Runtime schema column check failed: %s", exc)
def _seed_default_project_sync() -> None:
"""Synchronously seed the default video project on first startup."""
import cv2
@@ -93,12 +108,16 @@ def _seed_default_project_sync() -> None:
for idx, obj_name in enumerate(object_names):
img = cv2.imread(frame_files[idx])
h, w = img.shape[:2] if img is not None else (None, None)
timestamp_ms = idx * 1000.0 / 30.0
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=timestamp_ms,
source_frame_number=source_frame_number,
)
db.add(frame)
@@ -176,6 +195,7 @@ async def lifespan(app: FastAPI):
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
_ensure_runtime_schema_columns()
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)

View File

@@ -56,6 +56,8 @@ class Frame(Base):
image_url = Column(String(512), nullable=False)
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
timestamp_ms = Column(Float, nullable=True)
source_frame_number = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
project = relationship("Project", back_populates="frames")

View File

@@ -1,6 +1,8 @@
"""AI inference endpoints using selectable SAM runtimes."""
import logging
import tempfile
from pathlib import Path
from typing import Any, List
import cv2
@@ -15,6 +17,8 @@ from schemas import (
AiRuntimeStatus,
PredictRequest,
PredictResponse,
PropagateRequest,
PropagateResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
@@ -66,6 +70,48 @@ def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
]
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
if direction == "both":
before = (count - 1) // 2
after = count - 1 - before
start = max(0, source_position - before)
end = min(len(frames), source_position + after + 1)
while end - start < count and start > 0:
start -= 1
while end - start < count and end < len(frames):
end += 1
return frames[start:end], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
path = directory / f"frame_{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
@@ -184,6 +230,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
@@ -246,6 +293,51 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "interactive":
prompt = payload.prompt_data
if not isinstance(prompt, dict):
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
box = prompt.get("box")
points = prompt.get("points") or []
labels = prompt.get("labels")
if box is not None and (not isinstance(box, list) or len(box) != 4):
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
if not isinstance(points, list):
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
if not box and len(points) == 0:
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_box = box
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_points = list(points)
if box:
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
crop_bounds = _crop_bounds_from_points(crop_points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
if box:
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_interactive(
payload.model,
inference_image,
inference_box,
inference_points,
labels,
)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
@@ -276,6 +368,124 @@ def model_status(selected_model: str | None = None) -> dict:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/propagate",
response_model=PropagateResponse,
summary="Propagate one current-frame region across a video frame segment",
)
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
"""Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed.
SAM 3 uses the external Python 3.12 video tracker with the seed bbox.
"""
direction = payload.direction.lower()
if direction not in {"forward", "backward", "both"}:
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
bbox = seed.get("bbox")
points = seed.get("points") or []
if not polygons and not bbox and not points:
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
if len(selected_frames) == 0:
raise HTTPException(status_code=400, detail="No frames available for propagation")
try:
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
payload.model,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # noqa: BLE001
logger.error("Video propagation failed: %s", exc)
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
created: list[Annotation] = []
if payload.save_annotations:
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.model)
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not payload.include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
scores = frame_result.get("scores") or []
for polygon_index, polygon in enumerate(result_polygons):
if len(polygon) < 3:
continue
annotation = Annotation(
project_id=payload.project_id,
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return {
"model": sam_registry.normalize_model_id(payload.model),
"direction": direction,
"source_frame_id": source_frame.id,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"annotations": created,
}
@router.post(
"/auto",
response_model=PredictResponse,

View File

@@ -4,7 +4,7 @@ import logging
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
@@ -169,6 +169,9 @@ async def upload_dicom_batch(
def parse_media(
project_id: int,
source_type: Optional[str] = None,
parse_fps: Optional[float] = Query(None, gt=0, le=120),
max_frames: Optional[int] = Query(None, gt=0),
target_width: int = Query(640, ge=64, le=4096),
db: Session = Depends(get_db),
) -> ProcessingTask:
"""Create a background task for media frame extraction.
@@ -184,14 +187,21 @@ def parse_media(
raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video"
effective_parse_fps = parse_fps or project.parse_fps or 30.0
task = ProcessingTask(
task_type=f"parse_{effective_source}",
status=TASK_STATUS_QUEUED,
progress=0,
message="解析任务已入队",
project_id=project_id,
payload={"source_type": effective_source},
payload={
"source_type": effective_source,
"parse_fps": effective_parse_fps,
"max_frames": max_frames,
"target_width": target_width,
},
)
project.parse_fps = effective_parse_fps
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()

View File

@@ -51,6 +51,8 @@ class FrameBase(BaseModel):
image_url: str
width: Optional[int] = None
height: Optional[int] = None
timestamp_ms: Optional[float] = None
source_frame_number: Optional[int] = None
class FrameCreate(FrameBase):
@@ -188,6 +190,37 @@ class PredictResponse(BaseModel):
scores: Optional[list[float]] = None
class PropagationSeed(BaseModel):
polygons: Optional[list[list[list[float]]]] = None
bbox: Optional[list[float]] = None
points: Optional[list[list[float]]] = None
labels: Optional[list[int]] = None
label: Optional[str] = None
color: Optional[str] = None
class_metadata: Optional[dict[str, Any]] = None
template_id: Optional[int] = None
class PropagateRequest(BaseModel):
project_id: int
frame_id: int
model: Optional[str] = "sam2"
seed: PropagationSeed
direction: str = "forward"
max_frames: int = 30
include_source: bool = False
save_annotations: bool = True
class PropagateResponse(BaseModel):
model: str
direction: str
source_frame_id: int
processed_frame_count: int
created_annotation_count: int
annotations: list[AnnotationOut]
class AiModelStatus(BaseModel):
id: str
label: str

View File

@@ -52,6 +52,7 @@ def parse_video(
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
target_width: int = 640,
) -> Tuple[List[str], float]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
@@ -60,6 +61,7 @@ def parse_video(
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
target_width: Output frame width for model-friendly frame sequences.
Returns:
Tuple of (frame_paths, original_fps).
@@ -67,6 +69,8 @@ def parse_video(
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
original_fps = get_video_fps(video_path)
safe_fps = max(int(fps), 1)
safe_width = max(int(target_width), 1)
# Try FFmpeg first
if shutil.which("ffmpeg"):
@@ -75,7 +79,8 @@ def parse_video(
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps},scale=640:-1",
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
"-start_number", "0",
"-q:v", "5",
"-y",
pattern,
@@ -102,7 +107,7 @@ def parse_video(
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / fps)))
interval = max(1, int(round(video_fps / safe_fps)))
count = 0
saved = 0
@@ -112,6 +117,10 @@ def parse_video(
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
h, w = frame.shape[:2]
if safe_width > 0 and w != safe_width:
scale = safe_width / max(w, 1)
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
frame_paths.append(path)
saved += 1

View File

@@ -76,6 +76,38 @@ def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _positive_int(value: Any, default: int | None = None) -> int | None:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _positive_float(value: Any, default: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _frame_sequence_metadata(
index: int,
parse_fps: float,
original_fps: float | None,
) -> dict[str, float | int | None]:
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
timestamp_ms = index * 1000.0 / safe_parse_fps
source_frame_number = None
if original_fps and original_fps > 0:
source_frame_number = int(round(index * original_fps / safe_parse_fps))
return {
"timestamp_ms": timestamp_ms,
"source_frame_number": source_frame_number,
}
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
@@ -138,8 +170,12 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
payload = task.payload or {}
effective_source = payload.get("source_type") or project.source_type or "video"
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
max_frames = _positive_int(payload.get("max_frames"))
target_width = _positive_int(payload.get("target_width"), 640) or 640
project.parse_fps = parse_fps
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
@@ -163,7 +199,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path)
@@ -173,7 +209,13 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
frame_files, original_fps = parse_video(
local_path,
output_dir,
fps=int(parse_fps),
max_frames=max_frames,
target_width=target_width,
)
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
@@ -205,12 +247,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
except Exception: # noqa: BLE001
h, w = None, None
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=sequence_meta["timestamp_ms"],
source_frame_number=sequence_meta["source_frame_number"],
)
db.add(frame)
frames_out.append(frame)
@@ -223,6 +268,17 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
"frame_sequence": {
"original_fps": project.original_fps,
"parse_fps": parse_fps,
"frame_count": len(frames_out),
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
"target_width": target_width,
"frame_width": frames_out[0].width if frames_out else None,
"frame_height": frames_out[0].height if frames_out else None,
"max_frames": max_frames,
"object_prefix": f"projects/{project.id}/frames",
},
}
_set_task_state(
db,

View File

@@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001
try:
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
@@ -38,9 +39,12 @@ class SAM2Engine:
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._video_predictor = None
self._model_loaded = False
self._video_model_loaded = False
self._loaded_device: str | None = None
self._last_error: str | None = None
self._video_last_error: str | None = None
# -----------------------------------------------------------------------
# Internal helpers
@@ -85,6 +89,40 @@ class SAM2Engine:
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _load_video_model(self) -> None:
"""Load the SAM 2 video predictor on first propagation use."""
if self._video_model_loaded:
return
if not TORCH_AVAILABLE:
self._video_last_error = "PyTorch is not installed."
self._video_model_loaded = True
return
if not SAM2_AVAILABLE:
self._video_last_error = "sam2 package is not installed."
self._video_model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
self._video_model_loaded = True
return
try:
device = self._best_device()
self._video_predictor = build_sam2_video_predictor(
settings.sam_model_config,
settings.sam_model_path,
device=device,
)
self._video_model_loaded = True
self._loaded_device = device
self._video_last_error = None
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device)
except Exception as exc: # noqa: BLE001
self._video_last_error = str(exc)
self._video_model_loaded = True
logger.error("Failed to load SAM 2 video predictor: %s", exc)
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
@@ -95,6 +133,11 @@ class SAM2Engine:
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
def _ensure_video_ready(self) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable."""
self._load_video_model()
return SAM2_AVAILABLE and self._video_predictor is not None
def status(self) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path)
@@ -121,7 +164,7 @@ class SAM2Engine:
"available": available,
"loaded": self._predictor is not None,
"device": device,
"supports": ["point", "box", "auto"],
"supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
@@ -221,6 +264,52 @@ class SAM2Engine:
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_interactive(
self,
image: np.ndarray,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement."""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
bbox = None
if box:
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
pts = None
lbls = None
if points:
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
box=bbox,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
@@ -260,6 +349,89 @@ class SAM2Engine:
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict,
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
if not self._ensure_video_ready():
raise RuntimeError(self._video_last_error or self.status()["message"])
if not frame_paths:
return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
raise ValueError("source_frame_index is outside the frame sequence.")
import cv2
source_image = cv2.imread(frame_paths[source_frame_index])
if source_image is None:
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
height, width = source_image.shape[:2]
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
if not seed_mask.any():
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
seed_mask = self._bbox_to_mask(bbox, width, height)
if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = self._video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True,
offload_state_to_cpu=True,
)
self._video_predictor.add_new_mask(
inference_state,
frame_idx=source_frame_index,
obj_id=1,
mask=seed_mask,
)
results: dict[int, dict] = {}
def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
inference_state,
start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames,
reverse=reverse,
):
masks = out_mask_logits
if hasattr(masks, "detach"):
masks = masks.detach().cpu().numpy()
masks = np.asarray(masks)
if masks.ndim == 4:
masks = masks[:, 0]
polygons = []
scores = []
for mask in masks:
polygon = self._mask_to_polygon(mask > 0)
if polygon:
polygons.append(polygon)
scores.append(1.0)
results[int(out_frame_idx)] = {
"frame_index": int(out_frame_idx),
"polygons": polygons,
"scores": scores,
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
}
normalized_direction = direction.lower()
if normalized_direction in {"forward", "both"}:
collect(reverse=False)
if normalized_direction in {"backward", "both"}:
collect(reverse=True)
try:
self._video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001
pass
return [results[index] for index in sorted(results)]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@@ -292,6 +464,38 @@ class SAM2Engine:
]
]
@staticmethod
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
import cv2
mask = np.zeros((height, width), dtype=np.uint8)
for polygon in polygons:
if len(polygon) < 3:
continue
pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in polygon
],
dtype=np.int32,
)
cv2.fillPoly(mask, [pts], 1)
return mask.astype(bool)
@staticmethod
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
left = int(round(x * max(width - 1, 1)))
top = int(round(y * max(height - 1, 1)))
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
mask = np.zeros((height, width), dtype=bool)
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
return mask
# Singleton instance
sam_engine = SAM2Engine()

View File

@@ -56,8 +56,22 @@ class SAM3Engine:
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _checkpoint_path(self) -> str | None:
path = settings.sam3_checkpoint_path.strip()
return path if path else None
def _checkpoint_exists(self) -> bool:
path = self._checkpoint_path()
return bool(path and os.path.isfile(path))
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
return bool(
SAM3_PACKAGE_AVAILABLE
and TORCH_AVAILABLE
and self._python_ok()
and self._gpu_ok()
and self._checkpoint_exists()
)
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
@@ -98,6 +112,8 @@ class SAM3Engine:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
@@ -146,7 +162,10 @@ class SAM3Engine:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model()
self._model = build_sam3_image_model(
checkpoint_path=self._checkpoint_path(),
load_from_HF=False,
)
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
@@ -170,6 +189,8 @@ class SAM3Engine:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if not self._checkpoint_exists():
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
@@ -182,7 +203,7 @@ class SAM3Engine:
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return {
@@ -191,11 +212,11 @@ class SAM3Engine:
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic"],
"supports": ["semantic", "box", "video_track"],
"message": message,
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True,
@@ -203,7 +224,43 @@ class SAM3Engine:
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
}
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def _predict_external(
self,
image: np.ndarray,
prompt_type: str,
*,
text: str = "",
box: list[float] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
@@ -217,8 +274,11 @@ class SAM3Engine:
json.dumps(
{
"image_path": str(image_path),
"prompt_type": prompt_type,
"text": text.strip(),
"box": box,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
@@ -227,6 +287,8 @@ class SAM3Engine:
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
@@ -250,6 +312,72 @@ class SAM3Engine:
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "semantic", text=text)
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "box", box=box)
def _propagate_video_external(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
) -> list[dict[str, Any]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
if not frame_paths:
return []
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
request_path = Path(tmpdir) / "request.json"
request_path.write_text(
json.dumps(
{
"prompt_type": "video_track",
"frame_dir": str(Path(frame_paths[0]).parent),
"source_frame_index": source_frame_index,
"seed": seed,
"direction": direction,
"max_frames": max_frames,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("frames", [])
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
@@ -263,29 +391,37 @@ class SAM3Engine:
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
return self._prediction_to_polygons(output)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
if not self._can_load() and self._external_status().get("available"):
return self._predict_box_external(image, box)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.add_geometric_prompt(
state=state,
box=self._xyxy_to_cxcywh(box),
label=True,
)
return self._prediction_to_polygons(output)
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict[str, Any]]:
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
sam3_engine = SAM3Engine()

View File

@@ -43,6 +43,13 @@ def _compact_error(exc: Exception) -> str:
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
if checkpoint_path:
path = Path(checkpoint_path)
if path.is_file():
return True, None
return False, f"local checkpoint not found: {checkpoint_path}"
try:
from huggingface_hub import hf_hub_download
@@ -55,6 +62,7 @@ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
def runtime_status() -> dict[str, Any]:
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
package_error = None
package_available = importlib.util.find_spec("sam3") is not None
if package_available:
@@ -85,6 +93,7 @@ def runtime_status() -> dict[str, Any]:
"available": available,
"package_available": package_available,
"checkpoint_access": checkpoint_access,
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
"python_ok": python_ok,
"torch_ok": torch_version is not None,
"torch_version": torch_version,
@@ -118,34 +127,67 @@ def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
def _to_numpy(value: Any) -> np.ndarray:
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
elif hasattr(value, "cpu"):
value = value.detach()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.cpu().numpy()
elif hasattr(value, "cpu"):
value = value.cpu()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.numpy()
return np.asarray(value)
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
payload = json.loads(request_path.read_text(encoding="utf-8"))
image_path = Path(payload["image_path"])
text = str(payload["text"]).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
return [min(max(float(value), 0.0), 1.0) for value in bbox]
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model()
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
output = processor.set_text_prompt(state=state, prompt=text)
polygons = seed.get("polygons") or []
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(outputs.get("out_binary_masks", []))
scores = _to_numpy(outputs.get("out_probs", []))
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
if masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 2:
masks = masks[None, ...]
polygons = []
out_scores = []
out_ids = []
for index, mask in enumerate(masks):
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", []))
if masks.ndim == 4:
@@ -165,6 +207,115 @@ def predict(request_path: Path) -> dict[str, Any]:
}
def predict_video(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model_builder import build_sam3_video_predictor
payload = json.loads(request_path.read_text(encoding="utf-8"))
frame_dir = Path(payload["frame_dir"])
source_frame_index = int(payload.get("source_frame_index", 0))
seed = payload.get("seed") or {}
direction = str(payload.get("direction") or "forward").lower()
max_frames = payload.get("max_frames")
max_frames = int(max_frames) if max_frames else None
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if direction not in {"forward", "backward", "both"}:
raise ValueError(f"Unsupported propagation direction: {direction}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
predictor = build_sam3_video_predictor(
checkpoint_path=checkpoint_path or None,
async_loading_frames=False,
)
session_id = None
try:
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
session = predictor.handle_request(
{
"type": "start_session",
"resource_path": str(frame_dir),
"offload_video_to_cpu": True,
"offload_state_to_cpu": True,
}
)
session_id = session["session_id"]
predictor.handle_request(
{
"type": "add_prompt",
"session_id": session_id,
"frame_index": source_frame_index,
"bounding_boxes": [_bbox_from_seed(seed)],
"bounding_box_labels": [1],
"output_prob_thresh": threshold,
"rel_coordinates": True,
}
)
frames = []
for item in predictor.handle_stream_request(
{
"type": "propagate_in_video",
"session_id": session_id,
"propagation_direction": direction,
"start_frame_index": source_frame_index,
"max_frame_num_to_track": max_frames,
"output_prob_thresh": threshold,
}
):
frame_response = _video_outputs_to_response(item.get("outputs") or {})
frame_response["frame_index"] = int(item["frame_index"])
frames.append(frame_response)
finally:
if session_id:
predictor.handle_request({"type": "close_session", "session_id": session_id})
return {"frames": frames}
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
payload = json.loads(request_path.read_text(encoding="utf-8"))
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
return predict_video(request_path)
image_path = Path(payload["image_path"])
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
text = str(payload.get("text") or "").strip()
threshold = float(payload.get("confidence_threshold", 0.5))
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
if prompt_type == "semantic" and not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if prompt_type not in {"semantic", "box"}:
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model(
checkpoint_path=checkpoint_path or None,
load_from_HF=not bool(checkpoint_path),
)
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
if prompt_type == "box":
output = processor.add_geometric_prompt(
state=state,
box=_xyxy_to_cxcywh(payload.get("box") or []),
label=True,
)
else:
output = processor.set_text_prompt(state=state, prompt=text)
return _prediction_to_response(output)
def main() -> int:
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
parser.add_argument("--status", action="store_true")

View File

@@ -67,6 +67,19 @@ class SAMRegistry:
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box)
def predict_interactive(
self,
model_id: str | None,
image: Any,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
):
model = self.normalize_model_id(model_id)
if model != "sam2":
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
return self._ensure_available(model).predict_interactive(image, box, points, labels)
def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image)
@@ -76,5 +89,22 @@ class SAMRegistry:
return self._ensure_available(model).predict_semantic(image, text)
return self._ensure_available(model).predict_auto(image)
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
):
return self._ensure_available(model_id).propagate_video(
frame_paths,
source_frame_index,
seed,
direction=direction,
max_frames=max_frames,
)
sam_registry = SAMRegistry()

View File

@@ -116,6 +116,44 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
assert semantic_response.json()["scores"] == [0.5]
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
def fake_predict_interactive(model, image, box, points, labels):
calls["model"] = model
calls["box"] = box
calls["points"] = points
calls["labels"] = labels
return (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.88],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "interactive",
"prompt_data": {
"box": [0.1, 0.1, 0.9, 0.9],
"points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0],
},
"model": "sam2",
})
assert response.status_code == 200
assert response.json()["scores"] == [0.88]
assert calls == {
"model": "sam2",
"box": [0.1, 0.1, 0.9, 0.9],
"points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0],
}
def test_model_status_reports_runtime(client, monkeypatch):
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
"selected_model": selected_model or "sam2",
@@ -170,6 +208,80 @@ def test_model_status_reports_runtime(client, monkeypatch):
assert body["models"][1]["available"] is False
def test_propagate_saves_tracked_annotations(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Video Project"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(3)
]
calls = {}
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
calls["model"] = model
calls["source_frame_index"] = source_frame_index
calls["seed"] = seed
calls["direction"] = direction
calls["max_frames"] = max_frames
calls["frame_count"] = len(frame_paths)
return [
{
"frame_index": 0,
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"scores": [0.9],
"object_ids": [1],
},
{
"frame_index": 1,
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
"scores": [0.8],
"object_ids": [1],
},
]
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
response = client.post("/api/ai/propagate", json={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2",
"direction": "forward",
"max_frames": 2,
"include_source": False,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"bbox": [0.1, 0.1, 0.1, 0.1],
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
"template_id": None,
},
})
assert response.status_code == 200
body = response.json()
assert body["created_annotation_count"] == 1
assert body["processed_frame_count"] == 2
assert calls["model"] == "sam2"
assert calls["source_frame_index"] == 0
assert calls["direction"] == "forward"
assert calls["frame_count"] == 2
saved = body["annotations"][0]
assert saved["frame_id"] == frames[1]["id"]
assert saved["mask_data"]["source"] == "sam2_propagation"
assert saved["mask_data"]["class"]["name"] == "胆囊"
assert saved["mask_data"]["score"] == 0.8
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert len(listing.json()) == 1
def test_predict_validation_errors(client, monkeypatch):
project, _, _ = _create_project_and_frame(client)

View File

@@ -84,6 +84,12 @@ def test_parse_media_queues_background_task(client, monkeypatch):
assert data["progress"] == 0
assert data["project_id"] == project["id"]
assert data["celery_task_id"] == "celery-1"
assert data["payload"] == {
"source_type": "video",
"parse_fps": 5.0,
"max_frames": None,
"target_width": 640,
}
assert queued == [data["id"]]
assert published == [data["id"]]
@@ -94,6 +100,35 @@ def test_parse_media_queues_background_task(client, monkeypatch):
assert project_detail["status"] == "parsing"
def test_parse_media_accepts_frame_sequence_options(client, monkeypatch):
project = client.post("/api/projects", json={
"name": "Parse Options",
"video_path": "uploads/1/clip.mp4",
"source_type": "video",
"parse_fps": 30,
}).json()
class FakeAsyncResult:
id = "celery-options"
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: FakeAsyncResult())
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: None)
response = client.post(
f"/api/media/parse?project_id={project['id']}&parse_fps=15&max_frames=120&target_width=960"
)
assert response.status_code == 202
data = response.json()
assert data["payload"] == {
"source_type": "video",
"parse_fps": 15.0,
"max_frames": 120,
"target_width": 960,
}
assert client.get(f"/api/projects/{project['id']}").json()["parse_fps"] == 15.0
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
@@ -118,10 +153,14 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
frame_file.write_bytes(b"fake image")
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
monkeypatch.setattr(
"services.media_task_runner.parse_video",
lambda local_path, output_dir, fps, max_frames=None, target_width=640: ([str(frame_file)], 25.0),
)
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
monkeypatch.setattr("routers.projects.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
published = []
monkeypatch.setattr(
"services.media_task_runner.publish_task_progress_event",
@@ -131,6 +170,17 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
result = run_parse_media_task(db_session, task.id)
assert result["frames_extracted"] == 1
assert result["frame_sequence"] == {
"original_fps": 25.0,
"parse_fps": 5.0,
"frame_count": 1,
"duration_ms": 0.0,
"target_width": 640,
"frame_width": None,
"frame_height": None,
"max_frames": None,
"object_prefix": f"projects/{project['id']}/frames",
}
db_session.refresh(task)
assert task.status == "success"
assert task.progress == 100
@@ -140,6 +190,8 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
assert project_detail["status"] == "ready"
frames = client.get(f"/api/projects/{project['id']}/frames").json()
assert "frame_000001.jpg" in frames[0]["image_url"]
assert frames[0]["timestamp_ms"] == 0.0
assert frames[0]["source_frame_number"] == 0
def test_parse_task_runner_skips_already_cancelled_task(db_session):

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import numpy as np
from services.sam3_engine import SAM3Engine
from services.sam3_external_worker import _to_numpy
class _Completed:
@@ -14,6 +15,8 @@ class _Completed:
def _external_settings(monkeypatch, python_path: Path):
checkpoint_path = python_path.with_name("sam3.pt")
checkpoint_path.write_bytes(b"checkpoint")
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
python_path.chmod(0o755)
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
@@ -23,6 +26,7 @@ def _external_settings(monkeypatch, python_path: Path):
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
monkeypatch.setattr("services.sam3_engine.settings.sam3_checkpoint_path", str(checkpoint_path))
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
@@ -30,9 +34,12 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
def fake_run(args, **_kwargs):
assert "--status" in args
assert _kwargs["env"]["SAM3_CHECKPOINT_PATH"].endswith("sam3.pt")
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"checkpoint_path": _kwargs["env"]["SAM3_CHECKPOINT_PATH"],
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
@@ -48,7 +55,10 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
assert status["external_available"] is True
assert status["package_available"] is True
assert status["python_ok"] is True
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
assert status["checkpoint_exists"] is True
assert status["checkpoint_path"].endswith("sam3.pt")
assert status["supports"] == ["semantic", "box", "video_track"]
assert status["message"] == "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
@@ -61,6 +71,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
@@ -71,6 +82,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["text"] == "vessel"
assert request["confidence_threshold"] == 0.4
assert request["checkpoint_path"].endswith("sam3.pt")
assert Path(request["image_path"]).exists()
return _Completed(stdout=json.dumps({
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
@@ -86,6 +98,97 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
assert any("--request" in args for args in calls)
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
request_path = Path(args[-1])
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["prompt_type"] == "box"
assert request["box"] == [0.1, 0.2, 0.7, 0.8]
assert request["text"] == ""
return _Completed(stdout=json.dumps({
"polygons": [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]],
"scores": [0.88],
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
polygons, scores = SAM3Engine().predict_box(
np.zeros((8, 8, 3), dtype=np.uint8),
[0.1, 0.2, 0.7, 0.8],
)
assert polygons == [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]]
assert scores == [0.88]
def test_sam3_propagate_video_uses_external_worker(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
frame_dir = tmp_path / "frames"
frame_dir.mkdir()
frame_paths = []
for index in range(2):
frame_path = frame_dir / f"frame_{index:06d}.jpg"
frame_path.write_bytes(b"jpeg")
frame_paths.append(str(frame_path))
def fake_run(args, **_kwargs):
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
request_path = Path(args[-1])
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["prompt_type"] == "video_track"
assert request["frame_dir"] == str(frame_dir)
assert request["source_frame_index"] == 0
assert request["direction"] == "forward"
assert request["max_frames"] == 2
assert request["seed"]["bbox"] == [0.1, 0.1, 0.2, 0.2]
return _Completed(stdout=json.dumps({
"frames": [
{
"frame_index": 1,
"polygons": [[[0.2, 0.2], [0.4, 0.2], [0.4, 0.4]]],
"scores": [0.7],
"object_ids": [1],
}
]
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
frames = SAM3Engine().propagate_video(
frame_paths,
0,
{"bbox": [0.1, 0.1, 0.2, 0.2]},
direction="forward",
max_frames=2,
)
assert frames[0]["frame_index"] == 1
assert frames[0]["scores"] == [0.7]
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
@@ -94,6 +197,7 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"checkpoint_access": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
@@ -110,3 +214,32 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
assert "HF access denied" in str(exc)
else:
raise AssertionError("Expected SAM 3 external inference failure.")
def test_sam3_worker_casts_floating_tensors_before_numpy():
class FakeTensor:
def __init__(self):
self.float_called = False
def detach(self):
return self
def is_floating_point(self):
return True
def float(self):
self.float_called = True
return self
def cpu(self):
return self
def numpy(self):
return np.array([1.0], dtype=np.float32)
tensor = FakeTensor()
result = _to_numpy(tensor)
assert tensor.float_called is True
assert result.tolist() == [1.0]

View File

@@ -38,14 +38,14 @@ Word 方案描述的理想系统包含:
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` |
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` |
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 |
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 2 已接 video predictor 片段传播SAM 3 通过独立 Python 3.12 环境桥接,支持文本/框提示和 official video tracker 入口,状态会检查 Python/CUDA/包/本地 checkpoint |
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
| 标注持久化 | 部分落地 | 后端有 `Annotation`前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`COCO JSON 和 PNG mask ZIP 前端按钮均已接入ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
## 当前代码尚未落地的目标
- SAM 3当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py``setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny
- SAM 3当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py``setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU官方包导入和本地 `sam3权重/sam3.pt` checkpoint 状态检查。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tinyvideo tracker 入口已接入,真实效果取决于本地 checkpoint 是否兼容 video model
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。
- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point骨架提取、HDBSCAN 和模板自动映射尚未实现。
- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 等增强尚未实现。
@@ -55,4 +55,4 @@ Word 方案描述的理想系统包含:
## 结论
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可用 SAM 2 / SAM 3 进行视频片段传播、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 SAM 3 真实视频 tracker smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。

View File

@@ -65,6 +65,7 @@
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON |
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point再回显到工作区 |
| “传播片段”按钮 | 真实可用 | 使用当前选中 mask 或当前帧第一个 mask 作为 seed调用 `POST /api/ai/propagate`SAM 2 走 video predictorSAM 3 走 external video tracker完成后刷新已保存标注 |
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
## CanvasArea 画布
@@ -78,11 +79,11 @@
| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 |
| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon点工具生成小区域均写入 `Mask.segmentation`,可归档保存 |
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后可按 Enter 完成,也可在三点后点击首节点闭合;矩形/圆/线拖拽生成 polygon点工具生成小区域绘制工具可在已有 mask 上继续落点;均写入 `Mask.segmentation`,可归档保存 |
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 |
| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty选中顶点后 Delete/Backspace 可删点但保留至少三点 |
| Polygon 逐点编辑 / 删除 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty选中顶点后 Delete/Backspace 可删点但保留至少三点;选中 mask 但未选中顶点时 Delete/Backspace 删除整个 mask已保存 mask 会同步调用后端删除 |
| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty归档保存会更新后端 |
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask已保存 mask 会标为 dirty归档保存时更新后端 |
| 应用分类 | 真实可用 | Canvas 右下角按钮可将当前选择的模板分类应用到本帧 mask右侧语义分类树点击分类时会优先改当前已选 mask已保存 mask 会标为 dirty归档保存时更新后端 |
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
| 当前图层树文字 | Mock / UI-only | 固定显示 `OBJECT_VEHICLE_01` |
@@ -93,7 +94,7 @@
|------|------|------|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask使用 `polygon-clipping` 做 union / difference合并会保留主 mask 并移除被合并 mask去除会从主 mask 扣除后续选中 mask |
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference合并会保留主 mask 并移除被合并 mask去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y |
@@ -105,15 +106,16 @@
| 帧缩略图 | 真实可用 | 使用 `frames[].url` |
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` |
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
| 方向键切帧 | Mock / UI-only | Word 提到,但当前没有键盘监听 |
| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 |
## OntologyInspector 本体面板
| 元素 | 状态 | 说明 |
|------|------|------|
| 模板选择 | 部分可用 | 读取全局 templates可切换 activeTemplateId |
| 分类树展示 | 真实可用 | 显示模板 classes 和本地 customClasses |
| 分类树展示 / 换标签 | 真实可用 | 显示模板 classes 和本地 customClasses;点击分类会设为后续新 mask 的 activeClass如果 Canvas 已选 mask则同步更新已选 mask 的标签、颜色和 class 元数据 |
| 添加自定义分类 | 部分可用 | 只存在组件本地状态,不保存到后端 |
| 置信度条 | Mock / UI-only | 固定 `0.9412` |
| 拓扑锚点数量 | Mock / UI-only | 固定 `12 节点` |
@@ -124,8 +126,9 @@
| 元素 | 状态 | 说明 |
|------|------|------|
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand`predictMask()` 会把 `model` 传给后端 SAM registry |
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
| 正向/反向点 | 真实可用 | 可在当前项目帧上加点调用 AI 推理接口SAM 2 框选后会携带原始框和累计正/反点细化同一个候选 maskSAM 3 选择后会提示点交互需切回 SAM 2 |
| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 |
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 |
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
@@ -150,6 +153,6 @@
## 总体结论
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
当前最主要的 Mock 或未打通链路是polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。

View File

@@ -32,12 +32,13 @@ Authorization: Bearer <token>
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;支持 `parse_fps``max_frames``target_width` |
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url,以及 `timestamp_ms``source_frame_number` |
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 |
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
@@ -70,12 +71,13 @@ Authorization: Bearer <token>
| DELETE | `/api/templates/{template_id}` | 删除模板 |
| POST | `/api/media/upload` | 上传视频/图片/DICOM 单文件 |
| POST | `/api/media/upload/dicom` | 批量上传 DICOM |
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
| POST | `/api/media/parse` | 创建 Celery 拆帧任务query 支持 `project_id``source_type``parse_fps``max_frames``target_width` |
| GET | `/api/tasks` | 查询后台任务列表 |
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
| POST | `/api/ai/propagate` | SAM 2 / SAM 3 视频片段传播并保存标注 |
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
| POST | `/api/ai/auto` | 自动分割 |
| POST | `/api/ai/annotate` | 保存 AI 标注 |
@@ -110,6 +112,25 @@ Authorization: Bearer <token>
}
```
### 创建标准帧序列拆帧任务
```text
POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
```
任务 `payload` 会记录本次拆帧参数;完成后的 `result.frame_sequence` 返回 `original_fps``parse_fps``frame_count``duration_ms``target_width`、帧宽高和 MinIO object prefix。每条 `FrameOut` 包含:
```json
{
"frame_index": 0,
"image_url": "http://...",
"width": 960,
"height": 540,
"timestamp_ms": 0,
"source_frame_number": 0
}
```
### 创建/更新模板
```json
@@ -150,11 +171,14 @@ Authorization: Bearer <token>
- `point`
- `box`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定
- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points``labels`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。
选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。
可选 `options` 字段:
- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
@@ -183,6 +207,32 @@ Authorization: Bearer <token>
}
```
### 视频片段传播请求体
工作区“传播片段”调用:
```json
{
"project_id": 1,
"frame_id": 123,
"model": "sam2",
"direction": "forward",
"max_frames": 30,
"include_source": false,
"save_annotations": true,
"seed": {
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
"bbox": [0.1, 0.1, 0.2, 0.2],
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
"template_id": 2
}
}
```
`model=sam2` 使用 SAM 2 video predictor 的 mask seed 传播;`model=sam3` 使用独立 Python 3.12 helper 中的 SAM 3 video tracker并以 seed bbox 作为初始提示。响应会返回已创建的 `annotations`,保存的 `mask_data.source``sam2_propagation``sam3_propagation`
## 已完成的接口对齐
- `updateProject()` 已从 `PUT` 改为 `PATCH`

View File

@@ -16,7 +16,7 @@
剩余边界:
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接、本地 `sam3权重/sam3.pt` checkpoint 状态检查、本地 checkpoint 加载参数接入、单帧文本/框提示和 video tracker API 入口;下一步需要基于真实业务帧验证语义召回质量和视频 tracker 稳定性
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。
## 阶段 2打通标注保存已完成基础闭环
@@ -127,6 +127,24 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前
1. 在前端预览重叠裁决结果。
2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。
## 阶段 7.5:视频片段传播(已完成基础闭环)
当前工作区“传播片段”会使用当前选中 mask 或当前帧第一个 mask 作为 seed默认向后传播 30 帧并把结果写入后端标注表。
已完成:
1. 前端 `propagateMasks()` 已接入 `POST /api/ai/propagate`
2. 工作区按钮会把 seed mask 的 normalized polygon、bbox、label、color 和 class 元数据传给后端。
3. SAM 2 路径使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`
4. SAM 3 路径通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境中的官方 `build_sam3_video_predictor()`
5. 后端会跳过源帧,把传播结果保存到后续帧 `annotations`,并在完成后由前端刷新回显。
剩余建议:
1. 把传播任务改为异步任务,接入 Dashboard 和 WebSocket 进度。
2. 前端增加传播方向、帧数和覆盖已有标注策略设置。
3. 用真实长视频分别做 SAM 2 / SAM 3 tracker smoke test 和质量评估。
## 阶段 8清理 UI 文案与 Mock
建议统一这些文案和真实能力:

View File

@@ -28,7 +28,11 @@
- 未提供项目 ID 上传时,后端自动创建项目。
- 提供项目 ID 上传时,后端把上传对象关联到该项目。
- 拆帧接口根据项目 `source_type` 处理视频或 DICOM。
- 拆帧接口支持 `parse_fps``max_frames``target_width` 参数,用于生成可被 SAM 2 / SAM 3 视频处理复用的标准帧序列。
- 视频帧使用连续 `frame_%06d.jpg` 命名,默认从 `frame_000000.jpg` 开始,并按 `target_width` 缩放。
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`
- 每条帧记录包含 `frame_index``image_url``width``height``timestamp_ms``source_frame_number`
- 任务完成结果包含 `frame_sequence` 元数据:`original_fps``parse_fps``frame_count``duration_ms``target_width`、帧宽高和对象存储前缀。
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
- 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。
@@ -41,8 +45,9 @@
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。
- Canvas 显示当前帧图片。
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
- 时间轴支持缩略图点击切帧、range 拖动切帧、播放/暂停顺序推进帧。
- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
- 播放帧率使用项目 `parse_fps``original_fps`,限制在 1 到 30 FPS。
- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps``original_fps`,格式为 `mm:ss.cc`
## R5 工具栏
@@ -50,12 +55,16 @@
- 正向点、反向点、框选工具会影响 Canvas 交互。
- 魔法棒按钮切换到 AI 页面。
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
- 选中整个 mask 且未选中具体顶点时Delete/Backspace 删除该 mask已保存 mask 同步调用后端删除接口。
- 撤销、重做绑定全局 `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
- 区域合并工具支持多选当前帧 mask并使用 polygon union 生成合并后的主 mask。
- 区域去除工具支持多选当前帧 mask并从第一个选中的主 mask 中扣除后续选中 mask。
- 区域合并/去除模式显示已选数量,并隐藏 polygon 编辑手柄以避免手柄抢占多选点击。
- 区域去除结果包含内洞时,前端保留 hole ring 并用 even-odd 规则渲染。
## R6 AI 推理
@@ -65,7 +74,14 @@
- 前端发送后端契约:`image_id``prompt_type``prompt_data``model`
- 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。
- 框选提示传归一化 `[x1, y1, x2, y2]`
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割
- 工作区 SAM 2 框选会建立一个候选 mask后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。
- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。
- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。
- `POST /api/ai/propagate` 支持 `model=sam2``model=sam3`SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()``propagate_in_video()`SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。
- 传播结果会写入后续帧 `annotations``mask_data.source` 分别标记为 `sam2_propagation``sam3_propagation`,并保留 label、color 和 class 元数据。
- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。
- AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
- 后端返回 `polygons``scores`
- 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area`
@@ -98,6 +114,7 @@
- 工作区右侧可以选择模板。
- 面板显示模板分类和组件本地自定义分类。
- 用户可以选择具体分类;新 AI mask 会记录 `classId``className``classZIndex`,并在保存时写入 `mask_data.class`
- 如果 Canvas 当前已经选中一个或多个 mask点击语义分类树会把这些 mask 的 `label``color` 和 class 元数据改为该分类;已保存 mask 会进入 `dirty` 状态,归档保存时更新后端。
- 添加自定义分类只存在组件本地状态,不保存到后端。
- 置信度、拓扑锚点和重新提取骨架按钮当前为展示/占位。

View File

@@ -19,7 +19,7 @@
| 模块 | 文件 | 设计职责 |
|------|------|----------|
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
| 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 |
| 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 |
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
@@ -30,7 +30,7 @@
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 |
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、左右方向键切帧、播放和当前/总时长显示 |
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
@@ -48,10 +48,10 @@
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、视频传播、模型状态和标注保存 |
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测点/框/自动推理 |
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 |
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测点/框/自动推理和视频 mask 传播 |
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接、本地 checkpoint 加载、文本语义推理、正框几何提示和 video tracker 适配 |
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
## 状态模型
@@ -59,9 +59,10 @@
前端 store 的核心对象:
- `Project`项目基本信息、状态、帧数、fps、媒体路径。
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高、序列时间戳和原视频源帧号
- `Template` / `TemplateClass`:模板和分类定义。
- `Mask`:前端渲染用 mask包含 `pathData``segmentation``bbox``area`
- `selectedMaskIds`Canvas 当前选中的 mask id 列表,供右侧本体面板对已选区域直接换标签。
- `maskHistory` / `maskFuture`mask 编辑历史栈,用于撤销和重做。
- `activeModule`:当前页面。
- `activeTool`:当前工具。
@@ -79,9 +80,11 @@
1. `ProjectLibrary` 创建项目。
2. 上传视频或 DICOM 到 `/api/media/upload``/api/media/upload/dicom`
3. 调用 `/api/media/parse` 创建异步拆帧任务。
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`
5. 刷新项目列表
3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps``max_frames``target_width` 指定标准帧序列参数
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg``frame_000000.jpg` 连续命名,并按目标宽度缩放
5. worker 写入 `frames.timestamp_ms``frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀
6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`
7. 刷新项目列表。
### 任务控制
@@ -95,29 +98,43 @@
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。
3. 帧数据映射为 store `Frame[]`
3. 帧数据映射为 store `Frame[]`,包含 `timestampMs``sourceFrameNumber`,供时间轴和后续视频传播使用
4. 当前帧传入 `CanvasArea`
### AI 点/框推理
1. 用户在 Canvas 选择正向点、反向点或框选。
2. `CanvasArea` 读取当前帧 ID 和宽高。
3. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`
4. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3
5. 前端把 `polygons` 转为 mask写入 store
6. Canvas 按当前帧过滤并渲染 mask。
7. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft`
8. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
3. SAM 2 框选会创建一个候选 mask并记录原始框后续正向点/反向点会累计到同一候选上
4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt
5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3
6. 前端把 `polygons` 转为 mask交互式细化会替换同一个候选 mask而不是新增多个 mask。
7. Canvas 按当前帧过滤并渲染 mask
8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft`
9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`
10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
### 视频片段传播
1. 用户在工作区选中一个当前帧 mask如果未显式选中前端使用当前帧第一个 mask。
2. `VideoWorkspace``buildAnnotationPayload()` 把 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。
3. 前端调用 `POST /api/ai/propagate`,默认 `direction=forward``max_frames=30``include_source=false`
4. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。
5. `model=sam2` 时,`sam2_engine` 使用 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask再用 `propagate_in_video()` 传播。
6. `model=sam3` 时,`sam3_engine` 将请求交给 `sam3_external_worker.py`,由独立 Python 3.12 环境调用官方 `build_sam3_video_predictor()`,以 seed bbox 走 video tracker。
7. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。
8. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注。
### 手工绘制与历史栈
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
2. `CanvasArea` 将交互坐标转换成像素 polygon。
3. 新 mask 写入 `pathData`、像素 `segmentation``bbox``area` 和当前模板分类元数据
4. `addMask()``setMasks()``updateMask()``clearMasks()` 会维护 `maskHistory/maskFuture`
5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon
4. mask path 只在 `move``area_merge``area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage
5. 新 mask 写入 `pathData`、像素 `segmentation``bbox``area` 和当前模板分类元数据
6. `addMask()``setMasks()``updateMask()``clearMasks()` 会维护 `maskHistory/maskFuture`
7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`
### Polygon 逐点编辑
@@ -126,14 +143,16 @@
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty``saved=false`
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
6. 未选中具体顶点但选中了 mask 时Delete/Backspace 从前端 store 删除该 mask如果包含 `annotationId`,通过工作区回调调用后端删除接口。
### 区域合并与去除
1. 用户选择 `area_merge``area_remove` 后,点击多个当前帧 mask 组成选择集。
2. `CanvasArea``Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon
3. `area_merge` 使用 union更新第一个选中的主 mask并从前端 store 移除后续被合并 mask如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注
4. `area_remove` 使用 difference从第一个选中的主 mask 中扣除后续选中 mask扣除对象本身保留
5. 结果会重算 `pathData``segmentation``bbox``area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路
2. 合并/去除模式隐藏 polygon 顶点和边中点编辑手柄,并在右下角显示已选数量;少于两个 mask 时操作按钮禁用
3. `CanvasArea``Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon
4. `area_merge` 使用 union更新第一个选中的主 mask并从前端 store 移除后续被合并 mask如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注
5. `area_remove` 使用 difference从第一个选中的主 mask 中扣除后续选中 mask扣除对象本身保留如果 difference 产生内洞,`segmentation` 保留外圈和 hole ring渲染时使用 even-odd fill
6. 结果会重算 `pathData``segmentation``bbox``area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路;带洞结果的面积按外圈减内洞计算。
### GT Mask 导入
@@ -153,7 +172,10 @@
3. 保存时调用 `createTemplate()``updateTemplate()`
4. 后端把 `classes``rules` 打包进 `mapping_rules`
5. 返回时再解包给前端。
6. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask使用
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择
7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask 时使用。
8. 如果 `selectedMaskIds` 中存在当前 store 的 mask点击分类时会立即更新这些 mask 的 `templateId``classId``className``classZIndex``label``color`
9. 已保存 mask 被重新分类后进入 `dirty``saved=false`,继续复用工作区归档保存的 PATCH 链路。
### 导出
@@ -173,13 +195,20 @@
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id``prompt_type``prompt_data``model`
- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id``frame_id``model``seed``direction``max_frames`
- `saveAnnotation()` 使用 `POST /api/ai/annotate`
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3
- 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果
- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps``max_frames``target_width`,用于生成标准帧序列
- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms``source_frame_number`
- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。
- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。
- SAM 3 box prompt 复用后端 `/api/ai/predict``box` prompt_type输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。
- AI 页面选择 SAM 3 时优先发送文本 semantic prompt不会把正/反点误发送为 SAM 3 point prompt空文本、后端错误和空结果都会显示反馈消息。
- 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker当前前端默认向后传播 30 帧并保存结果标注。
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
@@ -198,7 +227,7 @@
以下能力属于当前冻结版本的占位或半可用功能:
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权`available=false`
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和本地 checkpoint;状态接口会暴露真实可用性,运行时缺失`available=false`
- 自定义分类只存在本地组件状态。
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。

View File

@@ -16,18 +16,51 @@
|------|----------|--------|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 |
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
## 逐功能点追踪
| 需求 | 功能点 | 对应测试 | 当前状态 |
|------|--------|----------|----------|
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 |
| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 |
| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 |
| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
## 本轮补齐记录
- R5补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
- R6补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。
- R6补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。
- R6补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。
- R6补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。
- R6补充 `propagateMasks()` API 封装和 `VideoWorkspace` 传播按钮测试,验证当前选中区域会发送到后端视频传播接口。
- R6补充 SAM 3 external video tracker 请求测试验证主后端会把帧目录、源帧索引、seed bbox 和方向传给独立 Python helper。
- R3补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps``max_frames``target_width` 会进入任务。
- R3补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms``source_frame_number``result.frame_sequence` 元数据。
- R8补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。
- R9补充 Canvas 选中 mask id 全局同步、本体树点击分类给已选 mask 换标签的测试,验证已保存 mask 会进入 dirty 状态。
## 运行命令
```bash

View File

@@ -1,4 +1,4 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
@@ -62,4 +62,112 @@ describe('AISegmentation', () => {
},
}));
});
it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
});
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam3',
points: undefined,
text: '胆囊',
})));
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
});
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'semantic-1',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'semantic result',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
useStore.setState({
activeTemplateId: 'template-1',
activeClassId: 'class-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
text: '胆囊',
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
})));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'semantic-1',
frameId: 'frame-1',
templateId: 'template-1',
classId: 'class-1',
className: '胆囊',
classZIndex: 30,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
});

View File

@@ -33,6 +33,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
// Canvas state
const [scale, setScale] = useState(1);
@@ -91,9 +92,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
};
const runInference = useCallback(async () => {
if (points.length === 0 && !semanticText.trim()) return;
const textPrompt = semanticText.trim();
if (aiModel === 'sam3' && !textPrompt) {
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return;
}
if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
return;
}
if (!currentFrame?.id) {
console.warn('AI inference skipped: no project frame is selected');
setInferenceMessage('请先在项目工作区选择一帧图像。');
return;
}
@@ -101,18 +111,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('AI inference skipped: active frame dimensions are unavailable');
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
return;
}
setIsInferencing(true);
setInferenceMessage('');
try {
const result = await predictMask({
imageId: currentFrame.id,
imageWidth,
imageHeight,
model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined,
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: textPrompt || undefined,
options: {
crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg,
@@ -120,6 +132,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
},
});
if (result.masks.length === 0) {
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
} else {
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
}
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
@@ -142,6 +159,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
});
} catch (err) {
console.error('AI inference failed:', err);
const detail = (err as any)?.response?.data?.detail;
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
} finally {
setIsInferencing(false);
}
@@ -282,6 +301,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
</button>
{inferenceMessage && (
<div className="rounded border border-white/10 bg-white/5 px-3 py-2 text-[11px] leading-relaxed text-gray-300">
{inferenceMessage}
</div>
)}
<button
onClick={onSendToWorkspace}
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"

View File

@@ -65,6 +65,157 @@ describe('CanvasArea', () => {
}));
});
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
});
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam3-box-mask',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
});
render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'sam3-box-mask',
metadata: expect.objectContaining({
source: 'sam3_box',
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
}),
}));
});
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'mask-box',
pathData: 'M 10 10 L 90 10 L 90 90 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 90]],
bbox: [10, 10, 80, 80],
area: 6400,
},
],
})
.mockResolvedValueOnce({
masks: [
{
id: 'mask-refined-pos',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
})
.mockResolvedValueOnce({
masks: [
{
id: 'mask-refined-neg',
pathData: 'M 30 30 L 70 30 L 70 70 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[30, 30, 70, 30, 70, 70]],
bbox: [30, 30, 40, 40],
area: 1600,
},
],
});
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(1, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
rerender(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(stage, { clientX: 150, clientY: 100 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 150, y: 100, type: 'pos' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box',
segmentation: [[20, 20, 80, 20, 80, 80]],
metadata: expect.objectContaining({
source: 'sam2_interactive',
promptPointCount: 1,
}),
}));
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
fireEvent.click(stage, { clientX: 300, clientY: 150 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(3, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [
{ x: 150, y: 100, type: 'pos' },
{ x: 300, y: 150, type: 'neg' },
],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box',
segmentation: [[30, 30, 70, 30, 70, 70]],
points: [[150, 100]],
metadata: expect.objectContaining({ promptPointCount: 2 }),
}));
});
it('renders only masks that belong to the current frame', () => {
useStore.setState({
masks: [
@@ -79,6 +230,26 @@ describe('CanvasArea', () => {
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
});
it('publishes the selected mask ids for the ontology panel', async () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'A',
color: '#fff',
segmentation: [[0, 0, 10, 0, 10, 10]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
});
it('renders imported GT seed points for editable point regions', () => {
useStore.setState({
masks: [
@@ -164,6 +335,57 @@ describe('CanvasArea', () => {
}));
});
it('deletes the selected draft mask with Delete when no vertex is selected', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
fireEvent.keyDown(window, { key: 'Delete' });
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().maskHistory.at(-1)).toEqual([
expect.objectContaining({ id: 'draft-1' }),
]);
});
it('deletes the selected saved mask locally and notifies the backend deletion callback', () => {
const onDeleteMaskAnnotations = vi.fn();
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Saved',
color: '#06b6d4',
saveStatus: 'saved',
saved: true,
segmentation: [[10, 10, 90, 10, 90, 40]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} onDeleteMaskAnnotations={onDeleteMaskAnnotations} />);
fireEvent.click(screen.getByTestId('konva-path'));
fireEvent.keyDown(window, { key: 'Backspace' });
expect(useStore.getState().masks).toEqual([]);
expect(onDeleteMaskAnnotations).toHaveBeenCalledWith(['99']);
});
it('inserts a polygon vertex from an edge midpoint handle', () => {
useStore.setState({
masks: [
@@ -248,9 +470,13 @@ describe('CanvasArea', () => {
});
render(<CanvasArea activeTool="area_merge" frame={frame} />);
expect(screen.getByText('已选 0')).toBeInTheDocument();
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
expect(screen.getByText('已选 1')).toBeInTheDocument();
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
fireEvent.click(paths[1]);
expect(screen.getByText('已选 2')).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
expect(useStore.getState().masks).toHaveLength(1);
@@ -300,6 +526,45 @@ describe('CanvasArea', () => {
expect(useStore.getState().masks[1].id).toBe('m2');
});
it('renders inner overlap removal as a hole in the primary mask', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 110 10 L 110 110 L 10 110 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 110, 10, 110, 110, 10, 110]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 40 40 L 80 40 L 80 80 L 40 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[40, 40, 80, 40, 80, 80, 40, 80]],
},
],
});
render(<CanvasArea activeTool="area_remove" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
const [primary] = useStore.getState().masks;
expect(primary).toEqual(expect.objectContaining({
id: 'm1',
area: 8400,
bbox: [10, 10, 100, 100],
metadata: expect.objectContaining({ hasHoles: true }),
}));
expect(primary.segmentation).toHaveLength(2);
expect(screen.getAllByTestId('konva-path')[0]).toHaveAttribute('data-fill-rule', 'evenodd');
});
it('creates a manual rectangle mask that can be undone and redone', () => {
useStore.setState({
activeTemplateId: '2',
@@ -329,6 +594,93 @@ describe('CanvasArea', () => {
expect(useStore.getState().masks).toHaveLength(1);
});
it('creates a manual circle mask from a drag gesture', () => {
render(<CanvasArea activeTool="create_circle" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工圆形',
color: '#06b6d4',
saveStatus: 'draft',
bbox: [120, 80, 140, 120],
metadata: expect.objectContaining({
source: 'manual',
shape: '圆形',
}),
}));
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
});
it('creates a manual line region from a drag gesture', () => {
render(<CanvasArea activeTool="create_line" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工线段',
color: '#06b6d4',
saveStatus: 'draft',
metadata: expect.objectContaining({
source: 'manual',
shape: '线段',
}),
}));
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
});
it('creates an editable point region on click', () => {
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工点区域',
color: '#06b6d4',
saveStatus: 'draft',
points: [[120, 80]],
bbox: expect.arrayContaining([115, 75]),
metadata: expect.objectContaining({
source: 'manual',
shape: '点区域',
}),
}));
});
it('creates a point region when clicking over an existing mask', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 200 10 L 200 200 Z',
label: 'Existing',
color: '#06b6d4',
segmentation: [[10, 10, 200, 10, 200, 200]],
},
],
});
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
expect(useStore.getState().masks).toHaveLength(2);
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
metadata: expect.objectContaining({ shape: '点区域' }),
points: [[120, 80]],
}));
});
it('finalizes a clicked polygon with Enter', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
@@ -344,6 +696,29 @@ describe('CanvasArea', () => {
}));
});
it('closes a clicked polygon by clicking the first node again', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.click(stage, { clientX: 120, clientY: 80 });
fireEvent.click(stage, { clientX: 220, clientY: 80 });
fireEvent.click(stage, { clientX: 180, clientY: 160 });
const handles = screen.getAllByTestId('konva-circle');
expect(handles[0]).toHaveAttribute('data-fill', '#facc15');
fireEvent.click(handles[0]);
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 120 80 L 220 80 L 180 160 Z',
segmentation: [[120, 80, 220, 80, 180, 160]],
metadata: expect.objectContaining({
source: 'manual',
shape: '多边形',
}),
}));
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',

View File

@@ -14,11 +14,14 @@ interface CanvasAreaProps {
}
type CanvasPoint = { x: number; y: number };
type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
const POLYGON_CLOSE_RADIUS = 8;
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max);
@@ -88,6 +91,10 @@ function polygonArea(points: CanvasPoint[]): number {
return Math.abs(sum) / 2;
}
function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
return Math.hypot(a.x - b.x, a.y - b.y);
}
function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
}
@@ -115,20 +122,35 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
return polygons.length > 0 ? polygons : null;
}
function openRingPoints(ring: Pair[]): CanvasPoint[] {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.map(([x, y]) => ({ x, y }));
}
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
return geometry
.map((polygon) => polygon[0] || [])
.map((ring) => {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.flatMap(([x, y]) => [x, y]);
})
.flatMap((polygon) => polygon)
.map((ring) => openRingPoints(ring).flatMap(({ x, y }) => [x, y]))
.filter((polygon) => polygon.length >= 6);
}
function multiPolygonArea(geometry: MultiPolygon): number {
return geometry.reduce((sum, polygon) => {
const [outerRing, ...holeRings] = polygon;
const outerArea = outerRing ? polygonArea(openRingPoints(outerRing)) : 0;
const holesArea = holeRings.reduce((holeSum, ring) => holeSum + polygonArea(openRingPoints(ring)), 0);
return sum + Math.max(outerArea - holesArea, 0);
}, 0);
}
function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
return geometry.some((polygon) => polygon.length > 1);
}
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const x1 = Math.min(start.x, end.x);
const y1 = Math.min(start.y, end.y);
@@ -179,10 +201,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1);
const [position, setPosition] = useState({ x: 0, y: 0 });
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
const [points, setPoints] = useState<PromptPoint[]>([]);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
const [samPromptBox, setSamPromptBox] = useState<PromptBox | null>(null);
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
@@ -191,12 +215,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const updateMask = useStore((state) => state.updateMask);
const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const storeActiveTool = useStore((state) => state.activeTool);
const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
@@ -226,6 +252,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
useEffect(() => {
const handleResize = () => {
@@ -252,6 +279,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]);
useEffect(() => {
setPoints([]);
setSamPromptBox(null);
setSamCandidateMaskId(null);
}, [frame?.id]);
useEffect(() => {
setGlobalSelectedMaskIds(selectedMaskIds);
}, [selectedMaskIds, setGlobalSelectedMaskIds]);
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
useEffect(() => {
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null);
@@ -324,6 +363,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
addMask(mask);
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const finishPolygon = useCallback(() => {
if (polygonPoints.length < 3) return;
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
}, [createManualMask, polygonPoints]);
const handleMouseMove = (e: any) => {
const stage = e.target.getStage();
if (!stage) return;
@@ -349,9 +394,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}
};
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
const runInference = useCallback(async (
promptPoints?: PromptPoint[],
promptBox?: PromptBox,
options: { resetCandidate?: boolean } = {},
) => {
if (!frame?.id) {
console.warn('Inference skipped: no active frame');
setInferenceMessage('请先选择一帧图像。');
return;
}
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
return;
}
@@ -359,31 +413,44 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('Inference skipped: active frame dimensions are unavailable');
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
return;
}
setIsInferencing(true);
setInferenceMessage('');
try {
const result = await predictMask({
imageId: frame.id,
imageWidth,
imageHeight,
model: aiModel,
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
points: promptPoints && promptPoints.length > 0
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
: undefined,
box: promptBox,
});
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({
id: m.id,
const [m] = result.masks;
if (m) {
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const label = activeClass?.name || existingCandidate?.label || m.label;
const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = {
...(existingCandidate?.metadata || {}),
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0,
};
const nextMask = {
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
templateId: activeTemplateId || existingCandidate?.templateId || undefined,
classId: activeClass?.id || existingCandidate?.classId,
className: activeClass?.name || existingCandidate?.className,
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
saved: false,
pathData: m.pathData,
label,
@@ -392,14 +459,33 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
bbox: m.bbox,
area: m.area,
});
});
metadata,
};
if (existingCandidate) {
updateMask(existingCandidate.id, nextMask);
setSelectedMaskId(existingCandidate.id);
setSelectedMaskIds([existingCandidate.id]);
} else {
const id = m.id;
setSamCandidateMaskId(id);
setSelectedMaskId(id);
setSelectedMaskIds([id]);
addMask({
id,
...nextMask,
});
}
} else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
}
} catch (err) {
console.error('Inference failed:', err);
const detail = (err as any)?.response?.data?.detail;
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
@@ -427,6 +513,29 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
clearMasks();
};
const deleteMasksById = useCallback((maskIds: string[]) => {
if (maskIds.length === 0) return;
const idSet = new Set(maskIds);
const deletingMasks = masks.filter((mask) => idSet.has(mask.id));
if (deletingMasks.length === 0) return;
setMasks(masks.filter((mask) => !idSet.has(mask.id)));
const annotationIds = deletingMasks
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId));
if (annotationIds.length > 0) {
void onDeleteMaskAnnotations?.(annotationIds);
}
if (samCandidateMaskId && idSet.has(samCandidateMaskId)) {
setSamCandidateMaskId(null);
setSamPromptBox(null);
setPoints([]);
}
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
const handleStageMouseDown = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
@@ -476,7 +585,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const y2 = Math.max(boxStart.y, boxCurrent.y);
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
runInference(undefined, { x1, y1, x2, y2 });
const nextBox = { x1, y1, x2, y2 };
setPoints([]);
setSamPromptBox(nextBox);
setSamCandidateMaskId(null);
runInference([], nextBox, { resetCandidate: true });
}
setBoxStart(null);
@@ -500,6 +613,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (effectiveTool === POLYGON_TOOL) {
const pos = stagePoint(e);
if (pos) {
const closeRadius = POLYGON_CLOSE_RADIUS / Math.max(scale, 0.1);
if (polygonPoints.length >= 3 && pointDistance(pos, polygonPoints[0]) <= closeRadius) {
finishPolygon();
return;
}
setPolygonPoints((current) => [...current, pos]);
}
return;
@@ -514,8 +632,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
];
setPoints(newPoints);
// Auto-trigger inference after point selection
runInference(newPoints);
runInference(newPoints, samPromptBox || undefined);
}
}
};
@@ -535,14 +652,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
});
}, [updateMask]);
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
const updateMaskFromSegmentation = useCallback((
mask: Mask,
segmentation: number[][],
options: { area?: number; hasHoles?: boolean } = {},
): Mask => {
const bbox = segmentationBbox(segmentation);
const metadata = { ...(mask.metadata || {}) };
if (options.hasHoles === true) metadata.hasHoles = true;
if (options.hasHoles === false) delete metadata.hasHoles;
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: segmentationArea(segmentation),
area: options.area ?? segmentationArea(segmentation),
metadata,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
@@ -572,11 +697,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}
return;
}
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask) {
event.preventDefault();
const ids = selectedMaskIds.length > 0 ? selectedMaskIds : [selectedMask.id];
deleteMasksById(ids);
return;
}
if (effectiveTool !== POLYGON_TOOL) return;
if (event.key === 'Enter' && polygonPoints.length >= 3) {
event.preventDefault();
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
finishPolygon();
}
if (event.key === 'Escape') {
event.preventDefault();
@@ -586,7 +716,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null;
@@ -623,8 +753,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== 'move' && !isBooleanTool) return;
event.cancelBubble = true;
if (BOOLEAN_TOOLS.has(effectiveTool)) {
if (isBooleanTool) {
setSelectedMaskIds((current) => (
current.includes(mask.id)
? current.filter((id) => id !== mask.id)
@@ -703,7 +834,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
return;
}
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation, {
area: multiPolygonArea(resultGeometry),
hasHoles: multiPolygonHasHoles(resultGeometry),
});
const secondaryIds = effectiveTool === 'area_merge'
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
: new Set<string>();
@@ -731,6 +865,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
<span className="text-xs text-cyan-400 font-mono">AI ...</span>
</div>
)}
{!isInferencing && inferenceMessage && (
<div className="absolute top-4 right-4 z-20 max-w-xs bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl text-xs leading-relaxed text-gray-300">
{inferenceMessage}
</div>
)}
<Stage
width={stageSize.width}
@@ -758,21 +897,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)}
{/* AI Returned Masks */}
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group>
))}
{frameMasks.map((mask) => {
const hasHoles = Boolean(mask.metadata?.hasHoles);
const paths = hasHoles
? [{ data: segmentationPath(mask.segmentation), polygonIndex: 0, fillRule: 'evenodd' }]
: (mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => ({
data: mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData,
polygonIndex,
fillRule: undefined,
}));
return (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
{paths.map(({ data, polygonIndex, fillRule }) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={data}
fill={mask.color}
fillRule={fillRule}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group>
);
})}
{/* Box selection preview */}
{boxRect && effectiveTool === 'box_select' && (
@@ -804,10 +954,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
key={`poly-point-${index}`}
x={point.x}
y={point.y}
radius={4 / scale}
fill="#22d3ee"
stroke="#ffffff"
radius={(index === 0 && polygonPoints.length >= 3 ? 6 : 4) / scale}
fill={index === 0 && polygonPoints.length >= 3 ? '#facc15' : '#22d3ee'}
stroke={index === 0 && polygonPoints.length >= 3 ? '#fef3c7' : '#ffffff'}
strokeWidth={1 / scale}
onClick={(event: any) => {
if (index !== 0 || polygonPoints.length < 3) return;
event.cancelBubble = true;
finishPolygon();
}}
onTap={(event: any) => {
if (index !== 0 || polygonPoints.length < 3) return;
event.cancelBubble = true;
finishPolygon();
}}
/>
))}
@@ -827,7 +987,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)))}
{/* Polygon edge insertion handles */}
{selectedMask && selectedMaskPoints.map((point, index) => {
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null;
return (
@@ -846,7 +1006,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
})}
{/* Polygon vertex editor */}
{selectedMask && selectedMaskPoints.map((point, index) => (
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
<Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x}
@@ -900,13 +1060,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{frameMasks.length > 0 && (
<div className="absolute bottom-4 right-4 flex gap-2">
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
<button
onClick={handleBooleanOperation}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
{isBooleanTool && (
<div className="flex items-center gap-2">
<span className="text-xs bg-white/5 text-gray-300 border border-white/10 px-2.5 py-1.5 rounded">
{booleanSelectedMasks.length}
</span>
<button
onClick={handleBooleanOperation}
disabled={booleanSelectedMasks.length < 2}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed disabled:hover:bg-emerald-500/10"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
</div>
)}
{activeClass && (
<button

View File

@@ -34,6 +34,65 @@ describe('FrameTimeline', () => {
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('shows current and total timeline time based on project fps', () => {
useStore.setState({
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
expect(screen.getAllByText('00:00.10').length).toBeGreaterThan(0);
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
});
it('changes frames with left and right arrow keys without leaving bounds', () => {
useStore.setState({
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
fireEvent.keyDown(window, { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(2);
fireEvent.keyDown(window, { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(2);
fireEvent.keyDown(window, { key: 'ArrowLeft' });
expect(useStore.getState().currentFrameIndex).toBe(1);
});
it('does not change frames while typing in editable fields', () => {
useStore.setState({
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(
<>
<input aria-label="annotation-name" />
<FrameTimeline />
</>,
);
fireEvent.keyDown(screen.getByLabelText('annotation-name'), { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(1);
});
it('plays forward using the project parse fps and stops at the end', () => {
vi.useFakeTimers();
useStore.setState({

View File

@@ -16,6 +16,20 @@ export function FrameTimeline() {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.min(Math.max(fps, 1), 30);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const timeBaseFps = useMemo(() => {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.max(fps, 1);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const formatTime = (seconds: number) => {
const safeSeconds = Math.max(0, seconds);
const minutes = Math.floor(safeSeconds / 60);
const wholeSeconds = Math.floor(safeSeconds % 60);
const centiseconds = Math.floor((safeSeconds % 1) * 100);
return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`;
};
useEffect(() => {
if (!isPlaying || totalFrames <= 1) return;
@@ -38,6 +52,30 @@ export function FrameTimeline() {
}
}, [totalFrames]);
useEffect(() => {
const isEditableTarget = (target: EventTarget | null) => {
if (!(target instanceof HTMLElement)) return false;
const tagName = target.tagName.toLowerCase();
return target.isContentEditable || ['input', 'textarea', 'select'].includes(tagName);
};
const handleKeyDown = (event: KeyboardEvent) => {
if (isEditableTarget(event.target) || totalFrames <= 1) return;
if (event.key !== 'ArrowLeft' && event.key !== 'ArrowRight') return;
event.preventDefault();
setIsPlaying(false);
const direction = event.key === 'ArrowRight' ? 1 : -1;
const nextIndex = Math.min(Math.max(currentFrameIndex + direction, 0), totalFrames - 1);
if (nextIndex !== currentFrameIndex) {
setCurrentFrame(nextIndex);
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [currentFrameIndex, setCurrentFrame, totalFrames]);
// show frames around current frame
const frameWindow = 20;
const displayIndices = totalFrames > 0
@@ -47,6 +85,12 @@ export function FrameTimeline() {
return (
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
<div className="h-4 bg-[#0d0d0d] flex items-center group relative">
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(currentSeconds)}
</div>
<div className="absolute right-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(totalSeconds)}
</div>
<input
type="range"
min="1"
@@ -65,6 +109,12 @@ export function FrameTimeline() {
className="w-3 h-3 bg-white rounded-full absolute top-1/2 -translate-y-1/2 -ml-1.5 shadow-sm transform scale-0 group-hover:scale-100 transition-transform shadow-cyan-500/50"
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
/>
<div
className="absolute -top-7 -translate-x-1/2 rounded bg-black/80 border border-white/10 px-2 py-0.5 text-[10px] font-mono text-cyan-300 opacity-0 group-hover:opacity-100 transition-opacity pointer-events-none"
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
>
{formatTime(currentSeconds)}
</div>
</div>
</div>
@@ -129,6 +179,9 @@ export function FrameTimeline() {
<div className="w-48 text-right shrink-0">
<div className="text-2xl font-mono text-white">{currentFrame}<span className="text-xs text-gray-500"> / {totalFrames}</span></div>
<div className="text-xs font-mono text-cyan-300 mt-1">
{formatTime(currentSeconds)} <span className="text-gray-600">/</span> {formatTime(totalSeconds)}
</div>
<div className="text-[10px] text-gray-500 uppercase tracking-widest mt-1"></div>
</div>
</div>

View File

@@ -45,6 +45,41 @@ describe('OntologyInspector', () => {
}));
});
it('applies the selected class to currently selected masks', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '旧标签',
color: '#06b6d4',
saveStatus: 'saved',
saved: true,
},
],
});
render(<OntologyInspector />);
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().activeClassId).toBe('c2');
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: 't1',
classId: 'c2',
className: '肝脏',
classZIndex: 10,
label: '肝脏',
color: '#00ff00',
saveStatus: 'dirty',
saved: false,
}));
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
expect(screen.getByText('1')).toBeInTheDocument();
});
it('adds custom classes locally without backend persistence', () => {
const { container } = render(<OntologyInspector />);
const customSection = screen.getByText('自定义分类').parentElement!;

View File

@@ -10,6 +10,9 @@ export function OntologyInspector() {
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClassId = useStore((state) => state.activeClassId);
const activeClass = useStore((state) => state.activeClass);
const masks = useStore((state) => state.masks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setMasks = useStore((state) => state.setMasks);
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
const setActiveClass = useStore((state) => state.setActiveClass);
@@ -28,6 +31,25 @@ export function OntologyInspector() {
setActiveTemplateId(activeTemplate.id);
}
setActiveClass(templateClass);
const selectedIdSet = new Set(selectedMaskIds);
const hasSelectedMasks = masks.some((mask) => selectedIdSet.has(mask.id));
if (!hasSelectedMasks) return;
const templateId = activeTemplate?.id || activeTemplateId || undefined;
setMasks(masks.map((mask) => {
if (!selectedIdSet.has(mask.id)) return mask;
return {
...mask,
templateId: templateId || mask.templateId,
classId: templateClass.id,
className: templateClass.name,
classZIndex: templateClass.zIndex,
label: templateClass.name,
color: templateClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
}));
};
const handleAddCustom = () => {
@@ -164,6 +186,10 @@ export function OntologyInspector() {
</span>
</div>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
</div>
<div className="space-y-1">
<label className="text-[10px] text-gray-500 uppercase"></label>
<div className="h-1.5 w-full bg-white/10 rounded-full overflow-hidden">

View File

@@ -82,4 +82,65 @@ describe('TemplateRegistry', () => {
expect(screen.getByText('分类A')).toBeInTheDocument();
});
it('edits an existing template through the backend and store', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '旧模板',
description: 'old desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
color: '#06b6d4',
z_index: 3,
},
]);
apiMock.updateTemplate.mockResolvedValueOnce({
id: 't1',
name: '新模板',
description: 'new desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
});
render(<TemplateRegistry />);
fireEvent.click(await screen.findByRole('button', { name: /修改库视图结构/ }));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: '新模板' } });
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'new desc' } });
fireEvent.click(screen.getByRole('button', { name: '保存' }));
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
name: '新模板',
description: 'new desc',
classes: [expect.objectContaining({ id: 'c1', name: '胆囊' })],
rules: [],
color: '#06b6d4',
z_index: 3,
})));
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({
id: 't1',
name: '新模板',
}));
});
it('deletes an existing template after confirmation', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '待删除模板',
description: 'desc',
classes: [],
rules: [],
},
]);
apiMock.deleteTemplate.mockResolvedValueOnce(undefined);
const { container } = render(<TemplateRegistry />);
await screen.findAllByText('待删除模板');
const buttons = Array.from(container.querySelectorAll('button'));
fireEvent.click(buttons[2]);
await waitFor(() => expect(apiMock.deleteTemplate).toHaveBeenCalledWith('t1'));
expect(useStore.getState().templates).toEqual([]);
});
});

View File

@@ -7,6 +7,7 @@ import { VideoWorkspace } from './VideoWorkspace';
const apiMock = vi.hoisted(() => ({
getProjectFrames: vi.fn(),
parseMedia: vi.fn(),
propagateMasks: vi.fn(),
getTask: vi.fn(),
getTemplates: vi.fn(),
getProjectAnnotations: vi.fn(),
@@ -24,6 +25,7 @@ const apiMock = vi.hoisted(() => ({
vi.mock('../lib/api', () => ({
getProjectFrames: apiMock.getProjectFrames,
parseMedia: apiMock.parseMedia,
propagateMasks: apiMock.propagateMasks,
getTask: apiMock.getTask,
getTemplates: apiMock.getTemplates,
getProjectAnnotations: apiMock.getProjectAnnotations,
@@ -47,6 +49,14 @@ describe('VideoWorkspace', () => {
apiMock.getProjectAnnotations.mockResolvedValue([]);
apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.propagateMasks.mockResolvedValue({
model: 'sam2',
direction: 'forward',
source_frame_id: 10,
processed_frame_count: 3,
created_annotation_count: 2,
annotations: [],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
@@ -320,4 +330,64 @@ describe('VideoWorkspace', () => {
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
]));
});
it('propagates the selected current-frame mask through the backend video tracker', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({
project_id: 1,
frame_id: 10,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
bbox: [0.1, 0.1, 0.2, 0.2],
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
aiModel: 'sam2',
activeTemplateId: '2',
selectedMaskIds: ['mask-1'],
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[64, 36, 192, 36, 192, 108]],
bbox: [64, 36, 128, 72],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '传播片段' }));
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
project_id: 1,
frame_id: 10,
model: 'sam2',
direction: 'forward',
max_frames: 30,
include_source: false,
save_annotations: true,
seed: {
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
bbox: [0.1, 0.1, 0.2, 0.2],
points: undefined,
label: '胆囊',
color: '#ff0000',
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
template_id: 2,
},
}));
await waitFor(() => expect(screen.getByText('已传播并保存 2 个区域')).toBeInTheDocument());
});
});

View File

@@ -12,6 +12,7 @@ import {
getTemplates,
importGtMask,
parseMedia,
propagateMasks,
saveAnnotation,
updateAnnotation,
} from '../lib/api';
@@ -37,6 +38,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const aiModel = useStore((state) => state.aiModel);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
@@ -45,6 +48,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false);
const [isImportingGt, setIsImportingGt] = useState(false);
const [isPropagating, setIsPropagating] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
@@ -102,6 +106,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
@@ -117,6 +123,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
@@ -314,6 +322,55 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
};
const handlePropagateSegment = async () => {
if (!currentProject?.id || !currentFrame?.id) return;
const currentFrameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
const selectedMask = selectedMaskIds
.map((id) => currentFrameMasks.find((mask) => mask.id === id))
.find((mask): mask is NonNullable<typeof mask> => Boolean(mask));
const seedMask = selectedMask || currentFrameMasks[0];
if (!seedMask) {
setStatusMessage('请先选择或创建一个当前帧区域');
return;
}
const seedPayload = buildAnnotationPayload(currentProject.id, seedMask, currentFrame, activeTemplateId);
if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) {
setStatusMessage('当前区域缺少可传播的 polygon 或 bbox');
return;
}
setIsPropagating(true);
setStatusMessage(`${aiModel.toUpperCase()} 正在传播当前区域...`);
try {
const result = await propagateMasks({
project_id: Number(currentProject.id),
frame_id: Number(currentFrame.id),
model: aiModel,
direction: 'forward',
max_frames: 30,
include_source: false,
save_annotations: true,
seed: {
polygons: seedPayload.mask_data?.polygons,
bbox: seedPayload.bbox,
points: seedPayload.points,
label: seedPayload.mask_data?.label,
color: seedPayload.mask_data?.color,
class_metadata: seedPayload.mask_data?.class,
template_id: seedPayload.template_id,
},
});
await hydrateSavedAnnotations(currentProject.id, frames);
setStatusMessage(`已传播并保存 ${result.created_annotation_count} 个区域`);
} catch (err) {
console.error('Propagation failed:', err);
setStatusMessage('传播失败,请检查模型状态或后端日志');
} finally {
setIsPropagating(false);
}
};
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
{/* Top Header / Status bar */}
@@ -339,28 +396,35 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
/>
<button
onClick={() => gtMaskInputRef.current?.click()}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isImportingGt ? '导入中...' : '导入 GT Mask'}
</button>
<button
onClick={handlePropagateSegment}
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isPropagating ? '传播中...' : '传播片段'}
</button>
<button
onClick={handleExportMasks}
disabled={!currentProject?.id || isExporting || isSaving}
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
</button>
<button
onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving}
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 JSON 标注集'}
</button>
<button
onClick={handleSave}
disabled={!currentProject?.id || isSaving || isExporting}
disabled={!currentProject?.id || isSaving || isExporting || isPropagating}
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
{isSaving ? '保存中...' : '结构化归档保存'}

View File

@@ -159,9 +159,9 @@ describe('api client contracts', () => {
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
await expect(parseMedia('9')).resolves.toEqual(task);
await expect(parseMedia('9', { parseFps: 15, maxFrames: 120, targetWidth: 960 })).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
params: { project_id: '9' },
params: { project_id: '9', parse_fps: 15, max_frames: 120, target_width: 960 },
});
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
@@ -175,7 +175,7 @@ describe('api client contracts', () => {
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
const saved = {
id: 1,
project_id: 9,
@@ -221,6 +221,43 @@ describe('api client contracts', () => {
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
axiosMock.client.post.mockResolvedValueOnce({
data: {
model: 'sam2',
direction: 'forward',
source_frame_id: 5,
processed_frame_count: 3,
created_annotation_count: 2,
annotations: [saved],
},
});
await expect(propagateMasks({
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
})).resolves.toEqual(expect.objectContaining({ created_annotation_count: 2 }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
}, {
timeout: 600000,
});
});
it('imports GT masks through multipart form data', async () => {
@@ -377,6 +414,33 @@ describe('api client contracts', () => {
});
});
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
points: [
{ x: 128, y: 64, type: 'pos' },
{ x: 256, y: 128, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'interactive',
prompt_data: {
box: [0.1, 0.1, 0.5, 0.5],
points: [[0.2, 0.2], [0.4, 0.4]],
labels: [1, 0],
},
model: 'sam2',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });

View File

@@ -153,6 +153,8 @@ export async function getProjectFrames(projectId: string): Promise<Array<{
image_url: string;
width: number | null;
height: number | null;
timestamp_ms?: number | null;
source_frame_number?: number | null;
}>> {
const response = await apiClient.get(`/api/projects/${projectId}/frames`);
return response.data;
@@ -185,9 +187,18 @@ export interface ProcessingTask {
updated_at: string;
}
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
export async function parseMedia(projectId: string, options: {
parseFps?: number;
maxFrames?: number;
targetWidth?: number;
} = {}): Promise<ProcessingTask> {
const response = await apiClient.post('/api/media/parse', null, {
params: { project_id: projectId },
params: {
project_id: projectId,
...(options.parseFps ? { parse_fps: options.parseFps } : {}),
...(options.maxFrames ? { max_frames: options.maxFrames } : {}),
...(options.targetWidth ? { target_width: options.targetWidth } : {}),
},
});
return response.data;
}
@@ -312,6 +323,40 @@ export interface SaveAnnotationPayload {
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
export interface PropagateMasksPayload {
project_id: number;
frame_id: number;
model?: AiModelId;
seed: {
polygons?: number[][][];
bbox?: number[];
points?: number[][];
label?: string;
color?: string;
class_metadata?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
template_id?: number;
};
direction?: 'forward' | 'backward' | 'both';
max_frames?: number;
include_source?: boolean;
save_annotations?: boolean;
}
export interface PropagateMasksResult {
model: AiModelId;
direction: string;
source_frame_id: number;
processed_frame_count: number;
created_annotation_count: number;
annotations: SavedAnnotation[];
}
export interface DashboardTask {
id: string;
task_id?: number;
@@ -474,10 +519,22 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
}
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
let prompt_type: 'point' | 'box' | 'semantic';
let prompt_type: 'point' | 'box' | 'semantic' | 'interactive';
let prompt_data: unknown;
if (payload.box) {
if (payload.box && payload.points && payload.points.length > 0) {
prompt_type = 'interactive';
prompt_data = {
box: [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
],
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
};
} else if (payload.box) {
prompt_type = 'box';
prompt_data = [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
@@ -540,6 +597,13 @@ export async function getProjectAnnotations(projectId: string, frameId?: string)
return response.data;
}
export async function propagateMasks(payload: PropagateMasksPayload): Promise<PropagateMasksResult> {
const response = await apiClient.post('/api/ai/propagate', payload, {
timeout: 600000,
});
return response.data;
}
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.post('/api/ai/annotate', payload);
return response.data;

View File

@@ -30,6 +30,7 @@ describe('useStore', () => {
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
useStore.getState().setCurrentFrame(0);
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
useStore.getState().setSelectedMaskIds(['m1']);
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
@@ -40,6 +41,7 @@ describe('useStore', () => {
expect(useStore.getState().currentProject?.id).toBe('1');
expect(useStore.getState().frames).toHaveLength(1);
expect(useStore.getState().currentFrameIndex).toBe(0);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
expect(useStore.getState().annotations).toHaveLength(1);
expect(useStore.getState().templates[0].name).toBe('Template 2');
@@ -51,6 +53,7 @@ describe('useStore', () => {
expect(useStore.getState().annotations).toEqual([]);
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().selectedMaskIds).toEqual([]);
expect(useStore.getState().templates).toEqual([]);
});

View File

@@ -27,6 +27,8 @@ export interface Frame {
width: number;
height: number;
timestamp?: string;
timestampMs?: number;
sourceFrameNumber?: number;
}
export interface Annotation {
@@ -112,6 +114,7 @@ export interface AppState {
currentFrameIndex: number;
annotations: Annotation[];
masks: Mask[];
selectedMaskIds: string[];
maskHistory: Mask[][];
maskFuture: Mask[][];
setActiveModule: (module: string) => void;
@@ -123,6 +126,7 @@ export interface AppState {
addMask: (mask: Mask) => void;
updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void;
setSelectedMaskIds: (ids: string[]) => void;
clearMasks: () => void;
undoMasks: () => void;
redoMasks: () => void;
@@ -167,6 +171,7 @@ export const useStore = create<AppState>((set) => ({
frames: [],
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
activeTemplateId: null,
@@ -195,6 +200,7 @@ export const useStore = create<AppState>((set) => ({
currentFrameIndex: 0,
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
@@ -227,9 +233,11 @@ export const useStore = create<AppState>((set) => ({
maskFuture: [],
};
}),
setSelectedMaskIds: (selectedMaskIds: string[]) => set({ selectedMaskIds }),
clearMasks: () =>
set((state) => ({
masks: [],
selectedMaskIds: [],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),

View File

@@ -71,7 +71,11 @@ vi.mock('react-konva', () => ({
data-fill={props.fill}
data-x={props.x}
data-y={props.y}
onClick={() => props.onClick?.({ cancelBubble: false })}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
@@ -92,7 +96,12 @@ vi.mock('react-konva', () => ({
data-testid="konva-path"
data-path={props.data}
data-fill={props.fill}
onClick={() => props.onClick?.({ cancelBubble: false })}
data-fill-rule={props.fillRule}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
/>
),
}));

View File

@@ -13,6 +13,7 @@ export function resetStore() {
currentFrameIndex: 0,
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
templates: [],