From 5ab460253527690c23d3746ba603e5fa3a0e2fb3 Mon Sep 17 00:00:00 2001 From: admin <572701190@qq.com> Date: Fri, 1 May 2026 20:27:33 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E8=A7=86=E9=A2=91?= =?UTF-8?q?=E4=BC=A0=E6=92=AD=E3=80=81=E6=A0=87=E6=B3=A8=E7=BC=96=E8=BE=91?= =?UTF-8?q?=E5=92=8C=E6=8B=86=E5=B8=A7=E9=97=AD=E7=8E=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。 --- .env.example | 3 +- .gitignore | 1 + AGENTS.md | 29 +- README.md | 20 +- backend/config.py | 1 + backend/main.py | 20 ++ backend/models.py | 2 + backend/routers/ai.py | 210 ++++++++++++ backend/routers/media.py | 14 +- backend/schemas.py | 33 ++ backend/services/frame_parser.py | 13 +- backend/services/media_task_runner.py | 64 +++- backend/services/sam2_engine.py | 206 +++++++++++- backend/services/sam3_engine.py | 188 +++++++++-- backend/services/sam3_external_worker.py | 191 +++++++++-- backend/services/sam_registry.py | 30 ++ backend/tests/test_ai.py | 112 +++++++ backend/tests/test_media.py | 54 +++- backend/tests/test_sam3_engine.py | 135 +++++++- doc/01-purpose-and-word-summary.md | 6 +- doc/03-frontend-element-audit.md | 21 +- doc/04-api-contracts.md | 60 +++- doc/05-implementation-plan.md | 20 +- doc/07-current-requirements-freeze.md | 23 +- doc/08-current-design-freeze.md | 89 +++-- doc/09-test-plan.md | 45 ++- src/components/AISegmentation.test.tsx | 110 ++++++- src/components/AISegmentation.tsx | 30 +- src/components/CanvasArea.test.tsx | 375 ++++++++++++++++++++++ src/components/CanvasArea.tsx | 290 +++++++++++++---- src/components/FrameTimeline.test.tsx | 59 ++++ src/components/FrameTimeline.tsx | 53 +++ src/components/OntologyInspector.test.tsx | 35 ++ src/components/OntologyInspector.tsx | 26 ++ src/components/TemplateRegistry.test.tsx | 61 ++++ src/components/VideoWorkspace.test.tsx | 70 ++++ src/components/VideoWorkspace.tsx | 72 ++++- src/lib/api.test.ts | 70 +++- src/lib/api.ts | 72 ++++- src/store/useStore.test.ts | 3 + src/store/useStore.ts | 8 + src/test/setup.tsx | 13 +- src/test/storeTestUtils.ts | 1 + 43 files changed, 2722 insertions(+), 216 deletions(-) diff --git a/.env.example b/.env.example index 5de43e1..231f2f5 100644 --- a/.env.example +++ b/.env.example @@ -19,4 +19,5 @@ VITE_WS_PROGRESS_URL="ws://192.168.3.11:8000/ws/progress" sam_default_model="sam2" sam_model_path="/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_config="configs/sam2/sam2_hiera_t.yaml" -sam3_model_version="sam3.1" +sam3_model_version="sam3" +sam3_checkpoint_path="/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt" diff --git a/.gitignore b/.gitignore index c399003..ae1883c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ coverage/ !.env.example # Data & Models models/ +sam3权重/ uploads/ frames/ minio_data/ diff --git a/AGENTS.md b/AGENTS.md index c2894f0..e7e8cc7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,7 +6,7 @@ ## 项目概述 -本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 +本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、视频片段传播、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 - **项目名称**: `react-example`(`package.json` 中的 `name`) - **前端入口**: `src/main.tsx` → `src/App.tsx` @@ -39,7 +39,7 @@ | 缓存 / 队列 Broker | Redis | | 后台任务 | Celery worker | | 对象存储 | MinIO | -| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 | +| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/本地 checkpoint 状态 | | 视频 / 影像处理 | FFmpeg / OpenCV / pydicom | | 运行时 | Node.js ES Modules;Python 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 | @@ -78,12 +78,12 @@ Seg_Server/ │ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames │ │ ├── templates.py # /api/templates │ │ ├── media.py # /api/media/upload、/upload/dicom、/parse -│ │ ├── ai.py # /api/ai/predict、/models/status、/auto、/annotate +│ │ ├── ai.py # /api/ai/predict、/propagate、/models/status、/auto、/annotate │ │ └── export.py # /api/export/{project_id}/coco、/masks │ └── services/ │ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传 -│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback -│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器 +│ ├── sam2_engine.py # SAM 2 单帧推理和 video predictor 传播封装 +│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器 │ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 └── src/ # React 前端 @@ -194,6 +194,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - `POST /api/tasks/{task_id}/cancel` - `POST /api/tasks/{task_id}/retry` - `POST /api/ai/predict` + - `POST /api/ai/propagate` - `GET /api/ai/models/status` - `POST /api/ai/auto` - `POST /api/ai/annotate` @@ -219,14 +220,15 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload 1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。 2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。 3. 上传资源:视频走 `/api/media/upload`;DICOM 批量走 `/api/media/upload/dicom`。 -4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。 -5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。 -6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。 -7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/difference;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 -8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割;`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为不可用。 -9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。 -10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 -11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。 +4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数。 +5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。 +6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。 +7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 +8. AI 分割:前端工具包括正向点、反向点和框选;SAM 2 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/interactive/自动分割和 video predictor 传播;`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理、框选提示和 external video tracker,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境,并优先使用 `sam3_checkpoint_path` 指向的本地 `sam3权重/sam3.pt`;如果 Python/CUDA/包/本地 checkpoint 均满足,会在状态接口中标为可用。 +9. 视频片段传播:工作区“传播片段”把当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;后端按项目帧序列下载片段帧,SAM 2 用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 用独立 helper 的官方 `build_sam3_video_predictor()`,并把后续帧结果保存为 `Annotation`。 +10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。 +11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 +12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。 --- @@ -240,6 +242,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 - 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 - 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。 +- 工作区“传播片段”按钮已接入 `POST /api/ai/propagate`;SAM 2 路径使用视频 predictor,SAM 3 路径使用独立 Python helper 的官方 video tracker,完成后刷新后端已保存标注。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。 - `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。 diff --git a/README.md b/README.md index 9061762..ce1f5a9 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,14 @@ > 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。 > -> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。 +> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、视频片段传播、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,SAM 3 支持语义文本、框选提示和 video tracker,前端会显示真实 GPU/模型状态。 --- ## 核心功能 - **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析 -- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性 +- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)、自动分割(auto)和 video predictor 传播,SAM 3 入口支持文本语义提示、框选提示和 external video tracker,并按真实运行环境显示可用性 - **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 - **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index) @@ -104,8 +104,8 @@ Seg_Server/ │ │ ├── ai.py # SAM 推理与模型状态接口 │ │ └── export.py # 数据导出 │ └── services/ # 业务服务 -│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级) -│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器 +│ ├── sam2_engine.py # SAM 2 推理引擎(单帧推理 + video predictor 传播) +│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接、文本语义推理、框选与 video tracker 适配器 │ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片 @@ -255,12 +255,11 @@ python download_sam2.py cd ~/Desktop/Seg_Server ./backend/setup_sam3_env.sh -# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录 -conda activate sam3 -huggingface-cli login +# 如果已把权重放在 sam3权重/sam3.pt,可直接走本地 checkpoint; +# 未配置本地 checkpoint 时,才需要 Hugging Face gated repo 授权和登录。 ``` -官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度;`facebook/sam3.1` 约 3.5 GB,主要面向新的视频 multiplex checkpoint。未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。 +官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度。当前仓库默认使用本机 `sam3权重/sam3.pt`,不会提交权重文件;未配置本地 checkpoint 且未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。 ### 步骤 6: 配置环境变量 @@ -278,6 +277,7 @@ sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt sam_model_config=configs/sam2/sam2_hiera_t.yaml sam_default_model=sam2 sam3_model_version=sam3 +sam3_checkpoint_path=/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt sam3_external_enabled=true sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python sam3_timeout_seconds=300 @@ -311,6 +311,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 & - 测试 Redis 连接 - 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态 - `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤 +- `/api/ai/propagate` 支持从当前帧 seed 区域向视频片段传播:SAM 2 使用 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker ### 步骤 6.1: 启动 Celery Worker @@ -324,7 +325,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & ``` -`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 +`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。接口支持 `parse_fps`、`max_frames` 和 `target_width`,用于生成后续 SAM 2 / SAM 3 视频处理可复用的标准帧序列;视频帧按 `frame_%06d.jpg` 连续命名,帧表会记录 `timestamp_ms` 和 `source_frame_number`,任务完成结果会返回 `frame_sequence` 元数据。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 ### 步骤 7: 安装前端依赖并构建 @@ -460,6 +461,7 @@ pip install -e . --no-build-isolation - 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。 - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。 +- 工作区“传播片段”会使用当前选中区域或当前帧第一个区域作为 seed,调用 `/api/ai/propagate`,并在完成后刷新已保存标注。 - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 - 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。 diff --git a/backend/config.py b/backend/config.py index d305bea..ca6d457 100644 --- a/backend/config.py +++ b/backend/config.py @@ -23,6 +23,7 @@ class Settings(BaseSettings): sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml" sam3_model_version: str = "sam3" + sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt" sam3_external_enabled: bool = True sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python" sam3_timeout_seconds: int = 300 diff --git a/backend/main.py b/backend/main.py index 9ddc5fc..c2d9244 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,6 +11,7 @@ from datetime import datetime, timezone from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import inspect, text from config import settings from database import Base, engine, SessionLocal @@ -30,6 +31,20 @@ logger = logging.getLogger(__name__) DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4" +def _ensure_runtime_schema_columns() -> None: + """Add nullable columns introduced after initial create_all deployments.""" + try: + inspector = inspect(engine) + frame_columns = {column["name"] for column in inspector.get_columns("frames")} + with engine.begin() as connection: + if "timestamp_ms" not in frame_columns: + connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT")) + if "source_frame_number" not in frame_columns: + connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER")) + except Exception as exc: # noqa: BLE001 + logger.warning("Runtime schema column check failed: %s", exc) + + def _seed_default_project_sync() -> None: """Synchronously seed the default video project on first startup.""" import cv2 @@ -93,12 +108,16 @@ def _seed_default_project_sync() -> None: for idx, obj_name in enumerate(object_names): img = cv2.imread(frame_files[idx]) h, w = img.shape[:2] if img is not None else (None, None) + timestamp_ms = idx * 1000.0 / 30.0 + source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None frame = Frame( project_id=project.id, frame_index=idx, image_url=obj_name, width=w, height=h, + timestamp_ms=timestamp_ms, + source_frame_number=source_frame_number, ) db.add(frame) @@ -176,6 +195,7 @@ async def lifespan(app: FastAPI): # Initialize database tables try: Base.metadata.create_all(bind=engine) + _ensure_runtime_schema_columns() logger.info("Database tables initialized.") except Exception as exc: # noqa: BLE001 logger.error("Database initialization failed: %s", exc) diff --git a/backend/models.py b/backend/models.py index 9dbf08d..c84dc70 100644 --- a/backend/models.py +++ b/backend/models.py @@ -56,6 +56,8 @@ class Frame(Base): image_url = Column(String(512), nullable=False) width = Column(Integer, nullable=True) height = Column(Integer, nullable=True) + timestamp_ms = Column(Float, nullable=True) + source_frame_number = Column(Integer, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) project = relationship("Project", back_populates="frames") diff --git a/backend/routers/ai.py b/backend/routers/ai.py index c785ef2..e7d51ae 100644 --- a/backend/routers/ai.py +++ b/backend/routers/ai.py @@ -1,6 +1,8 @@ """AI inference endpoints using selectable SAM runtimes.""" import logging +import tempfile +from pathlib import Path from typing import Any, List import cv2 @@ -15,6 +17,8 @@ from schemas import ( AiRuntimeStatus, PredictRequest, PredictResponse, + PropagateRequest, + PropagateResponse, AnnotationOut, AnnotationCreate, AnnotationUpdate, @@ -66,6 +70,48 @@ def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]: ] +def _polygon_bbox(polygon: list[list[float]]) -> list[float]: + xs = [_clamp01(point[0]) for point in polygon] + ys = [_clamp01(point[1]) for point in polygon] + left, right = min(xs), max(xs) + top, bottom = min(ys), max(ys) + return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)] + + +def _frame_window( + frames: list[Frame], + source_position: int, + direction: str, + max_frames: int, +) -> tuple[list[Frame], int]: + count = max(1, min(max_frames, len(frames))) + if direction == "backward": + start = max(0, source_position - count + 1) + return frames[start:source_position + 1], source_position - start + if direction == "both": + before = (count - 1) // 2 + after = count - 1 - before + start = max(0, source_position - before) + end = min(len(frames), source_position + after + 1) + while end - start < count and start > 0: + start -= 1 + while end - start < count and end < len(frames): + end += 1 + return frames[start:end], source_position - start + end = min(len(frames), source_position + count) + return frames[source_position:end], 0 + + +def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]: + paths = [] + for index, frame in enumerate(frames): + data = download_file(frame.image_url) + path = directory / f"frame_{index:06d}.jpg" + path.write_bytes(data) + paths.append(str(path)) + return paths + + def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]: """Reduce a binary component to one positive prompt point using distance transform.""" dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5) @@ -184,6 +230,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: - **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. + - **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`. - **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto. """ frame = db.query(Frame).filter(Frame.id == payload.image_id).first() @@ -246,6 +293,51 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: if crop_bounds: polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] + elif prompt_type == "interactive": + prompt = payload.prompt_data + if not isinstance(prompt, dict): + raise HTTPException(status_code=400, detail="Invalid interactive prompt data") + box = prompt.get("box") + points = prompt.get("points") or [] + labels = prompt.get("labels") + if box is not None and (not isinstance(box, list) or len(box) != 4): + raise HTTPException(status_code=400, detail="Invalid interactive box prompt data") + if not isinstance(points, list): + raise HTTPException(status_code=400, detail="Invalid interactive point prompt data") + if not box and len(points) == 0: + raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points") + if not isinstance(labels, list) or len(labels) != len(points): + labels = [1] * len(points) + negative_points = [ + point for point, label in zip(points, labels) if label == 0 + ] + inference_image = image + inference_box = box + inference_points = points + crop_bounds = None + if options.get("crop_to_prompt"): + margin = float(options.get("crop_margin", 0.05) or 0.05) + crop_points = list(points) + if box: + crop_points.extend([[box[0], box[1]], [box[2], box[3]]]) + crop_bounds = _crop_bounds_from_points(crop_points, margin) + inference_image = _crop_image(image, crop_bounds) + inference_points = [_to_crop_point(point, crop_bounds) for point in points] + if box: + inference_box = [ + *_to_crop_point([box[0], box[1]], crop_bounds), + *_to_crop_point([box[2], box[3]], crop_bounds), + ] + polygons, scores = sam_registry.predict_interactive( + payload.model, + inference_image, + inference_box, + inference_points, + labels, + ) + if crop_bounds: + polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] + elif prompt_type == "semantic": text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" polygons, scores = sam_registry.predict_semantic(payload.model, image, text) @@ -276,6 +368,124 @@ def model_status(selected_model: str | None = None) -> dict: raise HTTPException(status_code=400, detail=str(exc)) from exc +@router.post( + "/propagate", + response_model=PropagateResponse, + summary="Propagate one current-frame region across a video frame segment", +) +def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict: + """Track one selected region from the current frame across nearby frames. + + SAM 2 uses the official video predictor with the selected mask as the seed. + SAM 3 uses the external Python 3.12 video tracker with the seed bbox. + """ + direction = payload.direction.lower() + if direction not in {"forward", "backward", "both"}: + raise HTTPException(status_code=400, detail="direction must be forward, backward, or both") + max_frames = max(1, min(int(payload.max_frames or 30), 500)) + + project = db.query(Project).filter(Project.id == payload.project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + source_frame = db.query(Frame).filter( + Frame.id == payload.frame_id, + Frame.project_id == payload.project_id, + ).first() + if not source_frame: + raise HTTPException(status_code=404, detail="Frame not found") + + seed = payload.seed.model_dump(exclude_none=True) + polygons = seed.get("polygons") or [] + bbox = seed.get("bbox") + points = seed.get("points") or [] + if not polygons and not bbox and not points: + raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points") + + frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all() + source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None) + if source_position is None: + raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence") + + selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames) + if len(selected_frames) == 0: + raise HTTPException(status_code=400, detail="No frames available for propagation") + + try: + with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir: + frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir)) + propagated = sam_registry.propagate_video( + payload.model, + frame_paths, + source_relative_index, + seed, + direction, + len(selected_frames), + ) + except ModelUnavailableError as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc + except NotImplementedError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: # noqa: BLE001 + logger.error("Video propagation failed: %s", exc) + raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc + + created: list[Annotation] = [] + if payload.save_annotations: + class_metadata = seed.get("class_metadata") + template_id = seed.get("template_id") + label = seed.get("label") or "Propagated Mask" + color = seed.get("color") or "#06b6d4" + model_id = sam_registry.normalize_model_id(payload.model) + + for frame_result in propagated: + relative_index = int(frame_result.get("frame_index", -1)) + if relative_index < 0 or relative_index >= len(selected_frames): + continue + frame = selected_frames[relative_index] + if not payload.include_source and frame.id == source_frame.id: + continue + result_polygons = frame_result.get("polygons") or [] + scores = frame_result.get("scores") or [] + for polygon_index, polygon in enumerate(result_polygons): + if len(polygon) < 3: + continue + annotation = Annotation( + project_id=payload.project_id, + frame_id=frame.id, + template_id=template_id, + mask_data={ + "polygons": [polygon], + "label": label, + "color": color, + "source": f"{model_id}_propagation", + "propagated_from_frame_id": source_frame.id, + "propagated_from_frame_index": source_frame.frame_index, + "score": scores[polygon_index] if polygon_index < len(scores) else None, + **({"class": class_metadata} if class_metadata else {}), + }, + points=None, + bbox=_polygon_bbox(polygon), + ) + db.add(annotation) + created.append(annotation) + + db.commit() + for annotation in created: + db.refresh(annotation) + + return { + "model": sam_registry.normalize_model_id(payload.model), + "direction": direction, + "source_frame_id": source_frame.id, + "processed_frame_count": len(selected_frames), + "created_annotation_count": len(created), + "annotations": created, + } + + @router.post( "/auto", response_model=PredictResponse, diff --git a/backend/routers/media.py b/backend/routers/media.py index 3ec77ff..fbf8ffc 100644 --- a/backend/routers/media.py +++ b/backend/routers/media.py @@ -4,7 +4,7 @@ import logging from pathlib import Path from typing import List, Optional -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status from sqlalchemy.orm import Session from database import get_db @@ -169,6 +169,9 @@ async def upload_dicom_batch( def parse_media( project_id: int, source_type: Optional[str] = None, + parse_fps: Optional[float] = Query(None, gt=0, le=120), + max_frames: Optional[int] = Query(None, gt=0), + target_width: int = Query(640, ge=64, le=4096), db: Session = Depends(get_db), ) -> ProcessingTask: """Create a background task for media frame extraction. @@ -184,14 +187,21 @@ def parse_media( raise HTTPException(status_code=400, detail="Project has no media uploaded") effective_source = source_type or project.source_type or "video" + effective_parse_fps = parse_fps or project.parse_fps or 30.0 task = ProcessingTask( task_type=f"parse_{effective_source}", status=TASK_STATUS_QUEUED, progress=0, message="解析任务已入队", project_id=project_id, - payload={"source_type": effective_source}, + payload={ + "source_type": effective_source, + "parse_fps": effective_parse_fps, + "max_frames": max_frames, + "target_width": target_width, + }, ) + project.parse_fps = effective_parse_fps project.status = PROJECT_STATUS_PARSING db.add(task) db.commit() diff --git a/backend/schemas.py b/backend/schemas.py index 1654a36..06f6966 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -51,6 +51,8 @@ class FrameBase(BaseModel): image_url: str width: Optional[int] = None height: Optional[int] = None + timestamp_ms: Optional[float] = None + source_frame_number: Optional[int] = None class FrameCreate(FrameBase): @@ -188,6 +190,37 @@ class PredictResponse(BaseModel): scores: Optional[list[float]] = None +class PropagationSeed(BaseModel): + polygons: Optional[list[list[list[float]]]] = None + bbox: Optional[list[float]] = None + points: Optional[list[list[float]]] = None + labels: Optional[list[int]] = None + label: Optional[str] = None + color: Optional[str] = None + class_metadata: Optional[dict[str, Any]] = None + template_id: Optional[int] = None + + +class PropagateRequest(BaseModel): + project_id: int + frame_id: int + model: Optional[str] = "sam2" + seed: PropagationSeed + direction: str = "forward" + max_frames: int = 30 + include_source: bool = False + save_annotations: bool = True + + +class PropagateResponse(BaseModel): + model: str + direction: str + source_frame_id: int + processed_frame_count: int + created_annotation_count: int + annotations: list[AnnotationOut] + + class AiModelStatus(BaseModel): id: str label: str diff --git a/backend/services/frame_parser.py b/backend/services/frame_parser.py index 349a0d4..de521c8 100644 --- a/backend/services/frame_parser.py +++ b/backend/services/frame_parser.py @@ -52,6 +52,7 @@ def parse_video( output_dir: str, fps: int = 30, max_frames: Optional[int] = None, + target_width: int = 640, ) -> Tuple[List[str], float]: """Extract frames from a video file using FFmpeg or OpenCV fallback. @@ -60,6 +61,7 @@ def parse_video( output_dir: Directory to save extracted frames. fps: Target frame extraction rate. max_frames: Optional maximum number of frames to extract. + target_width: Output frame width for model-friendly frame sequences. Returns: Tuple of (frame_paths, original_fps). @@ -67,6 +69,8 @@ def parse_video( os.makedirs(output_dir, exist_ok=True) frame_paths: List[str] = [] original_fps = get_video_fps(video_path) + safe_fps = max(int(fps), 1) + safe_width = max(int(target_width), 1) # Try FFmpeg first if shutil.which("ffmpeg"): @@ -75,7 +79,8 @@ def parse_video( cmd = [ "ffmpeg", "-i", video_path, - "-vf", f"fps={fps},scale=640:-1", + "-vf", f"fps={safe_fps},scale={safe_width}:-1", + "-start_number", "0", "-q:v", "5", "-y", pattern, @@ -102,7 +107,7 @@ def parse_video( raise RuntimeError(f"Cannot open video: {video_path}") video_fps = cap.get(cv2.CAP_PROP_FPS) or 30 - interval = max(1, int(round(video_fps / fps))) + interval = max(1, int(round(video_fps / safe_fps))) count = 0 saved = 0 @@ -112,6 +117,10 @@ def parse_video( break if count % interval == 0: path = os.path.join(output_dir, f"frame_{saved:06d}.jpg") + h, w = frame.shape[:2] + if safe_width > 0 and w != safe_width: + scale = safe_width / max(w, 1) + frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA) cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80]) frame_paths.append(path) saved += 1 diff --git a/backend/services/media_task_runner.py b/backend/services/media_task_runner.py index 9db8989..f40b498 100644 --- a/backend/services/media_task_runner.py +++ b/backend/services/media_task_runner.py @@ -76,6 +76,38 @@ def _project_status_after_stop(project: Project) -> str: return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING +def _positive_int(value: Any, default: int | None = None) -> int | None: + try: + parsed = int(value) + except (TypeError, ValueError): + return default + return parsed if parsed > 0 else default + + +def _positive_float(value: Any, default: float) -> float: + try: + parsed = float(value) + except (TypeError, ValueError): + return default + return parsed if parsed > 0 else default + + +def _frame_sequence_metadata( + index: int, + parse_fps: float, + original_fps: float | None, +) -> dict[str, float | int | None]: + safe_parse_fps = max(float(parse_fps or 1.0), 1e-6) + timestamp_ms = index * 1000.0 / safe_parse_fps + source_frame_number = None + if original_fps and original_fps > 0: + source_frame_number = int(round(index * original_fps / safe_parse_fps)) + return { + "timestamp_ms": timestamp_ms, + "source_frame_number": source_frame_number, + } + + def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None: db.refresh(task) if task.status == TASK_STATUS_CANCELLED: @@ -138,8 +170,12 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: project.status = PROJECT_STATUS_PARSING _set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True) - effective_source = (task.payload or {}).get("source_type") or project.source_type or "video" - parse_fps = project.parse_fps or 30.0 + payload = task.payload or {} + effective_source = payload.get("source_type") or project.source_type or "video" + parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0) + max_frames = _positive_int(payload.get("max_frames")) + target_width = _positive_int(payload.get("target_width"), 640) or 640 + project.parse_fps = parse_fps tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_") output_dir = os.path.join(tmp_dir, "frames") os.makedirs(output_dir, exist_ok=True) @@ -163,7 +199,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=35, message="正在解析 DICOM 序列") - frame_files = parse_dicom(dcm_dir, output_dir) + frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames) else: _ensure_not_cancelled(db, task) media_bytes = download_file(project.video_path) @@ -173,7 +209,13 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧") - frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps)) + frame_files, original_fps = parse_video( + local_path, + output_dir, + fps=int(parse_fps), + max_frames=max_frames, + target_width=target_width, + ) project.original_fps = original_fps thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg") @@ -205,12 +247,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: except Exception: # noqa: BLE001 h, w = None, None + sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps) frame = Frame( project_id=project.id, frame_index=idx, image_url=obj_name, width=w, height=h, + timestamp_ms=sequence_meta["timestamp_ms"], + source_frame_number=sequence_meta["source_frame_number"], ) db.add(frame) frames_out.append(frame) @@ -223,6 +268,17 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: "frames_extracted": len(frames_out), "status": PROJECT_STATUS_READY, "message": "Frame extraction completed successfully.", + "frame_sequence": { + "original_fps": project.original_fps, + "parse_fps": parse_fps, + "frame_count": len(frames_out), + "duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0, + "target_width": target_width, + "frame_width": frames_out[0].width if frames_out else None, + "frame_height": frames_out[0].height if frames_out else None, + "max_frames": max_frames, + "object_prefix": f"projects/{project.id}/frames", + }, } _set_task_state( db, diff --git a/backend/services/sam2_engine.py b/backend/services/sam2_engine.py index fb25a8a..a527b6e 100644 --- a/backend/services/sam2_engine.py +++ b/backend/services/sam2_engine.py @@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001 try: from sam2.build_sam import build_sam2 + from sam2.build_sam import build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor SAM2_AVAILABLE = True @@ -38,9 +39,12 @@ class SAM2Engine: def __init__(self) -> None: self._predictor: Optional[SAM2ImagePredictor] = None + self._video_predictor = None self._model_loaded = False + self._video_model_loaded = False self._loaded_device: str | None = None self._last_error: str | None = None + self._video_last_error: str | None = None # ----------------------------------------------------------------------- # Internal helpers @@ -85,6 +89,40 @@ class SAM2Engine: logger.error("Failed to load SAM 2 model: %s", exc) self._model_loaded = True # Prevent repeated load attempts + def _load_video_model(self) -> None: + """Load the SAM 2 video predictor on first propagation use.""" + if self._video_model_loaded: + return + + if not TORCH_AVAILABLE: + self._video_last_error = "PyTorch is not installed." + self._video_model_loaded = True + return + if not SAM2_AVAILABLE: + self._video_last_error = "sam2 package is not installed." + self._video_model_loaded = True + return + if not os.path.isfile(settings.sam_model_path): + self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}" + self._video_model_loaded = True + return + + try: + device = self._best_device() + self._video_predictor = build_sam2_video_predictor( + settings.sam_model_config, + settings.sam_model_path, + device=device, + ) + self._video_model_loaded = True + self._loaded_device = device + self._video_last_error = None + logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device) + except Exception as exc: # noqa: BLE001 + self._video_last_error = str(exc) + self._video_model_loaded = True + logger.error("Failed to load SAM 2 video predictor: %s", exc) + def _best_device(self) -> str: if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available(): return "cuda" @@ -95,6 +133,11 @@ class SAM2Engine: self._load_model() return SAM2_AVAILABLE and self._predictor is not None + def _ensure_video_ready(self) -> bool: + """Ensure the video predictor is loaded; return whether it is usable.""" + self._load_video_model() + return SAM2_AVAILABLE and self._video_predictor is not None + def status(self) -> dict: """Return lightweight, real runtime status without forcing model load.""" checkpoint_exists = os.path.isfile(settings.sam_model_path) @@ -121,7 +164,7 @@ class SAM2Engine: "available": available, "loaded": self._predictor is not None, "device": device, - "supports": ["point", "box", "auto"], + "supports": ["point", "box", "interactive", "auto", "propagate"], "message": message, "package_available": SAM2_AVAILABLE, "checkpoint_exists": checkpoint_exists, @@ -221,6 +264,52 @@ class SAM2Engine: logger.error("SAM2 box prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + def predict_interactive( + self, + image: np.ndarray, + box: list[float] | None, + points: list[list[float]], + labels: list[int], + ) -> tuple[list[list[list[float]]], list[float]]: + """Run combined box and point prompt segmentation for refinement.""" + if not self._ensure_ready(): + logger.warning("SAM2 not ready; returning dummy masks.") + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + try: + h, w = image.shape[:2] + bbox = None + if box: + bbox = np.array( + [box[0] * w, box[1] * h, box[2] * w, box[3] * h], + dtype=np.float32, + ) + pts = None + lbls = None + if points: + pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32) + lbls = np.array(labels, dtype=np.int32) + + with torch.inference_mode(): # type: ignore[name-defined] + self._predictor.set_image(image) + masks, scores, _ = self._predictor.predict( + point_coords=pts, + point_labels=lbls, + box=bbox, + multimask_output=False, + ) + + polygons = [] + for m in masks: + poly = self._mask_to_polygon(m) + if poly: + polygons.append(poly) + + return polygons, scores.tolist() + except Exception as exc: # noqa: BLE001 + logger.error("SAM2 interactive prediction failed: %s", exc) + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]: """Run automatic mask generation (grid of points). @@ -260,6 +349,89 @@ class SAM2Engine: logger.error("SAM2 auto prediction failed: %s", exc) return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + def propagate_video( + self, + frame_paths: list[str], + source_frame_index: int, + seed: dict, + direction: str = "forward", + max_frames: int | None = None, + ) -> list[dict]: + """Propagate one seed mask across a prepared frame directory with SAM 2 video.""" + if not self._ensure_video_ready(): + raise RuntimeError(self._video_last_error or self.status()["message"]) + if not frame_paths: + return [] + if source_frame_index < 0 or source_frame_index >= len(frame_paths): + raise ValueError("source_frame_index is outside the frame sequence.") + + import cv2 + + source_image = cv2.imread(frame_paths[source_frame_index]) + if source_image is None: + raise RuntimeError("Failed to decode source frame for SAM 2 propagation.") + height, width = source_image.shape[:2] + seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height) + if not seed_mask.any(): + bbox = seed.get("bbox") + if isinstance(bbox, list) and len(bbox) == 4: + seed_mask = self._bbox_to_mask(bbox, width, height) + if not seed_mask.any(): + raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.") + + inference_state = self._video_predictor.init_state( + video_path=os.path.dirname(frame_paths[0]), + offload_video_to_cpu=True, + offload_state_to_cpu=True, + ) + self._video_predictor.add_new_mask( + inference_state, + frame_idx=source_frame_index, + obj_id=1, + mask=seed_mask, + ) + + results: dict[int, dict] = {} + + def collect(reverse: bool) -> None: + for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video( + inference_state, + start_frame_idx=source_frame_index, + max_frame_num_to_track=max_frames, + reverse=reverse, + ): + masks = out_mask_logits + if hasattr(masks, "detach"): + masks = masks.detach().cpu().numpy() + masks = np.asarray(masks) + if masks.ndim == 4: + masks = masks[:, 0] + polygons = [] + scores = [] + for mask in masks: + polygon = self._mask_to_polygon(mask > 0) + if polygon: + polygons.append(polygon) + scores.append(1.0) + results[int(out_frame_idx)] = { + "frame_index": int(out_frame_idx), + "polygons": polygons, + "scores": scores, + "object_ids": [int(obj_id) for obj_id in list(out_obj_ids)], + } + + normalized_direction = direction.lower() + if normalized_direction in {"forward", "both"}: + collect(reverse=False) + if normalized_direction in {"backward", "both"}: + collect(reverse=True) + + try: + self._video_predictor.reset_state(inference_state) + except Exception: # noqa: BLE001 + pass + return [results[index] for index in sorted(results)] + # ----------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------- @@ -292,6 +464,38 @@ class SAM2Engine: ] ] + @staticmethod + def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray: + import cv2 + + mask = np.zeros((height, width), dtype=np.uint8) + for polygon in polygons: + if len(polygon) < 3: + continue + pts = np.array( + [ + [ + int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))), + int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))), + ] + for x, y in polygon + ], + dtype=np.int32, + ) + cv2.fillPoly(mask, [pts], 1) + return mask.astype(bool) + + @staticmethod + def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray: + x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox] + left = int(round(x * max(width - 1, 1))) + top = int(round(y * max(height - 1, 1))) + right = int(round(min(x + w, 1.0) * max(width - 1, 1))) + bottom = int(round(min(y + h, 1.0) * max(height - 1, 1))) + mask = np.zeros((height, width), dtype=bool) + mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True + return mask + # Singleton instance sam_engine = SAM2Engine() diff --git a/backend/services/sam3_engine.py b/backend/services/sam3_engine.py index 8213be7..25d7741 100644 --- a/backend/services/sam3_engine.py +++ b/backend/services/sam3_engine.py @@ -56,8 +56,22 @@ class SAM3Engine: def _gpu_ok(self) -> bool: return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available()) + def _checkpoint_path(self) -> str | None: + path = settings.sam3_checkpoint_path.strip() + return path if path else None + + def _checkpoint_exists(self) -> bool: + path = self._checkpoint_path() + return bool(path and os.path.isfile(path)) + def _can_load(self) -> bool: - return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok()) + return bool( + SAM3_PACKAGE_AVAILABLE + and TORCH_AVAILABLE + and self._python_ok() + and self._gpu_ok() + and self._checkpoint_exists() + ) def _worker_path(self) -> Path: return Path(__file__).with_name("sam3_external_worker.py") @@ -98,6 +112,8 @@ class SAM3Engine: try: env = os.environ.copy() env["SAM3_MODEL_VERSION"] = settings.sam3_model_version + if self._checkpoint_path(): + env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or "" completed = subprocess.run( [settings.sam3_external_python, str(self._worker_path()), "--status"], capture_output=True, @@ -146,7 +162,10 @@ class SAM3Engine: from sam3.model.sam3_image_processor import Sam3Processor from sam3.model_builder import build_sam3_image_model - self._model = build_sam3_image_model() + self._model = build_sam3_image_model( + checkpoint_path=self._checkpoint_path(), + load_from_HF=False, + ) self._processor = Sam3Processor(self._model) self._model_loaded = True self._last_error = None @@ -170,6 +189,8 @@ class SAM3Engine: missing.append("PyTorch") if not self._gpu_ok(): missing.append("CUDA GPU") + if not self._checkpoint_exists(): + missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})") if missing: return f"SAM 3 unavailable: missing {', '.join(missing)}." return "SAM 3 dependencies are present; model will load on first inference." @@ -182,7 +203,7 @@ class SAM3Engine: if self._processor is not None: message = "SAM 3 model loaded and ready." elif external_ready: - message = "SAM 3 external runtime is ready; model will load in the helper process on inference." + message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference." elif external_status.get("message") and not self._can_load(): message = str(external_status["message"]) return { @@ -191,11 +212,11 @@ class SAM3Engine: "available": available, "loaded": self._processor is not None, "device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")), - "supports": ["semantic"], + "supports": ["semantic", "box", "video_track"], "message": message, "package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")), - "checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")), - "checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})", + "checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")), + "checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})", "python_ok": bool(self._python_ok() or external_status.get("python_ok")), "torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")), "cuda_required": True, @@ -203,7 +224,43 @@ class SAM3Engine: "external_python": settings.sam3_external_python if settings.sam3_external_enabled else None, } - def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: + def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]: + if len(box) != 4: + raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].") + x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box] + left, right = sorted([x1, x2]) + top, bottom = sorted([y1, y2]) + width = max(right - left, 1e-6) + height = max(bottom - top, 1e-6) + return [left + width / 2, top + height / 2, width, height] + + def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]: + masks = output.get("masks", []) + scores = output.get("scores", []) + polygons = [] + for mask in masks: + if hasattr(mask, "detach"): + mask = mask.detach().cpu().numpy() + if mask.ndim == 3: + mask = mask[0] + poly = SAM2Engine._mask_to_polygon(mask) + if poly: + polygons.append(poly) + + if hasattr(scores, "detach"): + scores = scores.detach().cpu().tolist() + elif hasattr(scores, "tolist"): + scores = scores.tolist() + return polygons, list(scores) + + def _predict_external( + self, + image: np.ndarray, + prompt_type: str, + *, + text: str = "", + box: list[float] | None = None, + ) -> tuple[list[list[list[float]]], list[float]]: status = self._external_status(force=True) if not status.get("available"): raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.") @@ -217,8 +274,11 @@ class SAM3Engine: json.dumps( { "image_path": str(image_path), + "prompt_type": prompt_type, "text": text.strip(), + "box": box, "model_version": settings.sam3_model_version, + "checkpoint_path": self._checkpoint_path(), "confidence_threshold": settings.sam3_confidence_threshold, }, ensure_ascii=False, @@ -227,6 +287,8 @@ class SAM3Engine: ) env = os.environ.copy() env["SAM3_MODEL_VERSION"] = settings.sam3_model_version + if self._checkpoint_path(): + env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or "" completed = subprocess.run( [settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)], capture_output=True, @@ -250,6 +312,72 @@ class SAM3Engine: raise RuntimeError(str(payload["error"])) return payload.get("polygons", []), payload.get("scores", []) + def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: + return self._predict_external(image, "semantic", text=text) + + def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]: + return self._predict_external(image, "box", box=box) + + def _propagate_video_external( + self, + frame_paths: list[str], + source_frame_index: int, + seed: dict[str, Any], + direction: str, + max_frames: int | None, + ) -> list[dict[str, Any]]: + status = self._external_status(force=True) + if not status.get("available"): + raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.") + if not frame_paths: + return [] + + with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir: + request_path = Path(tmpdir) / "request.json" + request_path.write_text( + json.dumps( + { + "prompt_type": "video_track", + "frame_dir": str(Path(frame_paths[0]).parent), + "source_frame_index": source_frame_index, + "seed": seed, + "direction": direction, + "max_frames": max_frames, + "model_version": settings.sam3_model_version, + "checkpoint_path": self._checkpoint_path(), + "confidence_threshold": settings.sam3_confidence_threshold, + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + env = os.environ.copy() + env["SAM3_MODEL_VERSION"] = settings.sam3_model_version + if self._checkpoint_path(): + env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or "" + completed = subprocess.run( + [settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)], + capture_output=True, + text=True, + timeout=settings.sam3_timeout_seconds, + check=False, + env=env, + ) + + if completed.returncode != 0: + detail = completed.stderr.strip() or completed.stdout.strip() + try: + parsed = json.loads(detail) + detail = parsed.get("error", detail) + except Exception: # noqa: BLE001 + pass + raise RuntimeError(f"SAM 3 external video tracking failed: {detail}") + + payload = json.loads(completed.stdout) + if payload.get("error"): + raise RuntimeError(str(payload["error"])) + return payload.get("frames", []) + def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: if not text.strip(): raise ValueError("SAM 3 semantic prompt requires non-empty text.") @@ -263,29 +391,37 @@ class SAM3Engine: state = self._processor.set_image(pil_image) output = self._processor.set_text_prompt(state=state, prompt=text.strip()) - masks = output.get("masks", []) - scores = output.get("scores", []) - polygons = [] - for mask in masks: - if hasattr(mask, "detach"): - mask = mask.detach().cpu().numpy() - if mask.ndim == 3: - mask = mask[0] - poly = SAM2Engine._mask_to_polygon(mask) - if poly: - polygons.append(poly) - - if hasattr(scores, "detach"): - scores = scores.detach().cpu().tolist() - elif hasattr(scores, "tolist"): - scores = scores.tolist() - return polygons, list(scores) + return self._prediction_to_polygons(output) def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]: raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.") - def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]: - raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.") + def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]: + if not self._can_load() and self._external_status().get("available"): + return self._predict_box_external(image, box) + if not self._ensure_ready(): + raise RuntimeError(self.status()["message"]) + + pil_image = Image.fromarray(image) + with torch.inference_mode(): # type: ignore[union-attr] + state = self._processor.set_image(pil_image) + output = self._processor.add_geometric_prompt( + state=state, + box=self._xyxy_to_cxcywh(box), + label=True, + ) + + return self._prediction_to_polygons(output) + + def propagate_video( + self, + frame_paths: list[str], + source_frame_index: int, + seed: dict[str, Any], + direction: str = "forward", + max_frames: int | None = None, + ) -> list[dict[str, Any]]: + return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames) sam3_engine = SAM3Engine() diff --git a/backend/services/sam3_external_worker.py b/backend/services/sam3_external_worker.py index 7f4e614..9e3a64d 100644 --- a/backend/services/sam3_external_worker.py +++ b/backend/services/sam3_external_worker.py @@ -43,6 +43,13 @@ def _compact_error(exc: Exception) -> str: def _checkpoint_access(model_version: str) -> tuple[bool, str | None]: + checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() + if checkpoint_path: + path = Path(checkpoint_path) + if path.is_file(): + return True, None + return False, f"local checkpoint not found: {checkpoint_path}" + try: from huggingface_hub import hf_hub_download @@ -55,6 +62,7 @@ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]: def runtime_status() -> dict[str, Any]: model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3") + checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None package_error = None package_available = importlib.util.find_spec("sam3") is not None if package_available: @@ -85,6 +93,7 @@ def runtime_status() -> dict[str, Any]: "available": available, "package_available": package_available, "checkpoint_access": checkpoint_access, + "checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})", "python_ok": python_ok, "torch_ok": torch_version is not None, "torch_version": torch_version, @@ -118,34 +127,67 @@ def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: def _to_numpy(value: Any) -> np.ndarray: if hasattr(value, "detach"): - value = value.detach().cpu().numpy() - elif hasattr(value, "cpu"): + value = value.detach() + if hasattr(value, "is_floating_point") and value.is_floating_point(): + value = value.float() value = value.cpu().numpy() + elif hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "is_floating_point") and value.is_floating_point(): + value = value.float() + value = value.numpy() return np.asarray(value) -def predict(request_path: Path) -> dict[str, Any]: - import torch - from sam3.model.sam3_image_processor import Sam3Processor - from sam3.model_builder import build_sam3_image_model +def _xyxy_to_cxcywh(box: list[float]) -> list[float]: + if len(box) != 4: + raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].") + x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box] + left, right = sorted([x1, x2]) + top, bottom = sorted([y1, y2]) + width = max(right - left, 1e-6) + height = max(bottom - top, 1e-6) + return [left + width / 2, top + height / 2, width, height] - payload = json.loads(request_path.read_text(encoding="utf-8")) - image_path = Path(payload["image_path"]) - text = str(payload["text"]).strip() - threshold = float(payload.get("confidence_threshold", 0.5)) - if not text: - raise ValueError("SAM 3 semantic prompt requires non-empty text.") - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True +def _bbox_from_seed(seed: dict[str, Any]) -> list[float]: + bbox = seed.get("bbox") + if isinstance(bbox, list) and len(bbox) == 4: + return [min(max(float(value), 0.0), 1.0) for value in bbox] - image = Image.open(image_path).convert("RGB") - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - model = build_sam3_image_model() - processor = Sam3Processor(model, confidence_threshold=threshold) - state = processor.set_image(image) - output = processor.set_text_prompt(state=state, prompt=text) + polygons = seed.get("polygons") or [] + points = [point for polygon in polygons for point in polygon if len(point) >= 2] + if not points: + raise ValueError("SAM 3 video tracking requires seed bbox or polygons.") + xs = [min(max(float(point[0]), 0.0), 1.0) for point in points] + ys = [min(max(float(point[1]), 0.0), 1.0) for point in points] + left, right = min(xs), max(xs) + top, bottom = min(ys), max(ys) + return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)] + +def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]: + masks = _to_numpy(outputs.get("out_binary_masks", [])) + scores = _to_numpy(outputs.get("out_probs", [])) + obj_ids = _to_numpy(outputs.get("out_obj_ids", [])) + if masks.ndim == 4: + masks = masks[:, 0] + elif masks.ndim == 2: + masks = masks[None, ...] + + polygons = [] + out_scores = [] + out_ids = [] + for index, mask in enumerate(masks): + polygon = _mask_to_polygon(mask) + if polygon: + polygons.append(polygon) + out_scores.append(float(scores[index]) if scores.size > index else 1.0) + out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1) + return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids} + + +def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]: masks = _to_numpy(output.get("masks", [])) scores = _to_numpy(output.get("scores", [])) if masks.ndim == 4: @@ -165,6 +207,115 @@ def predict(request_path: Path) -> dict[str, Any]: } +def predict_video(request_path: Path) -> dict[str, Any]: + import torch + from sam3.model_builder import build_sam3_video_predictor + + payload = json.loads(request_path.read_text(encoding="utf-8")) + frame_dir = Path(payload["frame_dir"]) + source_frame_index = int(payload.get("source_frame_index", 0)) + seed = payload.get("seed") or {} + direction = str(payload.get("direction") or "forward").lower() + max_frames = payload.get("max_frames") + max_frames = int(max_frames) if max_frames else None + checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip() + threshold = float(payload.get("confidence_threshold", 0.5)) + if direction not in {"forward", "backward", "both"}: + raise ValueError(f"Unsupported propagation direction: {direction}") + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + predictor = build_sam3_video_predictor( + checkpoint_path=checkpoint_path or None, + async_loading_frames=False, + ) + session_id = None + try: + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + session = predictor.handle_request( + { + "type": "start_session", + "resource_path": str(frame_dir), + "offload_video_to_cpu": True, + "offload_state_to_cpu": True, + } + ) + session_id = session["session_id"] + predictor.handle_request( + { + "type": "add_prompt", + "session_id": session_id, + "frame_index": source_frame_index, + "bounding_boxes": [_bbox_from_seed(seed)], + "bounding_box_labels": [1], + "output_prob_thresh": threshold, + "rel_coordinates": True, + } + ) + frames = [] + for item in predictor.handle_stream_request( + { + "type": "propagate_in_video", + "session_id": session_id, + "propagation_direction": direction, + "start_frame_index": source_frame_index, + "max_frame_num_to_track": max_frames, + "output_prob_thresh": threshold, + } + ): + frame_response = _video_outputs_to_response(item.get("outputs") or {}) + frame_response["frame_index"] = int(item["frame_index"]) + frames.append(frame_response) + finally: + if session_id: + predictor.handle_request({"type": "close_session", "session_id": session_id}) + + return {"frames": frames} + + +def predict(request_path: Path) -> dict[str, Any]: + import torch + from sam3.model.sam3_image_processor import Sam3Processor + from sam3.model_builder import build_sam3_image_model + + payload = json.loads(request_path.read_text(encoding="utf-8")) + if str(payload.get("prompt_type") or "").strip().lower() == "video_track": + return predict_video(request_path) + + image_path = Path(payload["image_path"]) + prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower() + text = str(payload.get("text") or "").strip() + threshold = float(payload.get("confidence_threshold", 0.5)) + checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip() + if prompt_type == "semantic" and not text: + raise ValueError("SAM 3 semantic prompt requires non-empty text.") + if prompt_type not in {"semantic", "box"}: + raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}") + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + image = Image.open(image_path).convert("RGB") + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + model = build_sam3_image_model( + checkpoint_path=checkpoint_path or None, + load_from_HF=not bool(checkpoint_path), + ) + processor = Sam3Processor(model, confidence_threshold=threshold) + state = processor.set_image(image) + if prompt_type == "box": + output = processor.add_geometric_prompt( + state=state, + box=_xyxy_to_cxcywh(payload.get("box") or []), + label=True, + ) + else: + output = processor.set_text_prompt(state=state, prompt=text) + + return _prediction_to_response(output) + + def main() -> int: parser = argparse.ArgumentParser(description="SAM 3 external runtime helper") parser.add_argument("--status", action="store_true") diff --git a/backend/services/sam_registry.py b/backend/services/sam_registry.py index 8907991..57e16ea 100644 --- a/backend/services/sam_registry.py +++ b/backend/services/sam_registry.py @@ -67,6 +67,19 @@ class SAMRegistry: def predict_box(self, model_id: str | None, image: Any, box: list[float]): return self._ensure_available(model_id).predict_box(image, box) + def predict_interactive( + self, + model_id: str | None, + image: Any, + box: list[float] | None, + points: list[list[float]], + labels: list[int], + ): + model = self.normalize_model_id(model_id) + if model != "sam2": + raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.") + return self._ensure_available(model).predict_interactive(image, box, points, labels) + def predict_auto(self, model_id: str | None, image: Any): return self._ensure_available(model_id).predict_auto(image) @@ -76,5 +89,22 @@ class SAMRegistry: return self._ensure_available(model).predict_semantic(image, text) return self._ensure_available(model).predict_auto(image) + def propagate_video( + self, + model_id: str | None, + frame_paths: list[str], + source_frame_index: int, + seed: dict[str, Any], + direction: str, + max_frames: int | None, + ): + return self._ensure_available(model_id).propagate_video( + frame_paths, + source_frame_index, + seed, + direction=direction, + max_frames=max_frames, + ) + sam_registry = SAMRegistry() diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index b31f410..5a02d15 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -116,6 +116,44 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch): assert semantic_response.json()["scores"] == [0.5] +def test_predict_interactive_combines_box_and_points(client, monkeypatch): + _, frame, _ = _create_project_and_frame(client) + calls = {} + monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) + + def fake_predict_interactive(model, image, box, points, labels): + calls["model"] = model + calls["box"] = box + calls["points"] = points + calls["labels"] = labels + return ( + [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], + [0.88], + ) + + monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive) + + response = client.post("/api/ai/predict", json={ + "image_id": frame["id"], + "prompt_type": "interactive", + "prompt_data": { + "box": [0.1, 0.1, 0.9, 0.9], + "points": [[0.5, 0.5], [0.2, 0.2]], + "labels": [1, 0], + }, + "model": "sam2", + }) + + assert response.status_code == 200 + assert response.json()["scores"] == [0.88] + assert calls == { + "model": "sam2", + "box": [0.1, 0.1, 0.9, 0.9], + "points": [[0.5, 0.5], [0.2, 0.2]], + "labels": [1, 0], + } + + def test_model_status_reports_runtime(client, monkeypatch): monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: { "selected_model": selected_model or "sam2", @@ -170,6 +208,80 @@ def test_model_status_reports_runtime(client, monkeypatch): assert body["models"][1]["available"] is False +def test_propagate_saves_tracked_annotations(client, monkeypatch): + project = client.post("/api/projects", json={"name": "Video Project"}).json() + frames = [ + client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": idx, + "image_url": f"frames/{idx}.jpg", + "width": 640, + "height": 360, + }).json() + for idx in range(3) + ] + calls = {} + monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg") + + def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): + calls["model"] = model + calls["source_frame_index"] = source_frame_index + calls["seed"] = seed + calls["direction"] = direction + calls["max_frames"] = max_frames + calls["frame_count"] = len(frame_paths) + return [ + { + "frame_index": 0, + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + "scores": [0.9], + "object_ids": [1], + }, + { + "frame_index": 1, + "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], + "scores": [0.8], + "object_ids": [1], + }, + ] + + monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video) + + response = client.post("/api/ai/propagate", json={ + "project_id": project["id"], + "frame_id": frames[0]["id"], + "model": "sam2", + "direction": "forward", + "max_frames": 2, + "include_source": False, + "seed": { + "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], + "bbox": [0.1, 0.1, 0.1, 0.1], + "label": "胆囊", + "color": "#ff0000", + "class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, + "template_id": None, + }, + }) + + assert response.status_code == 200 + body = response.json() + assert body["created_annotation_count"] == 1 + assert body["processed_frame_count"] == 2 + assert calls["model"] == "sam2" + assert calls["source_frame_index"] == 0 + assert calls["direction"] == "forward" + assert calls["frame_count"] == 2 + saved = body["annotations"][0] + assert saved["frame_id"] == frames[1]["id"] + assert saved["mask_data"]["source"] == "sam2_propagation" + assert saved["mask_data"]["class"]["name"] == "胆囊" + assert saved["mask_data"]["score"] == 0.8 + + listing = client.get(f"/api/ai/annotations?project_id={project['id']}") + assert len(listing.json()) == 1 + + def test_predict_validation_errors(client, monkeypatch): project, _, _ = _create_project_and_frame(client) diff --git a/backend/tests/test_media.py b/backend/tests/test_media.py index 685b182..f913c63 100644 --- a/backend/tests/test_media.py +++ b/backend/tests/test_media.py @@ -84,6 +84,12 @@ def test_parse_media_queues_background_task(client, monkeypatch): assert data["progress"] == 0 assert data["project_id"] == project["id"] assert data["celery_task_id"] == "celery-1" + assert data["payload"] == { + "source_type": "video", + "parse_fps": 5.0, + "max_frames": None, + "target_width": 640, + } assert queued == [data["id"]] assert published == [data["id"]] @@ -94,6 +100,35 @@ def test_parse_media_queues_background_task(client, monkeypatch): assert project_detail["status"] == "parsing" +def test_parse_media_accepts_frame_sequence_options(client, monkeypatch): + project = client.post("/api/projects", json={ + "name": "Parse Options", + "video_path": "uploads/1/clip.mp4", + "source_type": "video", + "parse_fps": 30, + }).json() + + class FakeAsyncResult: + id = "celery-options" + + monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: FakeAsyncResult()) + monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: None) + + response = client.post( + f"/api/media/parse?project_id={project['id']}&parse_fps=15&max_frames=120&target_width=960" + ) + + assert response.status_code == 202 + data = response.json() + assert data["payload"] == { + "source_type": "video", + "parse_fps": 15.0, + "max_frames": 120, + "target_width": 960, + } + assert client.get(f"/api/projects/{project['id']}").json()["parse_fps"] == 15.0 + + def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path): from models import ProcessingTask from services.media_task_runner import run_parse_media_task @@ -118,10 +153,14 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp frame_file.write_bytes(b"fake image") monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video") - monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0)) + monkeypatch.setattr( + "services.media_task_runner.parse_video", + lambda local_path, output_dir, fps, max_frames=None, target_width=640: ([str(frame_file)], 25.0), + ) monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb")) monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None) monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"]) + monkeypatch.setattr("routers.projects.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}") published = [] monkeypatch.setattr( "services.media_task_runner.publish_task_progress_event", @@ -131,6 +170,17 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp result = run_parse_media_task(db_session, task.id) assert result["frames_extracted"] == 1 + assert result["frame_sequence"] == { + "original_fps": 25.0, + "parse_fps": 5.0, + "frame_count": 1, + "duration_ms": 0.0, + "target_width": 640, + "frame_width": None, + "frame_height": None, + "max_frames": None, + "object_prefix": f"projects/{project['id']}/frames", + } db_session.refresh(task) assert task.status == "success" assert task.progress == 100 @@ -140,6 +190,8 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp assert project_detail["status"] == "ready" frames = client.get(f"/api/projects/{project['id']}/frames").json() assert "frame_000001.jpg" in frames[0]["image_url"] + assert frames[0]["timestamp_ms"] == 0.0 + assert frames[0]["source_frame_number"] == 0 def test_parse_task_runner_skips_already_cancelled_task(db_session): diff --git a/backend/tests/test_sam3_engine.py b/backend/tests/test_sam3_engine.py index e114599..3ea1303 100644 --- a/backend/tests/test_sam3_engine.py +++ b/backend/tests/test_sam3_engine.py @@ -4,6 +4,7 @@ from pathlib import Path import numpy as np from services.sam3_engine import SAM3Engine +from services.sam3_external_worker import _to_numpy class _Completed: @@ -14,6 +15,8 @@ class _Completed: def _external_settings(monkeypatch, python_path: Path): + checkpoint_path = python_path.with_name("sam3.pt") + checkpoint_path.write_bytes(b"checkpoint") python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8") python_path.chmod(0o755) monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False) @@ -23,6 +26,7 @@ def _external_settings(monkeypatch, python_path: Path): monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10) monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30) monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4) + monkeypatch.setattr("services.sam3_engine.settings.sam3_checkpoint_path", str(checkpoint_path)) def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch): @@ -30,9 +34,12 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch): def fake_run(args, **_kwargs): assert "--status" in args + assert _kwargs["env"]["SAM3_CHECKPOINT_PATH"].endswith("sam3.pt") return _Completed(stdout=json.dumps({ "available": True, "package_available": True, + "checkpoint_access": True, + "checkpoint_path": _kwargs["env"]["SAM3_CHECKPOINT_PATH"], "python_ok": True, "torch_ok": True, "cuda_available": True, @@ -48,7 +55,10 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch): assert status["external_available"] is True assert status["package_available"] is True assert status["python_ok"] is True - assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference." + assert status["checkpoint_exists"] is True + assert status["checkpoint_path"].endswith("sam3.pt") + assert status["supports"] == ["semantic", "box", "video_track"] + assert status["message"] == "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference." def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): @@ -61,6 +71,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): return _Completed(stdout=json.dumps({ "available": True, "package_available": True, + "checkpoint_access": True, "python_ok": True, "torch_ok": True, "cuda_available": True, @@ -71,6 +82,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): request = json.loads(request_path.read_text(encoding="utf-8")) assert request["text"] == "vessel" assert request["confidence_threshold"] == 0.4 + assert request["checkpoint_path"].endswith("sam3.pt") assert Path(request["image_path"]).exists() return _Completed(stdout=json.dumps({ "polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], @@ -86,6 +98,97 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): assert any("--request" in args for args in calls) +def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + + def fake_run(args, **_kwargs): + if "--status" in args: + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "checkpoint_access": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + request_path = Path(args[-1]) + request = json.loads(request_path.read_text(encoding="utf-8")) + assert request["prompt_type"] == "box" + assert request["box"] == [0.1, 0.2, 0.7, 0.8] + assert request["text"] == "" + return _Completed(stdout=json.dumps({ + "polygons": [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]], + "scores": [0.88], + })) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + polygons, scores = SAM3Engine().predict_box( + np.zeros((8, 8, 3), dtype=np.uint8), + [0.1, 0.2, 0.7, 0.8], + ) + + assert polygons == [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]] + assert scores == [0.88] + + +def test_sam3_propagate_video_uses_external_worker(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + frame_dir = tmp_path / "frames" + frame_dir.mkdir() + frame_paths = [] + for index in range(2): + frame_path = frame_dir / f"frame_{index:06d}.jpg" + frame_path.write_bytes(b"jpeg") + frame_paths.append(str(frame_path)) + + def fake_run(args, **_kwargs): + if "--status" in args: + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "checkpoint_access": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + request_path = Path(args[-1]) + request = json.loads(request_path.read_text(encoding="utf-8")) + assert request["prompt_type"] == "video_track" + assert request["frame_dir"] == str(frame_dir) + assert request["source_frame_index"] == 0 + assert request["direction"] == "forward" + assert request["max_frames"] == 2 + assert request["seed"]["bbox"] == [0.1, 0.1, 0.2, 0.2] + return _Completed(stdout=json.dumps({ + "frames": [ + { + "frame_index": 1, + "polygons": [[[0.2, 0.2], [0.4, 0.2], [0.4, 0.4]]], + "scores": [0.7], + "object_ids": [1], + } + ] + })) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + frames = SAM3Engine().propagate_video( + frame_paths, + 0, + {"bbox": [0.1, 0.1, 0.2, 0.2]}, + direction="forward", + max_frames=2, + ) + + assert frames[0]["frame_index"] == 1 + assert frames[0]["scores"] == [0.7] + + def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch): _external_settings(monkeypatch, tmp_path / "python") @@ -94,6 +197,7 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch): return _Completed(stdout=json.dumps({ "available": True, "package_available": True, + "checkpoint_access": True, "python_ok": True, "torch_ok": True, "cuda_available": True, @@ -110,3 +214,32 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch): assert "HF access denied" in str(exc) else: raise AssertionError("Expected SAM 3 external inference failure.") + + +def test_sam3_worker_casts_floating_tensors_before_numpy(): + class FakeTensor: + def __init__(self): + self.float_called = False + + def detach(self): + return self + + def is_floating_point(self): + return True + + def float(self): + self.float_called = True + return self + + def cpu(self): + return self + + def numpy(self): + return np.array([1.0], dtype=np.float32) + + tensor = FakeTensor() + + result = _to_numpy(tensor) + + assert tensor.float_called is True + assert result.tolist() == [1.0] diff --git a/doc/01-purpose-and-word-summary.md b/doc/01-purpose-and-word-summary.md index d5fc3bb..372518d 100644 --- a/doc/01-purpose-and-word-summary.md +++ b/doc/01-purpose-and-word-summary.md @@ -38,14 +38,14 @@ Word 方案描述的理想系统包含: | 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` | | DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 | | WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` | -| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 | +| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 2 已接 video predictor 片段传播;SAM 3 通过独立 Python 3.12 环境桥接,支持文本/框提示和 official video tracker 入口,状态会检查 Python/CUDA/包/本地 checkpoint | | 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 | | 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 | | COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON 和 PNG mask ZIP 前端按钮均已接入,ZIP 包含单标注 mask、语义融合 mask 和类别映射 | ## 当前代码尚未落地的目标 -- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU 和官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny。 +- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU、官方包导入和本地 `sam3权重/sam3.pt` checkpoint 状态检查。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny;video tracker 入口已接入,真实效果取决于本地 checkpoint 是否兼容 video model。 - Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。 - GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point;骨架提取、HDBSCAN 和模板自动映射尚未实现。 - Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。 @@ -55,4 +55,4 @@ Word 方案描述的理想系统包含: ## 结论 -当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。 +当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可用 SAM 2 / SAM 3 进行视频片段传播、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 SAM 3 真实视频 tracker smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。 diff --git a/doc/03-frontend-element-audit.md b/doc/03-frontend-element-audit.md index 662947d..edbdcac 100644 --- a/doc/03-frontend-element-audit.md +++ b/doc/03-frontend-element-audit.md @@ -65,6 +65,7 @@ | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON | | “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` | | “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 | +| “传播片段”按钮 | 真实可用 | 使用当前选中 mask 或当前帧第一个 mask 作为 seed,调用 `POST /api/ai/propagate`;SAM 2 走 video predictor,SAM 3 走 external video tracker,完成后刷新已保存标注 | | “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` | ## CanvasArea 画布 @@ -78,11 +79,11 @@ | 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 | | 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 | | AI 推理中提示 | 真实可用 | 请求期间会显示 | -| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon;点工具生成小区域;均写入 `Mask.segmentation`,可归档保存 | +| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后可按 Enter 完成,也可在三点后点击首节点闭合;矩形/圆/线拖拽生成 polygon;点工具生成小区域;绘制工具可在已有 mask 上继续落点;均写入 `Mask.segmentation`,可归档保存 | | Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 | -| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点 | +| Polygon 逐点编辑 / 删除 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点;选中 mask 但未选中顶点时 Delete/Backspace 删除整个 mask,已保存 mask 会同步调用后端删除 | | GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty,归档保存会更新后端 | -| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 | +| 应用分类 | 真实可用 | Canvas 右下角按钮可将当前选择的模板分类应用到本帧 mask;右侧语义分类树点击分类时会优先改当前已选 mask;已保存 mask 会标为 dirty,归档保存时更新后端 | | 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask | | 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 | | 当前图层树文字 | Mock / UI-only | 固定显示 `OBJECT_VEHICLE_01` | @@ -93,7 +94,7 @@ |------|------|------| | 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 | | 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask | -| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask | +| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 | | 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 | | 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y | @@ -105,15 +106,16 @@ | 帧缩略图 | 真实可用 | 使用 `frames[].url` | | 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` | | 顶部 range 拖动 | 真实可用 | 改变当前帧 | +| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` | | 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps | -| 方向键切帧 | Mock / UI-only | Word 提到,但当前没有键盘监听 | +| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 | ## OntologyInspector 本体面板 | 元素 | 状态 | 说明 | |------|------|------| | 模板选择 | 部分可用 | 读取全局 templates,可切换 activeTemplateId | -| 分类树展示 | 真实可用 | 显示模板 classes 和本地 customClasses | +| 分类树展示 / 换标签 | 真实可用 | 显示模板 classes 和本地 customClasses;点击分类会设为后续新 mask 的 activeClass,如果 Canvas 已选 mask,则同步更新已选 mask 的标签、颜色和 class 元数据 | | 添加自定义分类 | 部分可用 | 只存在组件本地状态,不保存到后端 | | 置信度条 | Mock / UI-only | 固定 `0.9412` | | 拓扑锚点数量 | Mock / UI-only | 固定 `12 节点` | @@ -124,8 +126,9 @@ | 元素 | 状态 | 说明 | |------|------|------| | 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry | -| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 | -| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 | +| 正向/反向点 | 真实可用 | 可在当前项目帧上加点并调用 AI 推理接口;SAM 2 框选后会携带原始框和累计正/反点细化同一个候选 mask;SAM 3 选择后会提示点交互需切回 SAM 2 | +| SAM 3 框选 | 真实可用 | 工作区选择 SAM 3 后可使用框选工具;后端通过官方 `add_geometric_prompt()` 正框执行 SAM 3 几何提示推理 | +| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用;空文本、失败和 0 mask 返回会显示前端反馈 | | 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 | | 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 | | 上传替换底图 | Mock / UI-only | 按钮无事件 | @@ -150,6 +153,6 @@ ## 总体结论 -当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 +当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。 diff --git a/doc/04-api-contracts.md b/doc/04-api-contracts.md index a263413..a45e853 100644 --- a/doc/04-api-contracts.md +++ b/doc/04-api-contracts.md @@ -32,12 +32,13 @@ Authorization: Bearer | `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 | | `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data | | `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data | -| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task | +| `parseMedia(projectId, options?)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task;支持 `parse_fps`、`max_frames`、`target_width` | | `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 | | `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery | | `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 | -| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url | +| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url,以及 `timestamp_ms`、`source_frame_number` | | `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` | +| `propagateMasks(payload)` | `POST /api/ai/propagate` | 对齐 | 当前帧 seed mask 向视频片段传播,并保存后续帧标注 | | `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 | | `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 | | `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask | @@ -70,12 +71,13 @@ Authorization: Bearer | DELETE | `/api/templates/{template_id}` | 删除模板 | | POST | `/api/media/upload` | 上传视频/图片/DICOM 单文件 | | POST | `/api/media/upload/dicom` | 批量上传 DICOM | -| POST | `/api/media/parse` | 创建 Celery 拆帧任务 | +| POST | `/api/media/parse` | 创建 Celery 拆帧任务;query 支持 `project_id`、`source_type`、`parse_fps`、`max_frames`、`target_width` | | GET | `/api/tasks` | 查询后台任务列表 | | GET | `/api/tasks/{task_id}` | 查询单个后台任务 | | POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 | | POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 | | POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 | +| POST | `/api/ai/propagate` | SAM 2 / SAM 3 视频片段传播并保存标注 | | GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 | | POST | `/api/ai/auto` | 自动分割 | | POST | `/api/ai/annotate` | 保存 AI 标注 | @@ -110,6 +112,25 @@ Authorization: Bearer } ``` +### 创建标准帧序列拆帧任务 + +```text +POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960 +``` + +任务 `payload` 会记录本次拆帧参数;完成后的 `result.frame_sequence` 返回 `original_fps`、`parse_fps`、`frame_count`、`duration_ms`、`target_width`、帧宽高和 MinIO object prefix。每条 `FrameOut` 包含: + +```json +{ + "frame_index": 0, + "image_url": "http://...", + "width": 960, + "height": 540, + "timestamp_ms": 0, + "source_frame_number": 0 +} +``` + ### 创建/更新模板 ```json @@ -150,11 +171,14 @@ Authorization: Bearer - `point` - `box` -- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定。 +- `interactive`,用于 SAM 2 交互式细化,`prompt_data` 同时携带 `box`、累计 `points` 和 `labels`。 +- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和本地 checkpoint 状态决定。 + +选择 `sam3` 且发送 `box` 时,前端仍传 normalized `[x1, y1, x2, y2]`,后端适配层会转换成官方几何 prompt 的 `[center_x, center_y, width, height]` 正框;当前 SAM 3 不接正/反点修正。 可选 `options` 字段: -- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 +- `crop_to_prompt`:对 point/box/interactive prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 - `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。 - `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。 @@ -183,6 +207,32 @@ Authorization: Bearer } ``` +### 视频片段传播请求体 + +工作区“传播片段”调用: + +```json +{ + "project_id": 1, + "frame_id": 123, + "model": "sam2", + "direction": "forward", + "max_frames": 30, + "include_source": false, + "save_annotations": true, + "seed": { + "polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], + "bbox": [0.1, 0.1, 0.2, 0.2], + "label": "胆囊", + "color": "#ff0000", + "class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, + "template_id": 2 + } +} +``` + +`model=sam2` 使用 SAM 2 video predictor 的 mask seed 传播;`model=sam3` 使用独立 Python 3.12 helper 中的 SAM 3 video tracker,并以 seed bbox 作为初始提示。响应会返回已创建的 `annotations`,保存的 `mask_data.source` 为 `sam2_propagation` 或 `sam3_propagation`。 + ## 已完成的接口对齐 - `updateProject()` 已从 `PUT` 改为 `PATCH`。 diff --git a/doc/05-implementation-plan.md b/doc/05-implementation-plan.md index fd2c6b7..a8494e6 100644 --- a/doc/05-implementation-plan.md +++ b/doc/05-implementation-plan.md @@ -16,7 +16,7 @@ 剩余边界: -1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test。 +1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接、本地 `sam3权重/sam3.pt` checkpoint 状态检查、本地 checkpoint 加载参数接入、单帧文本/框提示和 video tracker API 入口;下一步需要基于真实业务帧验证语义召回质量和视频 tracker 稳定性。 2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。 ## 阶段 2:打通标注保存(已完成基础闭环) @@ -127,6 +127,24 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前 1. 在前端预览重叠裁决结果。 2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。 +## 阶段 7.5:视频片段传播(已完成基础闭环) + +当前工作区“传播片段”会使用当前选中 mask 或当前帧第一个 mask 作为 seed,默认向后传播 30 帧并把结果写入后端标注表。 + +已完成: + +1. 前端 `propagateMasks()` 已接入 `POST /api/ai/propagate`。 +2. 工作区按钮会把 seed mask 的 normalized polygon、bbox、label、color 和 class 元数据传给后端。 +3. SAM 2 路径使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。 +4. SAM 3 路径通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境中的官方 `build_sam3_video_predictor()`。 +5. 后端会跳过源帧,把传播结果保存到后续帧 `annotations`,并在完成后由前端刷新回显。 + +剩余建议: + +1. 把传播任务改为异步任务,接入 Dashboard 和 WebSocket 进度。 +2. 前端增加传播方向、帧数和覆盖已有标注策略设置。 +3. 用真实长视频分别做 SAM 2 / SAM 3 tracker smoke test 和质量评估。 + ## 阶段 8:清理 UI 文案与 Mock 建议统一这些文案和真实能力: diff --git a/doc/07-current-requirements-freeze.md b/doc/07-current-requirements-freeze.md index 1236846..b2ed627 100644 --- a/doc/07-current-requirements-freeze.md +++ b/doc/07-current-requirements-freeze.md @@ -28,7 +28,11 @@ - 未提供项目 ID 上传时,后端自动创建项目。 - 提供项目 ID 上传时,后端把上传对象关联到该项目。 - 拆帧接口根据项目 `source_type` 处理视频或 DICOM。 +- 拆帧接口支持 `parse_fps`、`max_frames` 和 `target_width` 参数,用于生成可被 SAM 2 / SAM 3 视频处理复用的标准帧序列。 +- 视频帧使用连续 `frame_%06d.jpg` 命名,默认从 `frame_000000.jpg` 开始,并按 `target_width` 缩放。 - 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。 +- 每条帧记录包含 `frame_index`、`image_url`、`width`、`height`、`timestamp_ms` 和 `source_frame_number`。 +- 任务完成结果包含 `frame_sequence` 元数据:`original_fps`、`parse_fps`、`frame_count`、`duration_ms`、`target_width`、帧宽高和对象存储前缀。 - 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。 - 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。 - 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。 @@ -41,8 +45,9 @@ - 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。 - Canvas 显示当前帧图片。 - Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。 -- 时间轴支持缩略图点击切帧、range 拖动切帧、播放/暂停顺序推进帧。 +- 时间轴支持缩略图点击切帧、range 拖动切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。 - 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。 +- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps` 或 `original_fps`,格式为 `mm:ss.cc`。 ## R5 工具栏 @@ -50,12 +55,16 @@ - 正向点、反向点、框选工具会影响 Canvas 交互。 - 魔法棒按钮切换到 AI 页面。 - 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。 -- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。 +- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。 +- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。 - Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。 - 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。 +- 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。 - 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。 - 区域合并工具支持多选当前帧 mask,并使用 polygon union 生成合并后的主 mask。 - 区域去除工具支持多选当前帧 mask,并从第一个选中的主 mask 中扣除后续选中 mask。 +- 区域合并/去除模式显示已选数量,并隐藏 polygon 编辑手柄以避免手柄抢占多选点击。 +- 区域去除结果包含内洞时,前端保留 hole ring 并用 even-odd 规则渲染。 ## R6 AI 推理 @@ -65,7 +74,14 @@ - 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。 - 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。 - 框选提示传归一化 `[x1, y1, x2, y2]`。 -- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 +- 工作区 SAM 2 框选会建立一个候选 mask;后续正向点/反向点会携带原始框和累计点,以 `interactive` prompt 细化并替换同一个候选 mask。 +- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和本地 checkpoint 均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 +- SAM 3 支持工作区框选提示;后端把 normalized `[x1, y1, x2, y2]` 转成官方 `add_geometric_prompt()` 需要的 `[center_x, center_y, width, height]` 正框。 +- 当前 SAM 3 前端路径不支持正/反点修正;在工作区用 SAM 3 进行点交互时,前端会提示切回 SAM 2。 +- 工作区“传播片段”会把当前选中区域或当前帧第一个区域作为 seed,调用 `POST /api/ai/propagate`,默认从当前帧向后传播 30 帧并保存结果标注。 +- `POST /api/ai/propagate` 支持 `model=sam2` 或 `model=sam3`;SAM 2 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`,SAM 3 通过独立 Python 3.12 helper 调用官方 `build_sam3_video_predictor()` video tracker。 +- 传播结果会写入后续帧 `annotations`,`mask_data.source` 分别标记为 `sam2_propagation` 或 `sam3_propagation`,并保留 label、color 和 class 元数据。 +- AI 页面会对 SAM 3 空文本、推理失败和返回 0 个 mask 的情况显示明确反馈。 - AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。 - 后端返回 `polygons` 和 `scores`。 - 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。 @@ -98,6 +114,7 @@ - 工作区右侧可以选择模板。 - 面板显示模板分类和组件本地自定义分类。 - 用户可以选择具体分类;新 AI mask 会记录 `classId`、`className`、`classZIndex`,并在保存时写入 `mask_data.class`。 +- 如果 Canvas 当前已经选中一个或多个 mask,点击语义分类树会把这些 mask 的 `label`、`color` 和 class 元数据改为该分类;已保存 mask 会进入 `dirty` 状态,归档保存时更新后端。 - 添加自定义分类只存在组件本地状态,不保存到后端。 - 置信度、拓扑锚点和重新提取骨架按钮当前为展示/占位。 diff --git a/doc/08-current-design-freeze.md b/doc/08-current-design-freeze.md index e284b4b..2d3bb34 100644 --- a/doc/08-current-design-freeze.md +++ b/doc/08-current-design-freeze.md @@ -19,7 +19,7 @@ | 模块 | 文件 | 设计职责 | |------|------|----------| | 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 | -| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 | +| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、当前选中 mask ids、工具状态和 mask 撤销/重做历史栈 | | API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 | | 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 | | WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 | @@ -30,7 +30,7 @@ | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | | 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | -| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 | +| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、左右方向键切帧、播放和当前/总时长显示 | | 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 | | AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 | | 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 | @@ -48,10 +48,10 @@ | Projects | `backend/routers/projects.py` | 项目与帧 CRUD | | Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 | | Media | `backend/routers/media.py` | 上传媒体和拆帧 | -| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 | +| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、视频传播、模型状态和标注保存 | | Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 | -| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 | -| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 | +| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测、点/框/自动推理和视频 mask 传播 | +| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接、本地 checkpoint 加载、文本语义推理、正框几何提示和 video tracker 适配 | | SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 | ## 状态模型 @@ -59,9 +59,10 @@ 前端 store 的核心对象: - `Project`:项目基本信息、状态、帧数、fps、媒体路径。 -- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。 +- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高、序列时间戳和原视频源帧号。 - `Template` / `TemplateClass`:模板和分类定义。 - `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。 +- `selectedMaskIds`:Canvas 当前选中的 mask id 列表,供右侧本体面板对已选区域直接换标签。 - `maskHistory` / `maskFuture`:mask 编辑历史栈,用于撤销和重做。 - `activeModule`:当前页面。 - `activeTool`:当前工具。 @@ -79,9 +80,11 @@ 1. `ProjectLibrary` 创建项目。 2. 上传视频或 DICOM 到 `/api/media/upload` 或 `/api/media/upload/dicom`。 -3. 调用 `/api/media/parse` 创建异步拆帧任务。 -4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 -5. 刷新项目列表。 +3. 调用 `/api/media/parse` 创建异步拆帧任务;可通过 `parse_fps`、`max_frames` 和 `target_width` 指定标准帧序列参数。 +4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,视频帧按 `frame_%06d.jpg` 从 `frame_000000.jpg` 连续命名,并按目标宽度缩放。 +5. worker 写入 `frames.timestamp_ms` 和 `frames.source_frame_number`,并在任务 `result.frame_sequence` 中记录 FPS、帧数、时长、尺寸和对象存储前缀。 +6. worker 持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 +7. 刷新项目列表。 ### 任务控制 @@ -95,29 +98,43 @@ 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。 2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。 -3. 帧数据映射为 store `Frame[]`。 +3. 帧数据映射为 store `Frame[]`,包含 `timestampMs` 和 `sourceFrameNumber`,供时间轴和后续视频传播使用。 4. 当前帧传入 `CanvasArea`。 ### AI 点/框推理 1. 用户在 Canvas 选择正向点、反向点或框选。 2. `CanvasArea` 读取当前帧 ID 和宽高。 -3. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`。 -4. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。 -5. 前端把 `polygons` 转为 mask,写入 store。 -6. Canvas 按当前帧过滤并渲染 mask。 -7. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。 -8. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。 -9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 -10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 +3. SAM 2 框选会创建一个候选 mask,并记录原始框;后续正向点/反向点会累计到同一候选上。 +4. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`;同时有框和点时发送 `interactive` prompt。 +5. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。 +6. 前端把 `polygons` 转为 mask;交互式细化会替换同一个候选 mask,而不是新增多个 mask。 +7. Canvas 按当前帧过滤并渲染 mask。 +8. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。 +9. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。 +10. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 +11. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 + +### 视频片段传播 + +1. 用户在工作区选中一个当前帧 mask;如果未显式选中,前端使用当前帧第一个 mask。 +2. `VideoWorkspace` 用 `buildAnnotationPayload()` 把 seed mask 转成 normalized polygon、bbox、label、color 和 class 元数据。 +3. 前端调用 `POST /api/ai/propagate`,默认 `direction=forward`、`max_frames=30`、`include_source=false`。 +4. 后端按项目帧序列截取片段,下载对应帧到临时 `frame_%06d.jpg` 目录,保持当前帧在片段中的相对索引。 +5. `model=sam2` 时,`sam2_engine` 使用 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播。 +6. `model=sam3` 时,`sam3_engine` 将请求交给 `sam3_external_worker.py`,由独立 Python 3.12 环境调用官方 `build_sam3_video_predictor()`,以 seed bbox 走 video tracker。 +7. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧,`mask_data.source` 记录模型传播来源。 +8. 前端传播完成后重新调用 `GET /api/ai/annotations` 并回显新标注。 ### 手工绘制与历史栈 1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。 2. `CanvasArea` 将交互坐标转换成像素 polygon。 -3. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。 -4. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。 -5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。 +3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。 +4. mask path 只在 `move`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。 +5. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。 +6. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。 +7. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。 ### Polygon 逐点编辑 @@ -126,14 +143,16 @@ 3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。 4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。 5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。 +6. 未选中具体顶点但选中了 mask 时,Delete/Backspace 从前端 store 删除该 mask;如果包含 `annotationId`,通过工作区回调调用后端删除接口。 ### 区域合并与去除 1. 用户选择 `area_merge` 或 `area_remove` 后,点击多个当前帧 mask 组成选择集。 -2. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。 -3. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。 -4. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留。 -5. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路。 +2. 合并/去除模式隐藏 polygon 顶点和边中点编辑手柄,并在右下角显示已选数量;少于两个 mask 时操作按钮禁用。 +3. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。 +4. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。 +5. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留;如果 difference 产生内洞,`segmentation` 保留外圈和 hole ring,渲染时使用 even-odd fill。 +6. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路;带洞结果的面积按外圈减内洞计算。 ### GT Mask 导入 @@ -153,7 +172,10 @@ 3. 保存时调用 `createTemplate()` 或 `updateTemplate()`。 4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。 5. 返回时再解包给前端。 -6. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。 +6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。 +7. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。 +8. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。 +9. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。 ### 导出 @@ -173,13 +195,20 @@ - `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。 - `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。 - `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。 +- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`。 - `saveAnnotation()` 使用 `POST /api/ai/annotate`。 - `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。 - `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。 - `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。 - `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。 -- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。 -- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 +- `parseMedia()` 使用 `POST /api/media/parse?project_id=...`,可选 `parse_fps`、`max_frames`、`target_width`,用于生成标准帧序列。 +- `getProjectFrames()` 返回帧图像 URL、宽高、`timestamp_ms` 和 `source_frame_number`。 +- 后端 `/api/ai/predict` 支持 point、box、interactive、semantic 四种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。 +- 当前 SAM 3 暴露 semantic 文本语义推理和 box 几何提示;工作区 Canvas 的点交互会在选择 SAM 3 时显示提示,不再静默失败。 +- SAM 3 box prompt 复用后端 `/api/ai/predict` 的 `box` prompt_type,输入仍是 normalized `[x1, y1, x2, y2]`,引擎适配层会转换为官方 `add_geometric_prompt()` 使用的 `[center_x, center_y, width, height]` 正框。 +- AI 页面选择 SAM 3 时优先发送文本 semantic prompt,不会把正/反点误发送为 SAM 3 point prompt;空文本、后端错误和空结果都会显示反馈消息。 +- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box/interactive prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 +- 后端 `/api/ai/propagate` 支持 SAM 2 mask seed 视频传播和 SAM 3 external video tracker;当前前端默认向后传播 30 帧并保存结果标注。 - 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 @@ -198,7 +227,7 @@ 以下能力属于当前冻结版本的占位或半可用功能: - Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。 -- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。 -- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权时 `available=false`。 +- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。 +- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和本地 checkpoint;状态接口会暴露真实可用性,运行时缺失时 `available=false`。 - 自定义分类只存在本地组件状态。 - GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。 diff --git a/doc/09-test-plan.md b/doc/09-test-plan.md index ebfe0dd..19c3764 100644 --- a/doc/09-test-plan.md +++ b/doc/09-test-plan.md @@ -16,18 +16,51 @@ |------|----------|--------| | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | -| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | -| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 | -| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 | -| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | +| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | +| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、缩略图/range/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 | +| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、polygon 顶点拖动/删除、整块 mask 删除、区域合并/去除、内含去除 hole 渲染、合并模式隐藏编辑手柄、撤销/重做历史栈 | +| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/interactive/semantic 契约、SAM 2 框选后正负点细化同一候选 mask、SAM 2 视频传播、SAM 3 语义文本前端执行路径、SAM 3 工作区框选、SAM 3 video tracker 外部桥接、SAM 3 点交互不支持提示、空文本/空结果反馈、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | -| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD | -| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 | +| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、mapping_rules 解包/打包、后端模板 CRUD | +| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 模板选择、分类展示、具体分类选择、Canvas 选区同步、点击分类给已选 mask 换标签、自定义分类本地添加 | | R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | +## 逐功能点追踪 + +| 需求 | 功能点 | 对应测试 | 当前状态 | +|------|--------|----------|----------| +| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 | +| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 | +| R3 | 文件类型校验、自动/指定项目上传、视频/DICOM 拆帧任务、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `test_media.py`, `test_tasks.py` | 已覆盖 | +| R4 | 工作区加载帧、无帧自动解析、Canvas 底图、缩略图/range/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | +| R5 | 工具切换、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 | +| R5 | 顶点编辑、顶点删除、整块删除、撤销/重做、区域合并、区域去除、hole even-odd 渲染 | `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | +| R6 | SAM 2 点/框/interactive、SAM 2 视频传播、SAM 3 semantic、SAM 3 box、SAM 3 video tracker、SAM 3 不支持点交互时的前端反馈、模型选择、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_sam3_engine.py` | 已覆盖 | +| R7 | 保存、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 | +| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 | +| R9 | 模板选择、分类展示、分类选择、已选 mask 换标签、自定义本地分类、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts` | 已覆盖 | +| R10 | Dashboard 概览、队列、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 | +| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 | +| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 | +| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 | + +## 本轮补齐记录 + +- R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。 +- R6:补充 `AISegmentation.test.tsx` 中 SAM 3 semantic 文本推理测试,验证前端传参和返回 mask 绑定当前语义类别。 +- R6:补充 SAM 3 空文本、空结果和工作区点交互不支持提示测试,避免前端静默失败。 +- R6:补充 SAM 3 工作区 box prompt 测试和外部 worker box prompt 测试,验证官方 `add_geometric_prompt()` 正框链路。 +- R6:补充 `POST /api/ai/propagate` 后端测试,验证 seed mask 传播结果会保存为后续帧标注并保留 class 元数据。 +- R6:补充 `propagateMasks()` API 封装和 `VideoWorkspace` 传播按钮测试,验证当前选中区域会发送到后端视频传播接口。 +- R6:补充 SAM 3 external video tracker 请求测试,验证主后端会把帧目录、源帧索引、seed bbox 和方向传给独立 Python helper。 +- R3:补充 `parseMedia()` 查询参数和后端拆帧任务 payload 测试,验证 `parse_fps`、`max_frames`、`target_width` 会进入任务。 +- R3:补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms`、`source_frame_number` 和 `result.frame_sequence` 元数据。 +- R8:补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。 +- R9:补充 Canvas 选中 mask id 全局同步、本体树点击分类给已选 mask 换标签的测试,验证已保存 mask 会进入 dirty 状态。 + ## 运行命令 ```bash diff --git a/src/components/AISegmentation.test.tsx b/src/components/AISegmentation.test.tsx index c6a9bea..25ee8e9 100644 --- a/src/components/AISegmentation.test.tsx +++ b/src/components/AISegmentation.test.tsx @@ -1,4 +1,4 @@ -import { fireEvent, render, screen } from '@testing-library/react'; +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { resetStore } from '../test/storeTestUtils'; import { useStore } from '../store/useStore'; @@ -62,4 +62,112 @@ describe('AISegmentation', () => { }, })); }); + + it('prompts for semantic text before running SAM3 inference', async () => { + apiMock.getAiModelStatus.mockResolvedValue({ + selected_model: 'sam3', + gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true }, + models: [ + { id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, + { id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true }, + ], + }); + + render(); + const sam3Button = (await screen.findByText('SAM3')).closest('button')!; + fireEvent.click(sam3Button); + fireEvent.click(screen.getByText('执行高精度语义分割')); + + expect(apiMock.predictMask).not.toHaveBeenCalled(); + expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument(); + }); + + it('shows feedback when SAM3 semantic inference returns no masks', async () => { + apiMock.getAiModelStatus.mockResolvedValue({ + selected_model: 'sam3', + gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true }, + models: [ + { id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, + { id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true }, + ], + }); + apiMock.predictMask.mockResolvedValueOnce({ masks: [] }); + + render(); + const sam3Button = (await screen.findByText('SAM3')).closest('button')!; + fireEvent.click(sam3Button); + fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), { + target: { value: '胆囊' }, + }); + fireEvent.click(screen.getByText('执行高精度语义分割')); + + await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({ + model: 'sam3', + points: undefined, + text: '胆囊', + }))); + expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument(); + }); + + it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => { + apiMock.getAiModelStatus.mockResolvedValue({ + selected_model: 'sam3', + gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true }, + models: [ + { id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false }, + { id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true }, + ], + }); + apiMock.predictMask.mockResolvedValueOnce({ + masks: [ + { + id: 'semantic-1', + pathData: 'M 10 10 L 40 10 L 40 40 Z', + label: 'semantic result', + color: '#06b6d4', + segmentation: [[10, 10, 40, 10, 40, 40]], + bbox: [10, 10, 30, 30], + area: 900, + }, + ], + }); + useStore.setState({ + activeTemplateId: 'template-1', + activeClassId: 'class-1', + activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 }, + }); + + render(); + const sam3Button = (await screen.findByText('SAM3')).closest('button')!; + fireEvent.click(sam3Button); + fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), { + target: { value: '胆囊' }, + }); + fireEvent.click(screen.getByText('执行高精度语义分割')); + + await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({ + imageId: 'frame-1', + imageWidth: 640, + imageHeight: 360, + model: 'sam3', + points: undefined, + text: '胆囊', + options: { + crop_to_prompt: false, + auto_filter_background: true, + min_score: 0.05, + }, + }))); + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + id: 'semantic-1', + frameId: 'frame-1', + templateId: 'template-1', + classId: 'class-1', + className: '胆囊', + classZIndex: 30, + label: '胆囊', + color: '#ff0000', + saveStatus: 'draft', + })); + }); }); diff --git a/src/components/AISegmentation.tsx b/src/components/AISegmentation.tsx index 16424f3..5fef8c7 100644 --- a/src/components/AISegmentation.tsx +++ b/src/components/AISegmentation.tsx @@ -33,6 +33,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const [autoDeleteBg, setAutoDeleteBg] = useState(true); const [cropMode, setCropMode] = useState(false); const [isInferencing, setIsInferencing] = useState(false); + const [inferenceMessage, setInferenceMessage] = useState(''); // Canvas state const [scale, setScale] = useState(1); @@ -91,9 +92,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { }; const runInference = useCallback(async () => { - if (points.length === 0 && !semanticText.trim()) return; + const textPrompt = semanticText.trim(); + if (aiModel === 'sam3' && !textPrompt) { + setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。'); + return; + } + if (points.length === 0 && !textPrompt) { + setInferenceMessage('请先放置正/反向提示点,或输入语义描述。'); + return; + } if (!currentFrame?.id) { console.warn('AI inference skipped: no project frame is selected'); + setInferenceMessage('请先在项目工作区选择一帧图像。'); return; } @@ -101,18 +111,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0; if (imageWidth <= 0 || imageHeight <= 0) { console.warn('AI inference skipped: active frame dimensions are unavailable'); + setInferenceMessage('当前帧缺少宽高信息,无法推理。'); return; } setIsInferencing(true); + setInferenceMessage(''); try { const result = await predictMask({ imageId: currentFrame.id, imageWidth, imageHeight, model: aiModel, - points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })), - text: semanticText.trim() || undefined, + points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })), + text: textPrompt || undefined, options: { crop_to_prompt: cropMode, auto_filter_background: autoDeleteBg, @@ -120,6 +132,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { }, }); + if (result.masks.length === 0) { + setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。'); + } else { + setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`); + } result.masks.forEach((m) => { const label = activeClass?.name || m.label; const color = activeClass?.color || m.color; @@ -142,6 +159,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { }); } catch (err) { console.error('AI inference failed:', err); + const detail = (err as any)?.response?.data?.detail; + setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。'); } finally { setIsInferencing(false); } @@ -282,6 +301,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { {isInferencing ? : } {isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'} + {inferenceMessage && ( +
+ {inferenceMessage} +
+ )} + {isBooleanTool && ( +
+ + 已选 {booleanSelectedMasks.length} + + +
)} {activeClass && ( +