From 689a9ba283a576537fe81763bcdabf5c602f21f2 Mon Sep 17 00:00:00 2001 From: admin <572701190@qq.com> Date: Fri, 1 May 2026 15:26:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=BB=BA=E7=AB=8B=20SAM2=20=E6=A0=87?= =?UTF-8?q?=E6=B3=A8=E9=97=AD=E7=8E=AF=E5=9F=BA=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。 - 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。 - 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。 - 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。 - 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。 - 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。 - 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。 - 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。 - 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。 - 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。 - 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。 - 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。 - 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。 - 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。 - 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。 - 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。 --- AGENTS.md | 30 +- README.md | 53 +- backend/config.py | 7 +- backend/progress_events.py | 4 +- backend/routers/ai.py | 265 +++++++++- backend/routers/dashboard.py | 5 +- backend/routers/export.py | 98 +++- backend/routers/tasks.py | 111 +++- backend/schemas.py | 3 + backend/services/media_task_runner.py | 49 ++ backend/services/sam3_engine.py | 157 +++++- backend/services/sam3_external_worker.py | 190 +++++++ backend/setup_sam3_env.sh | 24 + backend/statuses.py | 4 + backend/tests/test_ai.py | 100 ++++ backend/tests/test_dashboard.py | 2 + backend/tests/test_export.py | 54 +- backend/tests/test_media.py | 22 + backend/tests/test_progress_events.py | 19 + backend/tests/test_sam3_engine.py | 112 +++++ backend/tests/test_tasks.py | 104 ++++ doc/01-purpose-and-word-summary.md | 20 +- doc/02-current-implementation-map.md | 7 +- doc/03-frontend-element-audit.md | 33 +- doc/04-api-contracts.md | 25 +- doc/05-implementation-plan.md | 90 ++-- doc/07-current-requirements-freeze.md | 25 +- doc/08-current-design-freeze.md | 75 ++- doc/09-test-plan.md | 12 +- package-lock.json | 26 + package.json | 1 + src/components/AISegmentation.test.tsx | 22 + src/components/AISegmentation.tsx | 25 +- src/components/CanvasArea.test.tsx | 265 ++++++++++ src/components/CanvasArea.tsx | 615 ++++++++++++++++++++++- src/components/Dashboard.test.tsx | 106 +++- src/components/Dashboard.tsx | 195 ++++++- src/components/ToolsPalette.test.tsx | 21 +- src/components/ToolsPalette.tsx | 28 +- src/components/VideoWorkspace.test.tsx | 64 +++ src/components/VideoWorkspace.tsx | 104 +++- src/lib/api.test.ts | 94 +++- src/lib/api.ts | 57 ++- src/lib/websocket.ts | 2 +- src/store/useStore.test.ts | 11 + src/store/useStore.ts | 57 ++- src/test/setup.tsx | 61 ++- src/test/storeTestUtils.ts | 2 + 48 files changed, 3280 insertions(+), 176 deletions(-) create mode 100644 backend/services/sam3_external_worker.py create mode 100755 backend/setup_sam3_env.sh create mode 100644 backend/tests/test_sam3_engine.py create mode 100644 backend/tests/test_tasks.py diff --git a/AGENTS.md b/AGENTS.md index eccb264..c2894f0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,7 +6,7 @@ ## 项目概述 -本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 +本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 - **项目名称**: `react-example`(`package.json` 中的 `name`) - **前端入口**: `src/main.tsx` → `src/App.tsx` @@ -30,6 +30,7 @@ | 前端请求 | Axios(`src/lib/api.ts`) | | 实时通信 | WebSocket 客户端(`src/lib/websocket.ts`) | | Canvas 渲染 | Konva + react-konva + use-image | +| 几何布尔运算 | polygon-clipping | | 图标库 | lucide-react | | 动画依赖 | motion(在 `package.json` 中声明) | | AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) | @@ -38,9 +39,9 @@ | 缓存 / 队列 Broker | Redis | | 后台任务 | Celery worker | | 对象存储 | MinIO | -| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;`GET /api/ai/models/status` 返回真实 GPU/模型状态 | +| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 | | 视频 / 影像处理 | FFmpeg / OpenCV / pydicom | -| 运行时 | Node.js ES Modules;Python 3.11 后端环境 | +| 运行时 | Node.js ES Modules;Python 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 | --- @@ -70,6 +71,7 @@ Seg_Server/ │ ├── celery_app.py # Celery app 配置 │ ├── worker_tasks.py # Celery 任务入口 │ ├── download_sam2.py # SAM 2 权重下载脚本 +│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本 │ ├── requirements.txt # Python 依赖 │ ├── routers/ │ │ ├── auth.py # /api/auth/login @@ -81,7 +83,8 @@ Seg_Server/ │ └── services/ │ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传 │ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback -│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器 +│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器 +│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 └── src/ # React 前端 ├── main.tsx # React StrictMode 挂载 @@ -188,10 +191,13 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - `POST /api/media/parse` - `GET /api/tasks` - `GET /api/tasks/{task_id}` + - `POST /api/tasks/{task_id}/cancel` + - `POST /api/tasks/{task_id}/retry` - `POST /api/ai/predict` - `GET /api/ai/models/status` - `POST /api/ai/auto` - `POST /api/ai/annotate` + - `POST /api/ai/import-gt-mask` - `GET /api/ai/annotations` - `PATCH/DELETE /api/ai/annotations/{annotation_id}` - `GET /api/dashboard/overview` @@ -216,9 +222,11 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload 4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。 5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。 6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。 -7. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 调用 SAM registry。SAM 2 支持点/框/自动分割;SAM 3 入口支持文本语义推理,运行时不满足官方要求时会在状态接口中标为不可用。 -8. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 -9. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出。 +7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/difference;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。 +8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割;`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为不可用。 +9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。 +10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。 +11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。 --- @@ -226,12 +234,16 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload - `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。 - 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。 -- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;工作区“导出 JSON 标注集”按钮已绑定下载流程,导出前会先保存当前待归档 mask。 +- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。 +- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。 +- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。 +- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。 +- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。 - 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。 - `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。 -- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。 +- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。 --- diff --git a/README.md b/README.md index afe1a93..9061762 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ > 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。 > -> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。 +> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。 --- @@ -14,10 +14,11 @@ - **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析 - **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性 -- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/选点/框选,实时渲染 Mask 遮罩 +- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩 +- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index) - **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 -- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出 +- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出;PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射 --- @@ -38,7 +39,7 @@ │ ├── /api/projects 项目 & 视频帧 CRUD │ │ ├── /api/templates 本体字典(分类/颜色/z-index) │ │ ├── /api/media 文件上传 & 异步拆帧任务创建 │ -│ ├── /api/tasks Celery 后台任务状态 │ +│ ├── /api/tasks Celery 后台任务状态/取消/重试/详情 │ │ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │ │ └── /api/export COCO JSON / PNG Masks 导出 │ └──────────────────────────┬──────────────────────────────────┘ @@ -62,6 +63,7 @@ | 样式方案 | TailwindCSS + 自定义深色主题 | v4 | | 状态管理 | Zustand | - | | Canvas 渲染 | Konva + react-konva | - | +| 几何布尔运算 | polygon-clipping | 0.15+ | | HTTP 客户端 | Axios | - | | 后端框架 | FastAPI | v0.136+ | | 数据库 ORM | SQLAlchemy(依赖中包含 Alembic) | 2.0+ | @@ -92,6 +94,7 @@ Seg_Server/ │ ├── celery_app.py # Celery app 配置 │ ├── worker_tasks.py # Celery 任务入口 │ ├── download_sam2.py # SAM 2 模型权重自动下载脚本 +│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本 │ ├── requirements.txt # Python 依赖 │ ├── routers/ # API 路由 │ │ ├── auth.py # 登录认证 @@ -102,7 +105,8 @@ Seg_Server/ │ │ └── export.py # 数据导出 │ └── services/ # 业务服务 │ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级) -│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器 +│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器 +│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper │ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片 ├── src/ # React 前端 @@ -117,10 +121,10 @@ Seg_Server/ │ └── components/ # 组件(扁平化目录) │ ├── Login.tsx # 登录页 │ ├── Sidebar.tsx # 左侧导航栏 -│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列) +│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列/任务控制) │ ├── ProjectLibrary.tsx # 项目库列表 │ ├── VideoWorkspace.tsx # 核心分割工作区布局 -│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/选点/Mask渲染) +│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染) │ ├── ToolsPalette.tsx # 左侧工具栏 │ ├── OntologyInspector.tsx # 右侧本体/属性检查面板 │ ├── FrameTimeline.tsx # 底部时间轴 @@ -161,7 +165,7 @@ Seg_Server/ - **GPU**: NVIDIA GPU(推荐 RTX 4090 或同等算力),用于 SAM 推理;SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境 - **CUDA**: 12.x / 13.x - **Node.js**: 22.x+ -- **Python**: 3.11(通过 Miniconda/Anaconda 管理) +- **Python**: 主后端使用 3.11(通过 Miniconda/Anaconda 管理);SAM 3 使用独立 `sam3` Python 3.12 conda 环境 ### 安装系统级依赖 @@ -243,7 +247,22 @@ python download_sam2.py > **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。 -### 步骤 5: 配置环境变量 +### 步骤 5: 可选安装 SAM 3 环境 + +当前后端不会把 SAM 3 直接装进 `seg_server`,而是通过独立 `sam3` conda 环境执行 `backend/services/sam3_external_worker.py`。这样可以保留现有 Python 3.11 / SAM 2 环境。 + +```bash +cd ~/Desktop/Seg_Server +./backend/setup_sam3_env.sh + +# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录 +conda activate sam3 +huggingface-cli login +``` + +官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度;`facebook/sam3.1` 约 3.5 GB,主要面向新的视频 multiplex checkpoint。未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。 + +### 步骤 6: 配置环境变量 后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件: @@ -258,7 +277,10 @@ minio_secure=false sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt sam_model_config=configs/sam2/sam2_hiera_t.yaml sam_default_model=sam2 -sam3_model_version=sam3.1 +sam3_model_version=sam3 +sam3_external_enabled=true +sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python +sam3_timeout_seconds=300 cors_origins=["http://localhost:3000","http://192.168.3.11:3000"] ``` @@ -271,7 +293,7 @@ VITE_WS_PROGRESS_URL=ws://192.168.3.11:8000/ws/progress 如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://:8000`。 -### 步骤 6: 启动后端服务 +### 步骤 7: 启动后端服务 ```bash cd ~/Desktop/Seg_Server/backend @@ -287,7 +309,8 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 & - 创建数据库表(如果不存在) - 检查 MinIO bucket 是否存在 - 测试 Redis 连接 -- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3 与 GPU 的真实可用状态 +- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态 +- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤 ### 步骤 6.1: 启动 Celery Worker @@ -301,7 +324,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & ``` -`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。 +`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。 ### 步骤 7: 安装前端依赖并构建 @@ -438,7 +461,8 @@ pip install -e . --no-build-isolation - 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。 - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。 - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。 -- 工作区“导出 JSON 标注集”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 +- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 +- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。 - 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。 - 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 @@ -447,6 +471,7 @@ pip install -e . --no-build-isolation ```bash curl http://localhost:8000/health curl http://localhost:8000/api/export/1/coco +curl http://localhost:8000/api/export/1/masks ``` --- diff --git a/backend/config.py b/backend/config.py index f55c266..d305bea 100644 --- a/backend/config.py +++ b/backend/config.py @@ -22,7 +22,12 @@ class Settings(BaseSettings): sam_default_model: str = "sam2" sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml" - sam3_model_version: str = "sam3.1" + sam3_model_version: str = "sam3" + sam3_external_enabled: bool = True + sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python" + sam3_timeout_seconds: int = 300 + sam3_status_cache_seconds: int = 30 + sam3_confidence_threshold: float = 0.5 # App app_env: str = "development" diff --git a/backend/progress_events.py b/backend/progress_events.py index 6b5a1c0..083bc0a 100644 --- a/backend/progress_events.py +++ b/backend/progress_events.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from typing import Any from redis_client import get_redis_client -from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS +from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS logger = logging.getLogger(__name__) @@ -22,6 +22,8 @@ def _iso_now() -> str: def _event_type(task_status: str) -> str: if task_status == TASK_STATUS_SUCCESS: return "complete" + if task_status == TASK_STATUS_CANCELLED: + return "cancelled" if task_status == TASK_STATUS_FAILED: return "error" return "progress" diff --git a/backend/routers/ai.py b/backend/routers/ai.py index 29a7096..c785ef2 100644 --- a/backend/routers/ai.py +++ b/backend/routers/ai.py @@ -5,7 +5,7 @@ from typing import Any, List import cv2 import numpy as np -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status from sqlalchemy.orm import Session from database import get_db @@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray: raise HTTPException(status_code=500, detail="Failed to load frame image") from exc +def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]: + """Approximate a contour and convert it to normalized polygon coordinates.""" + arc_length = cv2.arcLength(contour, True) + epsilon = max(1.0, arc_length * 0.01) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape(-1, 2) + if len(points) < 3: + points = contour.reshape(-1, 2) + return [ + [ + min(max(float(x) / max(width, 1), 0.0), 1.0), + min(max(float(y) / max(height, 1), 0.0), 1.0), + ] + for x, y in points + ] + + +def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]: + x, y, w, h = cv2.boundingRect(contour) + return [ + min(max(float(x) / max(width, 1), 0.0), 1.0), + min(max(float(y) / max(height, 1), 0.0), 1.0), + min(max(float(w) / max(width, 1), 0.0), 1.0), + min(max(float(h) / max(height, 1), 0.0), 1.0), + ] + + +def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]: + """Reduce a binary component to one positive prompt point using distance transform.""" + dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5) + _, _, _, max_loc = cv2.minMaxLoc(dist) + x, y = max_loc + return [ + min(max(float(x) / max(width, 1), 0.0), 1.0), + min(max(float(y) / max(height, 1), 0.0), 1.0), + ] + + +def _clamp01(value: float) -> float: + return min(max(float(value), 0.0), 1.0) + + +def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool: + """Return whether a normalized point is inside a normalized polygon.""" + if len(polygon) < 3: + return False + x, y = point + inside = False + j = len(polygon) - 1 + for i, current in enumerate(polygon): + xi, yi = current + xj, yj = polygon[j] + intersects = ((yi > y) != (yj > y)) and ( + x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi + ) + if intersects: + inside = not inside + j = i + return inside + + +def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]: + xs = [_clamp01(point[0]) for point in points] + ys = [_clamp01(point[1]) for point in points] + x1 = max(0.0, min(xs) - margin) + y1 = max(0.0, min(ys) - margin) + x2 = min(1.0, max(xs) + margin) + y2 = min(1.0, max(ys) + margin) + if x2 - x1 < 0.05: + center = (x1 + x2) / 2 + x1 = max(0.0, center - 0.025) + x2 = min(1.0, center + 0.025) + if y2 - y1 < 0.05: + center = (y1 + y2) / 2 + y1 = max(0.0, center - 0.025) + y2 = min(1.0, center + 0.025) + return x1, y1, x2, y2 + + +def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray: + height, width = image.shape[:2] + x1, y1, x2, y2 = bounds + left = int(round(x1 * width)) + top = int(round(y1 * height)) + right = max(left + 1, int(round(x2 * width))) + bottom = max(top + 1, int(round(y2 * height))) + return image[top:bottom, left:right] + + +def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]: + x1, y1, x2, y2 = bounds + return [ + _clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)), + _clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)), + ] + + +def _from_crop_polygon( + polygon: list[list[float]], + bounds: tuple[float, float, float, float], +) -> list[list[float]]: + x1, y1, x2, y2 = bounds + return [ + [ + _clamp01(x1 + float(point[0]) * (x2 - x1)), + _clamp01(y1 + float(point[1]) * (y2 - y1)), + ] + for point in polygon + ] + + +def _filter_predictions( + polygons: list[list[list[float]]], + scores: list[float], + options: dict[str, Any], + negative_points: list[list[float]] | None = None, +) -> tuple[list[list[list[float]]], list[float]]: + if not options.get("auto_filter_background"): + return polygons, scores + + min_score = float(options.get("min_score", 0.0) or 0.0) + next_polygons: list[list[list[float]]] = [] + next_scores: list[float] = [] + for index, polygon in enumerate(polygons): + score = scores[index] if index < len(scores) else 0.0 + if score < min_score: + continue + if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points): + continue + next_polygons.append(polygon) + next_scores.append(score) + return next_polygons, next_scores + + @router.post( "/predict", response_model=PredictResponse, @@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: image = _load_frame_image(frame) prompt_type = payload.prompt_type.lower() + options = payload.options or {} polygons: List[List[List[float]]] = [] scores: List[float] = [] + negative_points: list[list[float]] = [] try: if prompt_type == "point": @@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: raise HTTPException(status_code=400, detail="Invalid point prompt data") if not isinstance(labels, list) or len(labels) != len(points): labels = [1] * len(points) - polygons, scores = sam_registry.predict_points(payload.model, image, points, labels) + negative_points = [ + point for point, label in zip(points, labels) if label == 0 + ] + inference_image = image + inference_points = points + crop_bounds = None + if options.get("crop_to_prompt"): + margin = float(options.get("crop_margin", 0.25) or 0.25) + crop_bounds = _crop_bounds_from_points(points, margin) + inference_image = _crop_image(image, crop_bounds) + inference_points = [_to_crop_point(point, crop_bounds) for point in points] + polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels) + if crop_bounds: + polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] elif prompt_type == "box": box = payload.prompt_data if not isinstance(box, list) or len(box) != 4: raise HTTPException(status_code=400, detail="Invalid box prompt data") - polygons, scores = sam_registry.predict_box(payload.model, image, box) + inference_image = image + inference_box = box + crop_bounds = None + if options.get("crop_to_prompt"): + margin = float(options.get("crop_margin", 0.05) or 0.05) + crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin) + inference_image = _crop_image(image, crop_bounds) + inference_box = [ + *_to_crop_point([box[0], box[1]], crop_bounds), + *_to_crop_point([box[2], box[3]], crop_bounds), + ] + polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box) + if crop_bounds: + polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] elif prompt_type == "semantic": text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" @@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: except NotImplementedError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc + raise HTTPException(status_code=400, detail=str(exc)) from exc + polygons, scores = _filter_predictions(polygons, scores, options, negative_points) return {"polygons": polygons, "scores": scores} @@ -161,6 +324,100 @@ def save_annotation( return annotation +@router.post( + "/import-gt-mask", + response_model=List[AnnotationOut], + status_code=status.HTTP_201_CREATED, + summary="Import a GT mask and reduce components to editable point regions", +) +async def import_gt_mask( + project_id: int = Form(...), + frame_id: int = Form(...), + template_id: int | None = Form(None), + label: str = Form("GT Mask"), + color: str = Form("#22c55e"), + file: UploadFile = File(...), + db: Session = Depends(get_db), +) -> List[Annotation]: + """Convert a binary/label mask image into persisted polygon annotations. + + Each connected component becomes one annotation. The `points` field stores a + positive seed point at the component's distance-transform center, which gives + the frontend an editable point-region representation instead of a static + bitmap layer. + """ + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first() + if not frame: + raise HTTPException(status_code=404, detail="Frame not found") + + if template_id is not None: + template = db.query(Template).filter(Template.id == template_id).first() + if not template: + raise HTTPException(status_code=404, detail="Template not found") + + data = await file.read() + image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE) + if image is None: + raise HTTPException(status_code=400, detail="Invalid mask image") + + width = int(frame.width or image.shape[1]) + height = int(frame.height or image.shape[0]) + label_values = [int(value) for value in np.unique(image) if int(value) > 0] + if not label_values: + raise HTTPException(status_code=400, detail="No foreground mask regions found") + has_multiple_labels = len(label_values) > 1 + + annotations: list[Annotation] = [] + for label_value in label_values: + binary = np.where(image == label_value, 255, 0).astype(np.uint8) + contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + annotation_label = f"{label} {label_value}" if has_multiple_labels else label + + for contour in contours: + if cv2.contourArea(contour) < 1: + continue + + polygon = _normalized_contour(contour, image.shape[1], image.shape[0]) + if len(polygon) < 3: + continue + + component = np.zeros_like(binary, dtype=np.uint8) + cv2.drawContours(component, [contour], -1, 1, thickness=-1) + seed_point = _component_seed_point(component, image.shape[1], image.shape[0]) + bbox = _contour_bbox(contour, image.shape[1], image.shape[0]) + + annotation = Annotation( + project_id=project_id, + frame_id=frame_id, + template_id=template_id, + mask_data={ + "polygons": [polygon], + "label": annotation_label, + "color": color, + "source": "gt_mask", + "gt_label_value": label_value, + "image_size": {"width": width, "height": height}, + }, + points=[seed_point], + bbox=bbox, + ) + db.add(annotation) + annotations.append(annotation) + + if not annotations: + raise HTTPException(status_code=400, detail="No foreground mask regions found") + + db.commit() + for annotation in annotations: + db.refresh(annotation) + logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id) + return annotations + + @router.get( "/annotations", response_model=List[AnnotationOut], diff --git a/backend/routers/dashboard.py b/backend/routers/dashboard.py index 6bd5a6d..b256bf2 100644 --- a/backend/routers/dashboard.py +++ b/backend/routers/dashboard.py @@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"]) ACTIVE_TASK_STATUSES = {"queued", "running"} +MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"} def _system_load_percent() -> int: @@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]: "name": task.project.name if task.project else f"任务 {task.id}", "progress": task.progress, "status": task.message or task.status, + "raw_status": task.status, "frame_count": (task.result or {}).get("frames_extracted", 0), + "error": task.error, "updated_at": _iso_or_none(task.updated_at), } @@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]: .limit(50) .all() ) - tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES] + tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES] activities: list[dict[str, Any]] = [] for task in recent_tasks[:10]: diff --git a/backend/routers/export.py b/backend/routers/export.py index 9662dd1..2a6dbf0 100644 --- a/backend/routers/export.py +++ b/backend/routers/export.py @@ -37,6 +37,54 @@ def _mask_from_polygon( return mask +def _annotation_z_index(annotation: Annotation) -> int: + class_meta = (annotation.mask_data or {}).get("class") or {} + if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None: + try: + return int(class_meta["zIndex"]) + except (TypeError, ValueError): + pass + if annotation.template and annotation.template.z_index is not None: + return int(annotation.template.z_index) + return 0 + + +def _annotation_class_key(annotation: Annotation) -> str: + class_meta = (annotation.mask_data or {}).get("class") or {} + if isinstance(class_meta, dict): + if class_meta.get("id"): + return f"class:{class_meta['id']}" + if class_meta.get("name"): + return f"name:{class_meta['name']}" + if annotation.template_id: + return f"template:{annotation.template_id}" + return f"annotation:{annotation.id}" + + +def _annotation_label(annotation: Annotation) -> str: + mask_data = annotation.mask_data or {} + class_meta = mask_data.get("class") or {} + if isinstance(class_meta, dict) and class_meta.get("name"): + return str(class_meta["name"]) + if mask_data.get("label"): + return str(mask_data["label"]) + if annotation.template and annotation.template.name: + return str(annotation.template.name) + return f"Annotation {annotation.id}" + + +def _annotation_color(annotation: Annotation) -> str: + mask_data = annotation.mask_data or {} + class_meta = mask_data.get("class") or {} + if isinstance(class_meta, dict) and class_meta.get("color"): + return str(class_meta["color"]) + if mask_data.get("color"): + return str(mask_data["color"]) + if annotation.template and annotation.template.color: + return str(annotation.template.color) + return "#ffffff" + + @router.get( "/{project_id}/coco", summary="Export annotations in COCO format", @@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp summary="Export PNG masks as a ZIP archive", ) def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse: - """Export all annotation masks as individual PNG files inside a ZIP archive.""" + """Export individual masks plus z-index fused semantic masks inside a ZIP.""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") + import cv2 + annotations = ( db.query(Annotation) .filter(Annotation.project_id == project_id) .all() ) + frames = ( + db.query(Frame) + .filter(Frame.project_id == project_id) + .order_by(Frame.frame_index) + .all() + ) + + class_values: dict[str, int] = {} + semantic_classes: list[dict[str, Any]] = [] + + def class_value(annotation: Annotation) -> int: + key = _annotation_class_key(annotation) + if key not in class_values: + value = len(class_values) + 1 + class_values[key] = value + semantic_classes.append({ + "value": value, + "key": key, + "label": _annotation_label(annotation), + "color": _annotation_color(annotation), + "zIndex": _annotation_z_index(annotation), + "template_id": annotation.template_id, + }) + return class_values[key] zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {} for ann in annotations: if not ann.mask_data: continue @@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes mask = _mask_from_polygon(poly, width, height) combined = np.maximum(combined, mask) - # Encode PNG - import cv2 _, encoded = cv2.imencode(".png", combined) fname = f"mask_{ann.id:06d}.png" zf.writestr(fname, encoded.tobytes()) + if ann.frame_id is not None: + frame_masks.setdefault(ann.frame_id, []).append((ann, combined)) + + for frame in frames: + entries = frame_masks.get(frame.id, []) + if not entries: + continue + width = frame.width or 1920 + height = frame.height or 1080 + semantic = np.zeros((height, width), dtype=np.uint8) + for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])): + semantic[mask > 0] = class_value(ann) + _, encoded = cv2.imencode(".png", semantic) + zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes()) + + zf.writestr( + "semantic_classes.json", + json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"), + ) zip_buffer.seek(0) filename = f"project_{project_id}_masks.zip" diff --git a/backend/routers/tasks.py b/backend/routers/tasks.py index 9c1d335..385bdc7 100644 --- a/backend/routers/tasks.py +++ b/backend/routers/tasks.py @@ -1,15 +1,45 @@ """Processing task query endpoints.""" +import logging +from datetime import datetime, timezone from typing import List -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session +from celery_app import celery_app from database import get_db -from models import ProcessingTask +from models import ProcessingTask, Project +from progress_events import publish_task_progress_event from schemas import ProcessingTaskOut +from statuses import ( + PROJECT_STATUS_PARSING, + PROJECT_STATUS_PENDING, + PROJECT_STATUS_READY, + TASK_ACTIVE_STATUSES, + TASK_STATUS_CANCELLED, + TASK_STATUS_FAILED, + TASK_STATUS_QUEUED, +) +from worker_tasks import parse_project_media router = APIRouter(prefix="/api/tasks", tags=["Tasks"]) +logger = logging.getLogger(__name__) + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask: + task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() + if not task: + raise HTTPException(status_code=404, detail="Task not found") + return task + + +def _project_status_after_stop(project: Project) -> str: + return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING @router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks") @@ -31,7 +61,78 @@ def list_tasks( @router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task") def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: """Return one background task by id.""" - task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() - if not task: - raise HTTPException(status_code=404, detail="Task not found") + return _get_task_or_404(task_id, db) + + +@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task") +def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: + """Cancel a queued/running background task and revoke the Celery job when possible.""" + task = _get_task_or_404(task_id, db) + if task.status not in TASK_ACTIVE_STATUSES: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Task is not cancellable in status: {task.status}", + ) + + if task.celery_task_id: + try: + celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM") + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc) + + task.status = TASK_STATUS_CANCELLED + task.progress = 100 + task.message = "任务已取消" + task.error = "Cancelled by user" + task.finished_at = _now() + if task.project: + task.project.status = _project_status_after_stop(task.project) + + db.commit() + db.refresh(task) + publish_task_progress_event(task) + return task + + +@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task") +def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: + """Create a fresh queued task from a failed or cancelled task.""" + previous = _get_task_or_404(task_id, db) + if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Task is not retryable in status: {previous.status}", + ) + if previous.project_id is None: + raise HTTPException(status_code=400, detail="Task has no project_id") + + project = db.query(Project).filter(Project.id == previous.project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + if not project.video_path: + raise HTTPException(status_code=400, detail="Project has no media uploaded") + + payload = dict(previous.payload or {}) + payload.setdefault("source_type", project.source_type or "video") + payload["retry_of"] = previous.id + + task = ProcessingTask( + task_type=previous.task_type, + status=TASK_STATUS_QUEUED, + progress=0, + message=f"重试任务已入队(源任务 #{previous.id})", + project_id=project.id, + payload=payload, + ) + project.status = PROJECT_STATUS_PARSING + db.add(task) + db.commit() + db.refresh(task) + publish_task_progress_event(task) + + async_result = parse_project_media.delay(task.id) + task.celery_task_id = async_result.id + db.commit() + db.refresh(task) + publish_task_progress_event(task) return task diff --git a/backend/schemas.py b/backend/schemas.py index ac74557..1654a36 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -180,6 +180,7 @@ class PredictRequest(BaseModel): prompt_type: str # point / box / semantic prompt_data: Any model: Optional[str] = None + options: Optional[dict[str, Any]] = None class PredictResponse(BaseModel): @@ -201,6 +202,8 @@ class AiModelStatus(BaseModel): python_ok: bool = True torch_ok: bool = True cuda_required: bool = False + external_available: bool = False + external_python: Optional[str] = None class GpuStatus(BaseModel): diff --git a/backend/services/media_task_runner.py b/backend/services/media_task_runner.py index 396d376..9db8989 100644 --- a/backend/services/media_task_runner.py +++ b/backend/services/media_task_runner.py @@ -20,9 +20,11 @@ from services.frame_parser import ( upload_frames_to_minio, ) from statuses import ( + PROJECT_STATUS_PENDING, PROJECT_STATUS_ERROR, PROJECT_STATUS_PARSING, PROJECT_STATUS_READY, + TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_RUNNING, TASK_STATUS_SUCCESS, @@ -31,6 +33,10 @@ from statuses import ( logger = logging.getLogger(__name__) +class TaskCancelled(RuntimeError): + """Raised internally when a persisted task has been cancelled.""" + + def _now() -> datetime: return datetime.now(timezone.utc) @@ -66,12 +72,29 @@ def _set_task_state( publish_task_progress_event(task) +def _project_status_after_stop(project: Project) -> str: + return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING + + +def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None: + db.refresh(task) + if task.status == TASK_STATUS_CANCELLED: + raise TaskCancelled("Task was cancelled") + + def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: """Parse one project's media and update task progress in the database.""" task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() if not task: raise ValueError(f"Task not found: {task_id}") + if task.status == TASK_STATUS_CANCELLED: + return { + "task_id": task.id, + "status": TASK_STATUS_CANCELLED, + "message": task.message or "任务已取消", + } + if task.project_id is None: _set_task_state( db, @@ -111,6 +134,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: db.commit() raise ValueError("Project has no media uploaded") + _ensure_not_cancelled(db, task) project.status = PROJECT_STATUS_PARSING _set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True) @@ -121,6 +145,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: os.makedirs(output_dir, exist_ok=True) try: + _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=15, message="正在下载媒体文件") if effective_source == "dicom": dcm_dir = os.path.join(tmp_dir, "dcm") @@ -129,20 +154,24 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: client = get_minio_client() objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)) for obj in objects: + _ensure_not_cancelled(db, task) if obj.object_name.lower().endswith(".dcm"): data = download_file(obj.object_name) local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name)) with open(local_dcm, "wb") as f: f.write(data) + _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=35, message="正在解析 DICOM 序列") frame_files = parse_dicom(dcm_dir, output_dir) else: + _ensure_not_cancelled(db, task) media_bytes = download_file(project.video_path) local_path = os.path.join(tmp_dir, Path(project.video_path).name) with open(local_path, "wb") as f: f.write(media_bytes) + _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧") frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps)) project.original_fps = original_fps @@ -158,12 +187,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: except Exception as exc: # noqa: BLE001 logger.warning("Thumbnail extraction failed: %s", exc) + _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=70, message="正在上传帧到对象存储") object_names = upload_frames_to_minio(frame_files, project.id) + _ensure_not_cancelled(db, task) _set_task_state(db, task, progress=85, message="正在写入帧索引") frames_out = [] for idx, obj_name in enumerate(object_names): + _ensure_not_cancelled(db, task) local_frame = frame_files[idx] try: import cv2 @@ -203,6 +235,23 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: ) logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id) return result + except TaskCancelled: + project.status = _project_status_after_stop(project) + task.status = TASK_STATUS_CANCELLED + task.progress = 100 + task.message = task.message or "任务已取消" + task.error = task.error or "Cancelled by user" + task.finished_at = task.finished_at or _now() + db.commit() + db.refresh(task) + publish_task_progress_event(task) + logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id) + return { + "task_id": task.id, + "project_id": project.id, + "status": TASK_STATUS_CANCELLED, + "message": task.message, + } except Exception as exc: # noqa: BLE001 project.status = PROJECT_STATUS_ERROR _set_task_state( diff --git a/backend/services/sam3_engine.py b/backend/services/sam3_engine.py index 7f71d64..8213be7 100644 --- a/backend/services/sam3_engine.py +++ b/backend/services/sam3_engine.py @@ -9,8 +9,14 @@ the package. from __future__ import annotations import importlib.util +import json import logging +import os +import subprocess import sys +import tempfile +import time +from pathlib import Path from typing import Any import numpy as np @@ -41,6 +47,8 @@ class SAM3Engine: self._processor: Any | None = None self._model_loaded = False self._last_error: str | None = None + self._external_status_cache: dict[str, Any] | None = None + self._external_status_checked_at = 0.0 def _python_ok(self) -> bool: return sys.version_info >= (3, 12) @@ -51,6 +59,81 @@ class SAM3Engine: def _can_load(self) -> bool: return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok()) + def _worker_path(self) -> Path: + return Path(__file__).with_name("sam3_external_worker.py") + + def _external_python_exists(self) -> bool: + return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python)) + + def _external_status(self, force: bool = False) -> dict[str, Any]: + now = time.monotonic() + if ( + not force + and self._external_status_cache is not None + and now - self._external_status_checked_at < settings.sam3_status_cache_seconds + ): + return self._external_status_cache + + if not settings.sam3_external_enabled: + status = { + "available": False, + "package_available": False, + "python_ok": False, + "torch_ok": False, + "cuda_available": False, + "device": "unavailable", + "message": "SAM 3 external runtime is disabled.", + } + elif not self._external_python_exists(): + status = { + "available": False, + "package_available": False, + "python_ok": False, + "torch_ok": False, + "cuda_available": False, + "device": "unavailable", + "message": f"SAM 3 external Python not found: {settings.sam3_external_python}", + } + else: + try: + env = os.environ.copy() + env["SAM3_MODEL_VERSION"] = settings.sam3_model_version + completed = subprocess.run( + [settings.sam3_external_python, str(self._worker_path()), "--status"], + capture_output=True, + text=True, + timeout=min(settings.sam3_timeout_seconds, 30), + check=False, + env=env, + ) + if completed.returncode != 0: + detail = completed.stderr.strip() or completed.stdout.strip() + status = { + "available": False, + "package_available": False, + "python_ok": False, + "torch_ok": False, + "cuda_available": False, + "device": "unavailable", + "message": f"SAM 3 external status failed: {detail}", + } + else: + status = json.loads(completed.stdout) + except Exception as exc: # noqa: BLE001 + status = { + "available": False, + "package_available": False, + "python_ok": False, + "torch_ok": False, + "cuda_available": False, + "device": "unavailable", + "message": f"SAM 3 external status failed: {exc}", + } + + self._external_status_cache = status + self._external_status_checked_at = now + return status + def _load_model(self) -> None: if self._model_loaded: return @@ -92,26 +175,86 @@ class SAM3Engine: return "SAM 3 dependencies are present; model will load on first inference." def status(self) -> dict: - available = self._can_load() + external_status = self._external_status() + available = bool(self._can_load() or external_status.get("available")) + external_ready = bool(external_status.get("available")) + message = self._last_error or self._status_message() + if self._processor is not None: + message = "SAM 3 model loaded and ready." + elif external_ready: + message = "SAM 3 external runtime is ready; model will load in the helper process on inference." + elif external_status.get("message") and not self._can_load(): + message = str(external_status["message"]) return { "id": "sam3", "label": "SAM 3", "available": available, "loaded": self._processor is not None, - "device": "cuda" if self._gpu_ok() else "unavailable", + "device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")), "supports": ["semantic"], - "message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()), - "package_available": SAM3_PACKAGE_AVAILABLE, - "checkpoint_exists": SAM3_PACKAGE_AVAILABLE, + "message": message, + "package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")), + "checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")), "checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})", - "python_ok": self._python_ok(), - "torch_ok": TORCH_AVAILABLE, + "python_ok": bool(self._python_ok() or external_status.get("python_ok")), + "torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")), "cuda_required": True, + "external_available": external_ready, + "external_python": settings.sam3_external_python if settings.sam3_external_enabled else None, } + def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: + status = self._external_status(force=True) + if not status.get("available"): + raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.") + + with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir: + tmp_path = Path(tmpdir) + image_path = tmp_path / "image.png" + request_path = tmp_path / "request.json" + Image.fromarray(image).save(image_path) + request_path.write_text( + json.dumps( + { + "image_path": str(image_path), + "text": text.strip(), + "model_version": settings.sam3_model_version, + "confidence_threshold": settings.sam3_confidence_threshold, + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + env = os.environ.copy() + env["SAM3_MODEL_VERSION"] = settings.sam3_model_version + completed = subprocess.run( + [settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)], + capture_output=True, + text=True, + timeout=settings.sam3_timeout_seconds, + check=False, + env=env, + ) + + if completed.returncode != 0: + detail = completed.stderr.strip() or completed.stdout.strip() + try: + parsed = json.loads(detail) + detail = parsed.get("error", detail) + except Exception: # noqa: BLE001 + pass + raise RuntimeError(f"SAM 3 external inference failed: {detail}") + + payload = json.loads(completed.stdout) + if payload.get("error"): + raise RuntimeError(str(payload["error"])) + return payload.get("polygons", []), payload.get("scores", []) + def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: if not text.strip(): raise ValueError("SAM 3 semantic prompt requires non-empty text.") + if not self._can_load() and self._external_status().get("available"): + return self._predict_semantic_external(image, text) if not self._ensure_ready(): raise RuntimeError(self.status()["message"]) diff --git a/backend/services/sam3_external_worker.py b/backend/services/sam3_external_worker.py new file mode 100644 index 0000000..7f4e614 --- /dev/null +++ b/backend/services/sam3_external_worker.py @@ -0,0 +1,190 @@ +"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime. + +The main FastAPI backend can keep running in the existing Python 3.11/SAM 2 +environment while this helper is executed with a separate conda env that meets +SAM 3's stricter runtime requirements. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import os +import sys +from pathlib import Path +from typing import Any + +import numpy as np +from PIL import Image + + +def _torch_status() -> tuple[bool, str | None, str | None, str | None]: + try: + import torch + + cuda_available = bool(torch.cuda.is_available()) + return ( + cuda_available, + getattr(torch, "__version__", None), + getattr(torch.version, "cuda", None), + torch.cuda.get_device_name(0) if cuda_available else None, + ) + except Exception: # noqa: BLE001 + return False, None, None, None + + +def _compact_error(exc: Exception) -> str: + lines = [line.strip() for line in str(exc).splitlines() if line.strip()] + for line in lines: + if "Access to model" in line or "Cannot access gated repo" in line: + return line + return lines[0] if lines else exc.__class__.__name__ + + +def _checkpoint_access(model_version: str) -> tuple[bool, str | None]: + try: + from huggingface_hub import hf_hub_download + + repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3" + hf_hub_download(repo_id=repo_id, filename="config.json") + return True, None + except Exception as exc: # noqa: BLE001 + return False, _compact_error(exc) + + +def runtime_status() -> dict[str, Any]: + model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3") + package_error = None + package_available = importlib.util.find_spec("sam3") is not None + if package_available: + try: + import sam3 # noqa: F401 + except Exception as exc: # noqa: BLE001 + package_available = False + package_error = str(exc) + cuda_available, torch_version, cuda_version, device_name = _torch_status() + python_ok = sys.version_info >= (3, 12) + checkpoint_access = False + checkpoint_error = None + if package_available: + checkpoint_access, checkpoint_error = _checkpoint_access(model_version) + available = bool(package_available and python_ok and cuda_available and checkpoint_access) + missing = [] + if not python_ok: + missing.append("Python 3.12+ runtime") + if not package_available: + missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package") + if torch_version is None: + missing.append("PyTorch") + if not cuda_available: + missing.append("CUDA GPU") + if package_available and not checkpoint_access: + missing.append(f"Hugging Face checkpoint access ({checkpoint_error})") + return { + "available": available, + "package_available": package_available, + "checkpoint_access": checkpoint_access, + "python_ok": python_ok, + "torch_ok": torch_version is not None, + "torch_version": torch_version, + "cuda_version": cuda_version, + "cuda_available": cuda_available, + "device": "cuda" if cuda_available else "unavailable", + "device_name": device_name, + "message": ( + "SAM 3 external runtime is ready." + if available + else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}." + ), + } + + +def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: + import cv2 + + if mask.dtype != np.uint8: + mask = (mask > 0).astype(np.uint8) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + height, width = mask.shape[:2] + largest = [] + for contour in contours: + if len(contour) > len(largest): + largest = contour + if len(largest) < 3: + return [] + return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest] + + +def _to_numpy(value: Any) -> np.ndarray: + if hasattr(value, "detach"): + value = value.detach().cpu().numpy() + elif hasattr(value, "cpu"): + value = value.cpu().numpy() + return np.asarray(value) + + +def predict(request_path: Path) -> dict[str, Any]: + import torch + from sam3.model.sam3_image_processor import Sam3Processor + from sam3.model_builder import build_sam3_image_model + + payload = json.loads(request_path.read_text(encoding="utf-8")) + image_path = Path(payload["image_path"]) + text = str(payload["text"]).strip() + threshold = float(payload.get("confidence_threshold", 0.5)) + if not text: + raise ValueError("SAM 3 semantic prompt requires non-empty text.") + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + image = Image.open(image_path).convert("RGB") + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + model = build_sam3_image_model() + processor = Sam3Processor(model, confidence_threshold=threshold) + state = processor.set_image(image) + output = processor.set_text_prompt(state=state, prompt=text) + + masks = _to_numpy(output.get("masks", [])) + scores = _to_numpy(output.get("scores", [])) + if masks.ndim == 4: + masks = masks[:, 0] + elif masks.ndim == 3 and masks.shape[0] == 1: + masks = masks[None, 0] + + polygons = [] + for mask in masks: + polygon = _mask_to_polygon(mask) + if polygon: + polygons.append(polygon) + + return { + "polygons": polygons, + "scores": scores.astype(float).tolist() if scores.size else [], + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="SAM 3 external runtime helper") + parser.add_argument("--status", action="store_true") + parser.add_argument("--request", type=Path) + args = parser.parse_args() + + try: + if args.status: + print(json.dumps(runtime_status(), ensure_ascii=False)) + return 0 + if args.request: + print(json.dumps(predict(args.request), ensure_ascii=False)) + return 0 + parser.error("Use --status or --request") + except Exception as exc: # noqa: BLE001 + print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr) + return 1 + + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/backend/setup_sam3_env.sh b/backend/setup_sam3_env.sh new file mode 100755 index 0000000..8bde621 --- /dev/null +++ b/backend/setup_sam3_env.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Create the dedicated SAM 3 runtime used by backend/services/sam3_external_worker.py. +# Keep Hugging Face tokens outside this repository, for example: +# export HF_TOKEN=... +# huggingface-cli login --token "$HF_TOKEN" + +set -euo pipefail + +ENV_NAME="${SAM3_CONDA_ENV:-sam3}" + +source /home/wkmgc/miniconda3/etc/profile.d/conda.sh + +if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then + conda create -y -n "$ENV_NAME" python=3.12 +fi + +conda activate "$ENV_NAME" +python -m pip install --upgrade pip +python -m pip install "setuptools<81" +python -m pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128 +python -m pip install opencv-python pillow numpy huggingface_hub einops pycocotools psutil +python -m pip install git+https://github.com/facebookresearch/sam3.git + +python /home/wkmgc/Desktop/Seg_Server/backend/services/sam3_external_worker.py --status diff --git a/backend/statuses.py b/backend/statuses.py index 1cec48f..8d45153 100644 --- a/backend/statuses.py +++ b/backend/statuses.py @@ -9,3 +9,7 @@ TASK_STATUS_QUEUED = "queued" TASK_STATUS_RUNNING = "running" TASK_STATUS_SUCCESS = "success" TASK_STATUS_FAILED = "failed" +TASK_STATUS_CANCELLED = "cancelled" + +TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING} +TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED} diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index 3d22f15..b31f410 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -1,4 +1,5 @@ import numpy as np +import cv2 def _create_project_and_frame(client): @@ -46,6 +47,46 @@ def test_predict_accepts_point_object_with_labels(client, monkeypatch): assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0]) +def test_predict_applies_crop_and_background_filter_options(client, monkeypatch): + _, frame, _ = _create_project_and_frame(client) + calls = {} + monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8)) + + def fake_predict_points(model, image, points, labels): + calls["shape"] = image.shape + calls["points"] = points + calls["labels"] = labels + return ( + [ + [[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]], + [[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]], + ], + [0.9, 0.01], + ) + + monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points) + + response = client.post("/api/ai/predict", json={ + "image_id": frame["id"], + "prompt_type": "point", + "prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]}, + "options": { + "crop_to_prompt": True, + "crop_margin": 0.1, + "auto_filter_background": True, + "min_score": 0.05, + }, + }) + + assert response.status_code == 200 + assert calls["shape"][0] < 100 + assert calls["shape"][1] < 200 + assert calls["labels"] == [1, 0] + assert response.json()["scores"] == [0.9] + polygon = response.json()["polygons"][0] + assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point) + + def test_predict_box_and_semantic_fallback(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) @@ -246,3 +287,62 @@ def test_update_and_delete_annotation_validation(client): f"/api/ai/annotations/{saved['id']}", json={"template_id": 999}, ).status_code == 404 + + +def test_import_gt_mask_creates_annotations_with_seed_points(client): + project, frame, template = _create_project_and_frame(client) + mask = np.zeros((360, 640), dtype=np.uint8) + cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1) + ok, encoded = cv2.imencode(".png", mask) + assert ok + + response = client.post( + "/api/ai/import-gt-mask", + data={ + "project_id": str(project["id"]), + "frame_id": str(frame["id"]), + "template_id": str(template["id"]), + "label": "Imported GT", + "color": "#22c55e", + }, + files={"file": ("mask.png", encoded.tobytes(), "image/png")}, + ) + + assert response.status_code == 201 + body = response.json() + assert len(body) == 1 + assert body[0]["project_id"] == project["id"] + assert body[0]["frame_id"] == frame["id"] + assert body[0]["template_id"] == template["id"] + assert body[0]["mask_data"]["label"] == "Imported GT" + assert body[0]["mask_data"]["source"] == "gt_mask" + assert body[0]["mask_data"]["gt_label_value"] == 255 + assert len(body[0]["mask_data"]["polygons"][0]) >= 3 + assert len(body[0]["points"]) == 1 + assert 0.0 <= body[0]["points"][0][0] <= 1.0 + assert 0.0 <= body[0]["points"][0][1] <= 1.0 + + +def test_import_gt_mask_splits_label_values(client): + project, frame, _ = _create_project_and_frame(client) + mask = np.zeros((360, 640), dtype=np.uint8) + cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1) + cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1) + ok, encoded = cv2.imencode(".png", mask) + assert ok + + response = client.post( + "/api/ai/import-gt-mask", + data={ + "project_id": str(project["id"]), + "frame_id": str(frame["id"]), + "label": "GT Class", + }, + files={"file": ("labels.png", encoded.tobytes(), "image/png")}, + ) + + assert response.status_code == 201 + body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"]) + assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2] + assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"] + assert all(len(item["points"]) == 1 for item in body) diff --git a/backend/tests/test_dashboard.py b/backend/tests/test_dashboard.py index aabc75c..b83303f 100644 --- a/backend/tests/test_dashboard.py +++ b/backend/tests/test_dashboard.py @@ -59,7 +59,9 @@ def test_dashboard_overview_uses_persisted_records(client, db_session): "name": "Pending Project", "progress": 35, "status": "正在使用 FFmpeg/OpenCV 拆帧", + "raw_status": "running", "frame_count": 0, + "error": None, "updated_at": body["tasks"][0]["updated_at"], }, ] diff --git a/backend/tests/test_export.py b/backend/tests/test_export.py index 898ba4e..940b0af 100644 --- a/backend/tests/test_export.py +++ b/backend/tests/test_export.py @@ -1,6 +1,10 @@ import zipfile +import json from io import BytesIO +import cv2 +import numpy as np + def _seed_export_data(client): project = client.post("/api/projects", json={"name": "Export Project"}).json() @@ -58,7 +62,55 @@ def test_export_masks_zip(client): assert response.status_code == 200 assert response.headers["content-type"].startswith("application/zip") with zipfile.ZipFile(BytesIO(response.content)) as archive: - assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"] + assert archive.namelist() == [ + f"mask_{annotation['id']:06d}.png", + "semantic_frame_000000.png", + "semantic_classes.json", + ] + + +def test_export_masks_uses_z_index_for_semantic_fusion(client): + project = client.post("/api/projects", json={"name": "Fusion Project"}).json() + frame = client.post(f"/api/projects/{project['id']}/frames", json={ + "project_id": project["id"], + "frame_index": 0, + "image_url": "frames/0.jpg", + "width": 20, + "height": 20, + }).json() + low = client.post("/api/ai/annotate", json={ + "project_id": project["id"], + "frame_id": frame["id"], + "mask_data": { + "polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]], + "label": "Low", + "color": "#00ff00", + "class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10}, + }, + }).json() + high = client.post("/api/ai/annotate", json={ + "project_id": project["id"], + "frame_id": frame["id"], + "mask_data": { + "polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]], + "label": "High", + "color": "#ff0000", + "class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20}, + }, + }).json() + + response = client.get(f"/api/export/{project['id']}/masks") + + assert response.status_code == 200 + with zipfile.ZipFile(BytesIO(response.content)) as archive: + assert f"mask_{low['id']:06d}.png" in archive.namelist() + assert f"mask_{high['id']:06d}.png" in archive.namelist() + legend = json.loads(archive.read("semantic_classes.json")) + high_value = next(item["value"] for item in legend["classes"] if item["key"] == "class:high") + semantic_bytes = np.frombuffer(archive.read("semantic_frame_000000.png"), dtype=np.uint8) + semantic = cv2.imdecode(semantic_bytes, cv2.IMREAD_GRAYSCALE) + + assert semantic[10, 10] == high_value def test_export_missing_project_returns_404(client): diff --git a/backend/tests/test_media.py b/backend/tests/test_media.py index e3ecce7..685b182 100644 --- a/backend/tests/test_media.py +++ b/backend/tests/test_media.py @@ -140,3 +140,25 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp assert project_detail["status"] == "ready" frames = client.get(f"/api/projects/{project['id']}/frames").json() assert "frame_000001.jpg" in frames[0]["image_url"] + + +def test_parse_task_runner_skips_already_cancelled_task(db_session): + from models import ProcessingTask + from services.media_task_runner import run_parse_media_task + + task = ProcessingTask( + task_type="parse_video", + status="cancelled", + progress=100, + message="任务已取消", + project_id=1, + payload={"source_type": "video"}, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + result = run_parse_media_task(db_session, task.id) + + assert result["status"] == "cancelled" + assert result["message"] == "任务已取消" diff --git a/backend/tests/test_progress_events.py b/backend/tests/test_progress_events.py index 171f27a..79473a6 100644 --- a/backend/tests/test_progress_events.py +++ b/backend/tests/test_progress_events.py @@ -26,6 +26,25 @@ def test_task_progress_payload_uses_dashboard_task_id_and_project_name(): assert payload["status"] == "解析完成" +def test_task_progress_payload_marks_cancelled_tasks(): + task = SimpleNamespace( + id=13, + project_id=7, + project=SimpleNamespace(name="demo.mp4"), + status="cancelled", + progress=100, + message="任务已取消", + error="Cancelled by user", + updated_at=None, + ) + + payload = task_progress_payload(task) + + assert payload["type"] == "cancelled" + assert payload["status"] == "任务已取消" + assert payload["error"] == "Cancelled by user" + + def test_publish_progress_event_writes_json_to_redis(monkeypatch): calls = [] diff --git a/backend/tests/test_sam3_engine.py b/backend/tests/test_sam3_engine.py new file mode 100644 index 0000000..e114599 --- /dev/null +++ b/backend/tests/test_sam3_engine.py @@ -0,0 +1,112 @@ +import json +from pathlib import Path + +import numpy as np + +from services.sam3_engine import SAM3Engine + + +class _Completed: + def __init__(self, returncode=0, stdout="", stderr=""): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +def _external_settings(monkeypatch, python_path: Path): + python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8") + python_path.chmod(0o755) + monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False) + monkeypatch.setattr("services.sam3_engine.TORCH_AVAILABLE", False) + monkeypatch.setattr("services.sam3_engine.settings.sam3_external_enabled", True) + monkeypatch.setattr("services.sam3_engine.settings.sam3_external_python", str(python_path)) + monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10) + monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30) + monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4) + + +def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + + def fake_run(args, **_kwargs): + assert "--status" in args + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + status = SAM3Engine().status() + + assert status["available"] is True + assert status["external_available"] is True + assert status["package_available"] is True + assert status["python_ok"] is True + assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference." + + +def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + calls = [] + + def fake_run(args, **_kwargs): + calls.append(args) + if "--status" in args: + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + request_path = Path(args[-1]) + request = json.loads(request_path.read_text(encoding="utf-8")) + assert request["text"] == "vessel" + assert request["confidence_threshold"] == 0.4 + assert Path(request["image_path"]).exists() + return _Completed(stdout=json.dumps({ + "polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], + "scores": [0.91], + })) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + polygons, scores = SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), " vessel ") + + assert polygons == [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]] + assert scores == [0.91] + assert any("--request" in args for args in calls) + + +def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch): + _external_settings(monkeypatch, tmp_path / "python") + + def fake_run(args, **_kwargs): + if "--status" in args: + return _Completed(stdout=json.dumps({ + "available": True, + "package_available": True, + "python_ok": True, + "torch_ok": True, + "cuda_available": True, + "device": "cuda", + "message": "ready", + })) + return _Completed(returncode=1, stderr=json.dumps({"error": "HF access denied"})) + + monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run) + + try: + SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), "vessel") + except RuntimeError as exc: + assert "HF access denied" in str(exc) + else: + raise AssertionError("Expected SAM 3 external inference failure.") diff --git a/backend/tests/test_tasks.py b/backend/tests/test_tasks.py new file mode 100644 index 0000000..482cd1f --- /dev/null +++ b/backend/tests/test_tasks.py @@ -0,0 +1,104 @@ +from models import ProcessingTask + + +def test_cancel_task_revokes_celery_and_updates_project(client, db_session, monkeypatch): + project = client.post("/api/projects", json={ + "name": "Cancelable", + "video_path": "uploads/1/clip.mp4", + "status": "parsing", + }).json() + task = ProcessingTask( + task_type="parse_video", + status="running", + progress=35, + message="正在使用 FFmpeg/OpenCV 拆帧", + project_id=project["id"], + celery_task_id="celery-1", + payload={"source_type": "video"}, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + revoked = [] + published = [] + monkeypatch.setattr( + "routers.tasks.celery_app.control.revoke", + lambda celery_id, terminate, signal: revoked.append((celery_id, terminate, signal)), + ) + monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append(event_task.status)) + + response = client.post(f"/api/tasks/{task.id}/cancel") + + assert response.status_code == 200 + body = response.json() + assert body["status"] == "cancelled" + assert body["progress"] == 100 + assert body["message"] == "任务已取消" + assert body["error"] == "Cancelled by user" + assert revoked == [("celery-1", True, "SIGTERM")] + assert published == ["cancelled"] + assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending" + + +def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch): + project = client.post("/api/projects", json={ + "name": "Retryable", + "video_path": "uploads/2/clip.mp4", + "source_type": "video", + "status": "error", + }).json() + task = ProcessingTask( + task_type="parse_video", + status="failed", + progress=100, + message="解析失败", + error="ffmpeg failed", + project_id=project["id"], + payload={"source_type": "video"}, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + class FakeAsyncResult: + id = "celery-retry" + + queued = [] + published = [] + monkeypatch.setattr("routers.tasks.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult()) + monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append((event_task.id, event_task.status))) + + response = client.post(f"/api/tasks/{task.id}/retry") + + assert response.status_code == 202 + body = response.json() + assert body["id"] != task.id + assert body["status"] == "queued" + assert body["progress"] == 0 + assert body["celery_task_id"] == "celery-retry" + assert body["payload"]["retry_of"] == task.id + assert queued == [body["id"]] + assert published[0] == (body["id"], "queued") + assert published[-1] == (body["id"], "queued") + assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing" + + +def test_task_actions_reject_invalid_states(client, db_session): + project = client.post("/api/projects", json={ + "name": "Done", + "video_path": "uploads/3/clip.mp4", + }).json() + task = ProcessingTask( + task_type="parse_video", + status="success", + progress=100, + project_id=project["id"], + payload={"source_type": "video"}, + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + assert client.post(f"/api/tasks/{task.id}/cancel").status_code == 409 + assert client.post(f"/api/tasks/{task.id}/retry").status_code == 409 diff --git a/doc/01-purpose-and-word-summary.md b/doc/01-purpose-and-word-summary.md index 93fe322..d5fc3bb 100644 --- a/doc/01-purpose-and-word-summary.md +++ b/doc/01-purpose-and-word-summary.md @@ -38,21 +38,21 @@ Word 方案描述的理想系统包含: | 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` | | DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 | | WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` | -| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 依赖官方运行环境,当前环境不满足时会标为不可用 | -| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;重叠裁决算法未落地 | -| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新和当前帧删除;逐点几何编辑未落地 | -| COCO / Mask 导出 | 部分落地 | `backend/routers/export.py`;COCO JSON 前端按钮已接入,PNG mask ZIP 尚未提供前端按钮 | +| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 | +| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 | +| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 | +| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON 和 PNG mask ZIP 前端按钮均已接入,ZIP 包含单标注 mask、语义融合 mask 和类别映射 | ## 当前代码尚未落地的目标 -- SAM 3:当前已提供 `sam3_engine.py` 适配入口和状态检测;要实际运行仍需安装官方 `facebookresearch/sam3` 依赖并满足 Python 3.12+、PyTorch 2.7+、CUDA 12.6+。 +- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU 和官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny。 - Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。 -- GT mask 导入:当前前端没有 GT Label 导入入口,后端也没有对应路由。 -- Mask 到点区域的拓扑降维:当前没有距离变换、骨架提取、HDBSCAN 等实现。 -- 类别优先级融合:模板有 z-index,但没有后端融合算法。 -- 撤销/重做:工具栏有按钮,但没有历史栈。 +- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point;骨架提取、HDBSCAN 和模板自动映射尚未实现。 +- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。 +- 类别优先级融合:PNG mask 导出时已按 zIndex 生成语义融合 mask;前端裁决预览尚未实现。 +- 撤销/重做:当前已有全局 mask 历史栈。 - 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask,并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。 ## 结论 -当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可实时查看任务进度、可浏览项目帧、可维护模板、可点/框 AI 推理、可保存标注、可导出 COCO、可查看 Dashboard 后端概览”的全栈雏形,但离 Word 中描述的完整智能标注系统还有明显差距。下一阶段最重要的是继续补齐手工绘制、撤销重做和真实语义文本分割。 +当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。 diff --git a/doc/02-current-implementation-map.md b/doc/02-current-implementation-map.md index 981029c..ecc453c 100644 --- a/doc/02-current-implementation-map.md +++ b/doc/02-current-implementation-map.md @@ -71,6 +71,7 @@ 5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。 6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。 7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。 +8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。 ### 工作区浏览 @@ -98,7 +99,7 @@ ## 当前主要风险点 - 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。 -- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖时走 SAM 3;当前环境若不满足会在模型状态中标明不可用。 -- 工作区顶部“导出 JSON 标注集”和“结构化归档保存”已接入导出、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。撤销重做和手工绘制仍未持久化。 -- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。 +- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖、并具备 Hugging Face gated 权重访问时走 SAM 3;当前状态接口会分别暴露外部 Python 环境、CUDA、包导入和 checkpoint access 是否满足。 +- 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。 +- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。 - 后端路由大多未做真实鉴权。 diff --git a/doc/03-frontend-element-audit.md b/doc/03-frontend-element-audit.md index cb9a5e5..662947d 100644 --- a/doc/03-frontend-element-audit.md +++ b/doc/03-frontend-element-audit.md @@ -30,8 +30,11 @@ | 元素 | 状态 | 说明 | |------|------|------| | WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` | -| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running 任务生成 | -| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`,FastAPI 广播 progress/complete/error | +| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/failed/cancelled 任务生成 | +| 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` | +| 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 | +| 失败详情 | 真实可用 | 任务详情按钮调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间 | +| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`,FastAPI 广播 progress/complete/error/cancelled | | 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 | | 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录;WebSocket status/complete/error 会继续追加 | @@ -60,6 +63,8 @@ | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 | | 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask | | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON | +| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` | +| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 | | “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` | ## CanvasArea 画布 @@ -70,10 +75,13 @@ | 滚轮缩放 | 真实可用 | 改变 Konva Stage scale | | 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable | | 光标坐标显示 | 真实可用 | 根据 pointer position 计算 | -| 正向/反向选点 | 部分可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;需点击归档保存才持久化 | -| 框选 | 部分可用 | UI 能画框,并把框坐标归一化后调用后端推理;需点击归档保存才持久化 | +| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 | +| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 | | AI 推理中提示 | 真实可用 | 请求期间会显示 | -| Mask 渲染 | 部分可用 | 前端会把推理/已保存标注转成 Konva `pathData` 渲染 | +| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon;点工具生成小区域;均写入 `Mask.segmentation`,可归档保存 | +| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 | +| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点 | +| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty,归档保存会更新后端 | | 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 | | 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask | | 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 | @@ -84,11 +92,11 @@ | 元素 | 状态 | 说明 | |------|------|------| | 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 | -| 多边形/矩形/圆/点/线 | Mock / UI-only | 只切换 activeTool,没有对应绘制逻辑 | -| 区域合并/去除 | Mock / UI-only | 只切换 activeTool,没有后端或前端算法 | +| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask | +| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 | | 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 | -| 撤销/重做 | Mock / UI-only | 按钮无事件 | +| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y | ## FrameTimeline 时间轴 @@ -117,10 +125,11 @@ |------|------|------| | 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry | | 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 | -| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且运行环境满足官方依赖时走 SAM 3 文本语义推理,否则状态接口会标明不可用 | -| 参数开关 | Mock / UI-only | `cropMode`、`autoDeleteBg` 只改本地状态 | +| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 | +| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 | | 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 | | 上传替换底图 | Mock / UI-only | 按钮无事件 | +| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 | | 清空全体锚点 | 部分可用 | 清空前端 points 和 masks | | 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store,但没有保存/确认流程 | | 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 | @@ -141,6 +150,6 @@ ## 总体结论 -当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区点/框 AI 推理、标注保存/回显、COCO 导出、模板 CRUD。 +当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。 -当前最主要的 Mock 或未打通链路是:撤销重做、手工几何绘制、GT 导入、mask 降维点区域、真正的文本语义分割和语义优先级融合。 +当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。 diff --git a/doc/04-api-contracts.md b/doc/04-api-contracts.md index c4108dc..a263413 100644 --- a/doc/04-api-contracts.md +++ b/doc/04-api-contracts.md @@ -34,6 +34,8 @@ Authorization: Bearer | `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data | | `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task | | `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 | +| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery | +| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 | | `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url | | `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` | | `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 | @@ -41,8 +43,10 @@ Authorization: Bearer | `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask | | `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask | | `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 | +| `importGtMask(file, projectId, frameId, templateId?)` | `POST /api/ai/import-gt-mask` | 对齐 | multipart 上传 GT mask,后端按非零像素值/连通域生成 polygon 标注和 seed point | | `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 | | `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` | +| `exportMasks(projectId)` | `GET /api/export/{projectId}/masks` | 对齐 | 下载单标注 mask、语义融合 mask 和类别映射 ZIP | ## 后端 FastAPI 接口 @@ -69,10 +73,13 @@ Authorization: Bearer | POST | `/api/media/parse` | 创建 Celery 拆帧任务 | | GET | `/api/tasks` | 查询后台任务列表 | | GET | `/api/tasks/{task_id}` | 查询单个后台任务 | +| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 | +| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 | | POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 | | GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 | | POST | `/api/ai/auto` | 自动分割 | | POST | `/api/ai/annotate` | 保存 AI 标注 | +| POST | `/api/ai/import-gt-mask` | 导入 GT mask 并生成标注/seed point | | GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 | | PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 | | DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 | @@ -143,7 +150,13 @@ Authorization: Bearer - `point` - `box` -- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation +- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定。 + +可选 `options` 字段: + +- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。 +- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。 +- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。 后端响应: @@ -181,13 +194,19 @@ Authorization: Bearer - `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`。 - `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`。 - `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`。 +- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。 +- `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`。 - `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。 - `getTask()` 已接入 `GET /api/tasks/{taskId}`。 +- `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`。 +- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`。 - `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。 -- 工作区导出按钮已调用 `exportCoco()`,并会先保存未归档 mask。 +- Dashboard 任务列表已展示 queued/running/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情。 +- 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。 +- PNG mask ZIP 已包含每帧 `semantic_frame_*.png` 和 `semantic_classes.json`,重叠区域按 zIndex 裁决。 ## 仍需处理的接口问题 - WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。 - Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`;FastAPI 订阅后广播到 `/ws/progress`。 -- 已保存标注目前支持分类级更新和整帧清空删除;逐点几何编辑器尚未实现。 +- 已保存标注目前支持分类级更新、polygon 顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑后的 PATCH 更新和整帧清空删除;复杂洞结构的专业编辑仍未实现。 diff --git a/doc/05-implementation-plan.md b/doc/05-implementation-plan.md index 08b0450..fd2c6b7 100644 --- a/doc/05-implementation-plan.md +++ b/doc/05-implementation-plan.md @@ -16,8 +16,8 @@ 剩余边界: -1. SAM 3 真实推理需要独立满足官方 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ 环境。 -2. 标注删除/更新接口已打通基础能力;逐点几何编辑器尚未实现。 +1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test。 +2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。 ## 阶段 2:打通标注保存(已完成基础闭环) @@ -34,16 +34,22 @@ 剩余建议: 1. 加入保存冲突处理和批量保存错误提示。 -2. 增加逐点几何编辑器,让已保存 mask 的 polygon 本身可以被修改后 PATCH。 +2. 逐点几何编辑器已支持拖动/删除顶点、边中点插入新点和多 polygon 子区域编辑;后续增强为复杂洞结构编辑。 +3. 区域合并/去除已支持基础 union/difference;后续增强为更明确的多选列表、操作预览和冲突确认。 -## 阶段 3:接入导出按钮(已完成 COCO JSON) +## 阶段 3:接入导出按钮(已完成 COCO JSON 和 PNG Mask ZIP) -当前工作区“导出 JSON 标注集”会先保存未归档 mask,再调用 COCO 导出接口。 +当前工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”都会先保存未归档 mask,再调用后端导出接口。 -建议: +已完成: -1. 增加“导出 PNG Mask ZIP”按钮,调用 `/api/export/{projectId}/masks`。 -2. 无标注时给出更明确的空导出提示。 +1. COCO JSON 调用 `/api/export/{projectId}/coco`。 +2. PNG Mask ZIP 调用 `/api/export/{projectId}/masks`。 +3. ZIP 内保留单标注二值 `mask_*.png`,同时输出 `semantic_frame_*.png` 和 `semantic_classes.json`。 + +剩余建议: + +1. 无标注时给出更明确的空导出提示。 ## 阶段 4:替换 Dashboard mock @@ -52,13 +58,18 @@ 已完成: - 聚合项目、帧、标注、模板数量和主机 load average。 -- 按 `processing_tasks` queued/running 任务生成解析队列。 +- 按 `processing_tasks` queued/running/failed/cancelled 任务生成解析队列。 - 按最近任务、项目、标注、模板记录生成活动流。 +已完成补充: + +1. Dashboard 对 queued/running 任务提供取消按钮。 +2. Dashboard 对 failed/cancelled 任务提供重试按钮。 +3. Dashboard 详情弹窗展示任务 error、payload、result、Celery ID 和时间。 + 剩余建议: -1. 为任务增加取消、重试和失败详情 UI。 -2. 为 Dashboard 增加任务历史筛选和失败详情入口。 +1. 为 Dashboard 增加任务历史筛选。 ## 阶段 5:异步拆帧和进度 @@ -72,44 +83,55 @@ Word 方案中提到 Celery + Redis。当前已经有 Celery app、worker task 4. worker 写 PostgreSQL 任务进度。 5. worker 发布 Redis `seg:progress`,FastAPI 广播到 `/ws/progress`。 +已完成补充: + +1. `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,并尝试 revoke Celery。 +2. `POST /api/tasks/{task_id}/retry` 为 failed/cancelled 任务创建新的 queued 任务。 +3. worker 在关键阶段检查 cancelled 状态,避免取消后继续写帧。 +4. Redis/WebSocket 进度事件增加 `cancelled` 类型。 + +Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务历史筛选和更细的 worker 中断粒度。 + +## 阶段 6:GT 导入与点区域(已完成基础增强版) + +Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前已经完成基础增强版:导入二值/标签 mask 图片后,后端按非零像素值拆分类别,再按连通域生成 polygon 标注,并用距离变换提取一个正向 seed point。 + +已完成: + +1. 工作区提供“导入 GT Mask”入口。 +2. 前端调用 `POST /api/ai/import-gt-mask` multipart 接口。 +3. 后端按非零像素值拆分多类别 mask。 +4. 后端使用 OpenCV contour 提取每个类别下的连通域。 +5. 后端使用 distance transform 生成 `points` seed。 +6. 导入结果写入 `annotations` 表并回显为工作区 mask。 +7. 前端把 seed point 转为像素坐标显示在 Canvas 上,拖动后会标记标注为 dirty 并可归档保存。 + 剩余建议: -1. 为任务增加取消、重试和失败详情接口。 -2. 前端 Dashboard 保留轮询兜底,并补充失败详情 UI。 +1. 增加骨架提取和聚类增强。 +2. 为多类别像素值提供模板分类自动映射规则。 -Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务控制。 +## 阶段 7:模板优先级融合(已完成导出侧裁决) -## 阶段 6:GT 导入与点区域 +当前导出 PNG Mask ZIP 时已经按 class/template z-index 做重叠裁决,从低到高覆盖,生成每帧 `semantic_frame_*.png`。 -这是 Word 方案中最复杂的部分,当前完全未实现。 - -建议拆成小步: - -1. 先支持上传二值/多类别 mask。 -2. 后端按类别提取 connected components。 -3. 用 OpenCV distance transform 找正向点。 -4. 暂时不做骨架/HDBSCAN,先生成最小可用点集。 -5. 前端以可拖拽点显示并保存。 -6. 后续再做骨架和聚类增强。 - -## 阶段 7:模板优先级融合 - -当前模板有 z-index,但没有真正用于语义冲突裁决。 - -建议: +已完成: 1. 标注保存时记录 template class id / name / zIndex。 2. 导出 mask 时按 zIndex 从低到高覆盖。 -3. 同类 mask 做 union。 +3. 同类语义值在融合图中共享同一个 class value。 4. 跨类重叠由高 zIndex 覆盖低 zIndex。 -这一步完成后,系统才真正符合“语义分割一个像素一个类别”的目标。 +剩余建议: + +1. 在前端预览重叠裁决结果。 +2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。 ## 阶段 8:清理 UI 文案与 Mock 建议统一这些文案和真实能力: - SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。 -- 撤销/重做按钮接历史栈,否则隐藏。 +- 撤销/重做按钮已接全局 mask 历史栈。 - “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。 - AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。 diff --git a/doc/07-current-requirements-freeze.md b/doc/07-current-requirements-freeze.md index 609f0c4..1236846 100644 --- a/doc/07-current-requirements-freeze.md +++ b/doc/07-current-requirements-freeze.md @@ -31,6 +31,9 @@ - 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。 - 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。 - 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。 +- 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。 +- 后端支持 `POST /api/tasks/{task_id}/retry` 对 failed/cancelled 任务创建新的 queued 任务。 +- worker 会在关键阶段检查任务是否已取消,取消后停止继续写帧。 ## R4 工作区与帧浏览 @@ -46,7 +49,13 @@ - 工具栏可以切换当前 active tool。 - 正向点、反向点、框选工具会影响 Canvas 交互。 - 魔法棒按钮切换到 AI 页面。 -- 多边形、矩形、圆、点、线、合并、去除、撤销、重做当前只提供 UI 状态或占位按钮,不完成真实绘制/算法。 +- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。 +- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。 +- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。 +- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。 +- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。 +- 区域合并工具支持多选当前帧 mask,并使用 polygon union 生成合并后的主 mask。 +- 区域去除工具支持多选当前帧 mask,并从第一个选中的主 mask 中扣除后续选中 mask。 ## R6 AI 推理 @@ -56,7 +65,8 @@ - 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。 - 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。 - 框选提示传归一化 `[x1, y1, x2, y2]`。 -- 语义文本提示传 `semantic`;选择 `sam3` 且环境满足依赖时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 +- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 +- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。 - 后端返回 `polygons` 和 `scores`。 - 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。 - AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。 @@ -71,6 +81,9 @@ - 当前前端“结构化归档保存”会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。 - 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。 - 工作区加载项目帧后会查询已保存标注并回显。 +- 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask`。 +- 后端导入 GT mask 时按非零像素值拆分多类别区域,再按连通域生成 polygon 标注,并通过距离变换写入 seed point。 +- 前端会回显导入标注的 seed point;拖动 seed point 后,已保存标注会变为 dirty,归档保存时会更新后端 `points`。 ## R8 模板库 @@ -93,9 +106,12 @@ - Dashboard 显示基础统计、解析队列和活动日志。 - Dashboard 初始数据来自 `GET /api/dashboard/overview`。 - 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。 -- 解析队列由 `processing_tasks` 中的 queued/running 任务生成;活动日志由最近任务、项目、标注和模板记录生成。 +- 解析队列由 `processing_tasks` 中的 queued/running/failed/cancelled 任务生成;活动日志由最近任务、项目、标注和模板记录生成。 +- Dashboard 对 queued/running 任务提供取消按钮,对 failed/cancelled 任务提供重试按钮。 +- Dashboard 任务详情会读取 `GET /api/tasks/{task_id}` 并展示失败 error、payload、result、Celery ID 和时间信息。 - Dashboard 会连接 `/ws/progress`。 - 收到 progress、complete、error、status 消息时,前端会更新队列或日志。 +- 收到 cancelled 消息时,前端会把对应任务标记为已取消。 - Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件,FastAPI 订阅并广播给 `/ws/progress` 客户端。 - 后端 WebSocket 接收到客户端消息后返回 status heartbeat。 @@ -104,7 +120,10 @@ - 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。 - 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。 - 当前前端 `exportCoco()` API 封装已对齐后端路径。 +- 当前前端 `exportMasks()` API 封装已对齐后端路径。 - 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。 +- 工作区“导出 PNG Mask ZIP”按钮已绑定下载事件;导出前会先保存当前未归档 mask。 +- PNG mask ZIP 包含单标注二值 mask、按 zIndex 融合后的每帧语义 mask 和 `semantic_classes.json`。 ## R12 配置 diff --git a/doc/08-current-design-freeze.md b/doc/08-current-design-freeze.md index 5aebebd..e284b4b 100644 --- a/doc/08-current-design-freeze.md +++ b/doc/08-current-design-freeze.md @@ -19,17 +19,17 @@ | 模块 | 文件 | 设计职责 | |------|------|----------| | 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 | -| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态 | +| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 | | API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 | | 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 | | WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 | | 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 | | 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store | -| Dashboard | `src/components/Dashboard.tsx` | 展示统计和 WebSocket 进度消息 | +| Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 | | 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | -| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具和跳转 AI 页面 | +| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 | | 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 | | 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 | | AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 | @@ -51,7 +51,7 @@ | AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 | | Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 | | SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 | -| SAM 3 | `backend/services/sam3_engine.py` | SAM 3 状态检测和文本语义推理适配 | +| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 | | SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 | ## 状态模型 @@ -62,6 +62,7 @@ - `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。 - `Template` / `TemplateClass`:模板和分类定义。 - `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。 +- `maskHistory` / `maskFuture`:mask 编辑历史栈,用于撤销和重做。 - `activeModule`:当前页面。 - `activeTool`:当前工具。 - `aiModel`:当前选择的 AI 模型,取值为 `sam2` 或 `sam3`。 @@ -82,6 +83,14 @@ 4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。 5. 刷新项目列表。 +### 任务控制 + +1. Dashboard 从 `GET /api/dashboard/overview` 读取 queued/running/failed/cancelled 任务。 +2. 用户取消任务时,前端调用 `POST /api/tasks/{task_id}/cancel`;后端写入 `cancelled`、设置 `finished_at`,并尝试 `celery_app.control.revoke(..., terminate=True)`。 +3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。 +4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。 +5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。 + ### 工作区加载 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。 @@ -102,6 +111,41 @@ 9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 +### 手工绘制与历史栈 + +1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。 +2. `CanvasArea` 将交互坐标转换成像素 polygon。 +3. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。 +4. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。 +5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。 + +### Polygon 逐点编辑 + +1. 用户点击 Canvas 上的 mask path 后,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。 +2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation`、`bbox`、`area`。 +3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。 +4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。 +5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。 + +### 区域合并与去除 + +1. 用户选择 `area_merge` 或 `area_remove` 后,点击多个当前帧 mask 组成选择集。 +2. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。 +3. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。 +4. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留。 +5. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路。 + +### GT Mask 导入 + +1. 工作区“导入 GT Mask”选择图片文件。 +2. 前端 `importGtMask()` 以 multipart form-data 调用 `POST /api/ai/import-gt-mask`,携带 `project_id` 和 `frame_id`。 +3. 后端验证项目、帧、模板后使用 OpenCV 读取灰度 mask。 +4. 后端按非零像素值拆分多类别标签。 +5. 后端对每个类别的前景做 contour 提取,每个连通域保存为一个 `Annotation`。 +6. `points` 字段保存距离变换中心 seed point,`mask_data.polygons` 保存 normalized polygon,`mask_data.gt_label_value` 保存原始像素类别值。 +7. 前端重新读取项目标注并回显。 +8. `annotationToMask()` 会把 normalized seed point 转成像素坐标,Canvas 以可拖拽点显示;拖动后 `buildAnnotationPayload()` 会把点再归一化写回后端。 + ### 模板管理 1. `TemplateRegistry` 从后端读取模板。 @@ -114,8 +158,10 @@ ### 导出 1. 后端根据项目、帧、标注和模板生成 COCO JSON。 -2. PNG mask 导出会把 normalized polygon 渲染为二值 mask 并打包 ZIP。 -3. 前端“导出 JSON 标注集”按钮会在导出前保存待归档标注,然后下载 COCO JSON。 +2. PNG mask 导出会把 normalized polygon 渲染为单标注二值 mask。 +3. PNG mask 导出还会按 `mask_data.class.zIndex` 或模板 `z_index` 从低到高覆盖,生成每帧语义融合 mask。 +4. ZIP 内写入 `semantic_classes.json`,记录语义值到类别、颜色和 zIndex 的映射。 +5. 前端“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮都会在导出前保存待归档标注,然后下载对应文件。 ## 接口契约 @@ -123,13 +169,18 @@ - `updateProject()` 使用 `PATCH /api/projects/{id}`。 - `exportCoco()` 使用 `GET /api/export/{projectId}/coco`。 +- `exportMasks()` 使用 `GET /api/export/{projectId}/masks`。 +- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。 +- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。 - `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。 - `saveAnnotation()` 使用 `POST /api/ai/annotate`。 +- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。 - `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。 - `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。 - `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。 - 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。 -- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态。 +- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。 +- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 ## 外部依赖边界 @@ -146,10 +197,8 @@ 以下能力属于当前冻结版本的占位或半可用功能: -- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running 任务生成。 -- 多边形、矩形、圆、点、线手工绘制未实现。 -- 合并、去除、撤销、重做未实现。 -- 工作区导出 PNG mask ZIP 按钮尚未提供。 -- 已保存标注支持通过“应用分类”进入 dirty 状态并归档更新;暂未提供逐点几何编辑器。 -- SAM 3 文本语义分割取决于官方依赖和 GPU 运行环境;状态接口会暴露真实可用性。 +- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。 +- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。 +- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权时 `available=false`。 - 自定义分类只存在本地组件状态。 +- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。 diff --git a/doc/09-test-plan.md b/doc/09-test-plan.md index 1bf11b7..ebfe0dd 100644 --- a/doc/09-test-plan.md +++ b/doc/09-test-plan.md @@ -16,15 +16,15 @@ |------|----------|--------| | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | -| R3 媒体上传与拆帧 | `backend/tests/test_media.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧 | +| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 | | R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 | -| R5 工具栏 | `src/components/ToolsPalette.test.tsx` | 工具切换、AI 跳转、占位按钮存在 | -| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | -| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、项目不存在、帧不存在 | +| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 | +| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | +| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 | | R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 | -| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py` | 后端概览接口、任务表驱动队列、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | -| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO 按钮下载、导出前自动保存、COCO 路径、JSON 结构、mask ZIP | +| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | +| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | diff --git a/package-lock.json b/package-lock.json index 51f889c..f02a5ee 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,6 +18,7 @@ "konva": "^10.2.5", "lucide-react": "^0.546.0", "motion": "^12.23.24", + "polygon-clipping": "^0.15.7", "react": "^19.0.0", "react-dom": "^19.0.0", "react-konva": "^19.2.3", @@ -4165,6 +4166,16 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/polygon-clipping": { + "version": "0.15.7", + "resolved": "https://registry.npmjs.org/polygon-clipping/-/polygon-clipping-0.15.7.tgz", + "integrity": "sha512-nhfdr83ECBg6xtqOAJab1tbksbBAOMUltN60bU+llHVOL0e5Onm1WpAXXWXVB39L8AJFssoIhEVuy/S90MmotA==", + "license": "MIT", + "dependencies": { + "robust-predicates": "^3.0.2", + "splaytree": "^3.1.0" + } + }, "node_modules/postcss": { "version": "8.5.12", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz", @@ -4438,6 +4449,12 @@ "node": ">= 4" } }, + "node_modules/robust-predicates": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/robust-predicates/-/robust-predicates-3.0.3.tgz", + "integrity": "sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==", + "license": "Unlicense" + }, "node_modules/rollup": { "version": "4.60.2", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz", @@ -4684,6 +4701,15 @@ "node": ">=0.10.0" } }, + "node_modules/splaytree": { + "version": "3.2.3", + "resolved": "https://registry.npmjs.org/splaytree/-/splaytree-3.2.3.tgz", + "integrity": "sha512-7OXrNWzy6CK+r7Ch9OLPBDTKfB6XlWHjX4P0RU5B3IgFuWPeYN0XtRtlexGRjgbQxpfaUve6jTAwBGWuGntz/w==", + "license": "MIT", + "engines": { + "node": ">=18.20 || >=20" + } + }, "node_modules/stackback": { "version": "0.0.2", "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", diff --git a/package.json b/package.json index 4f8d419..1f5332e 100644 --- a/package.json +++ b/package.json @@ -24,6 +24,7 @@ "konva": "^10.2.5", "lucide-react": "^0.546.0", "motion": "^12.23.24", + "polygon-clipping": "^0.15.7", "react": "^19.0.0", "react-dom": "^19.0.0", "react-konva": "^19.2.3", diff --git a/src/components/AISegmentation.test.tsx b/src/components/AISegmentation.test.tsx index e0bb7ba..c6a9bea 100644 --- a/src/components/AISegmentation.test.tsx +++ b/src/components/AISegmentation.test.tsx @@ -40,4 +40,26 @@ describe('AISegmentation', () => { expect(useStore.getState().aiModel).toBe('sam3'); expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument(); }); + + it('passes enabled inference parameters to the backend', async () => { + apiMock.predictMask.mockResolvedValueOnce({ masks: [] }); + render(); + + fireEvent.click(screen.getByText('正向选点')); + fireEvent.click(screen.getByTestId('konva-stage')); + fireEvent.click(await screen.findByText('执行高精度语义分割')); + + expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({ + imageId: 'frame-1', + imageWidth: 640, + imageHeight: 360, + model: 'sam2', + points: [{ x: 120, y: 80, type: 'pos' }], + options: { + crop_to_prompt: false, + auto_filter_background: true, + min_score: 0.05, + }, + })); + }); }); diff --git a/src/components/AISegmentation.tsx b/src/components/AISegmentation.tsx index b45f638..16424f3 100644 --- a/src/components/AISegmentation.tsx +++ b/src/components/AISegmentation.tsx @@ -17,6 +17,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const masks = useStore((state) => state.masks); const addMask = useStore((state) => state.addMask); const clearMasks = useStore((state) => state.clearMasks); + const maskHistory = useStore((state) => state.maskHistory); + const maskFuture = useStore((state) => state.maskFuture); + const undoMasks = useStore((state) => state.undoMasks); + const redoMasks = useStore((state) => state.redoMasks); const frames = useStore((state) => state.frames); const currentFrameIndex = useStore((state) => state.currentFrameIndex); const activeTemplateId = useStore((state) => state.activeTemplateId); @@ -109,6 +113,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { model: aiModel, points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })), text: semanticText.trim() || undefined, + options: { + crop_to_prompt: cropMode, + auto_filter_background: autoDeleteBg, + min_score: autoDeleteBg ? 0.05 : 0, + }, }); result.masks.forEach((m) => { @@ -136,7 +145,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { } finally { setIsInferencing(false); } - }, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]); + }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]); const handleStageClick = (e: any) => { if (effectiveTool === 'move') return; @@ -290,10 +299,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { {aiModel.toUpperCase()} 动态推理渲染
- -
diff --git a/src/components/CanvasArea.test.tsx b/src/components/CanvasArea.test.tsx index a462b7a..0b809f2 100644 --- a/src/components/CanvasArea.test.tsx +++ b/src/components/CanvasArea.test.tsx @@ -79,6 +79,271 @@ describe('CanvasArea', () => { expect(screen.getByText('遮罩数: 1')).toBeInTheDocument(); }); + it('renders imported GT seed points for editable point regions', () => { + useStore.setState({ + masks: [ + { + id: 'gt-1', + frameId: 'frame-1', + pathData: 'M 0 0 L 10 0 L 10 10 Z', + label: 'GT', + color: '#22c55e', + points: [[120, 80]], + }, + ], + }); + + render(); + + expect(screen.getAllByTestId('konva-circle')).toHaveLength(2); + }); + + it('selects a polygon mask and drags a vertex into dirty saved state', () => { + useStore.setState({ + masks: [ + { + id: 'annotation-99', + annotationId: '99', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 40 Z', + label: 'Saved', + color: '#06b6d4', + saved: true, + saveStatus: 'saved', + segmentation: [[10, 10, 90, 10, 90, 40]], + bbox: [10, 10, 80, 30], + }, + ], + }); + + render(); + fireEvent.click(screen.getByTestId('konva-path')); + const handles = screen.getAllByTestId('konva-circle') + .filter((element) => element.getAttribute('data-fill') === '#ffffff'); + expect(handles).toHaveLength(3); + + fireEvent.mouseUp(handles[0], { clientX: 20, clientY: 30 }); + + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + pathData: 'M 20 30 L 90 10 L 90 40 Z', + segmentation: [[20, 30, 90, 10, 90, 40]], + bbox: [20, 10, 70, 30], + area: 1050, + saveStatus: 'dirty', + saved: false, + })); + }); + + it('deletes a selected polygon vertex without dropping below three points', () => { + useStore.setState({ + masks: [ + { + id: 'draft-1', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 40 L 10 40 Z', + label: 'Draft', + color: '#06b6d4', + saveStatus: 'draft', + segmentation: [[10, 10, 90, 10, 90, 40, 10, 40]], + bbox: [10, 10, 80, 30], + }, + ], + }); + + render(); + fireEvent.click(screen.getByTestId('konva-path')); + const handles = screen.getAllByTestId('konva-circle') + .filter((element) => element.getAttribute('data-fill') === '#ffffff'); + fireEvent.click(handles[0]); + fireEvent.keyDown(window, { key: 'Delete' }); + + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + pathData: 'M 90 10 L 90 40 L 10 40 Z', + segmentation: [[90, 10, 90, 40, 10, 40]], + saveStatus: 'draft', + })); + }); + + it('inserts a polygon vertex from an edge midpoint handle', () => { + useStore.setState({ + masks: [ + { + id: 'draft-1', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 40 Z', + label: 'Draft', + color: '#06b6d4', + saveStatus: 'draft', + segmentation: [[10, 10, 90, 10, 90, 40]], + bbox: [10, 10, 80, 30], + }, + ], + }); + + render(); + fireEvent.click(screen.getByTestId('konva-path')); + const edgeHandles = screen.getAllByTestId('konva-circle') + .filter((element) => element.getAttribute('data-fill') === '#22d3ee'); + fireEvent.click(edgeHandles[0]); + + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]], + pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z', + saveStatus: 'draft', + })); + }); + + it('edits the selected polygon in a multi-polygon mask', () => { + useStore.setState({ + masks: [ + { + id: 'multi-1', + frameId: 'frame-1', + pathData: 'M 10 10 L 50 10 L 50 40 Z M 100 100 L 150 100 L 150 140 Z', + label: 'Multi', + color: '#06b6d4', + saveStatus: 'draft', + segmentation: [ + [10, 10, 50, 10, 50, 40], + [100, 100, 150, 100, 150, 140], + ], + bbox: [10, 10, 140, 130], + }, + ], + }); + + render(); + const paths = screen.getAllByTestId('konva-path'); + fireEvent.click(paths[1]); + const vertexHandles = screen.getAllByTestId('konva-circle') + .filter((element) => element.getAttribute('data-fill') === '#ffffff'); + fireEvent.mouseUp(vertexHandles[0], { clientX: 120, clientY: 120 }); + + expect(useStore.getState().masks[0].segmentation).toEqual([ + [10, 10, 50, 10, 50, 40], + [120, 120, 150, 100, 150, 140], + ]); + }); + + it('merges selected draft masks with polygon union', () => { + useStore.setState({ + masks: [ + { + id: 'm1', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z', + label: 'A', + color: '#06b6d4', + segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]], + }, + { + id: 'm2', + frameId: 'frame-1', + pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z', + label: 'B', + color: '#ff0000', + segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]], + }, + ], + }); + + render(); + const paths = screen.getAllByTestId('konva-path'); + fireEvent.click(paths[0]); + fireEvent.click(paths[1]); + fireEvent.click(screen.getByRole('button', { name: '合并选中' })); + + expect(useStore.getState().masks).toHaveLength(1); + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + id: 'm1', + segmentation: [[10, 10, 90, 10, 90, 30, 120, 30, 120, 80, 50, 80, 50, 50, 10, 50]], + bbox: [10, 10, 110, 70], + saveStatus: 'draft', + })); + }); + + it('removes overlap from the primary selected mask with polygon difference', () => { + useStore.setState({ + masks: [ + { + id: 'm1', + frameId: 'frame-1', + pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z', + label: 'A', + color: '#06b6d4', + segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]], + }, + { + id: 'm2', + frameId: 'frame-1', + pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z', + label: 'B', + color: '#ff0000', + segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]], + }, + ], + }); + + render(); + const paths = screen.getAllByTestId('konva-path'); + fireEvent.click(paths[0]); + fireEvent.click(paths[1]); + fireEvent.click(screen.getByRole('button', { name: '从主区域去除' })); + + expect(useStore.getState().masks).toHaveLength(2); + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + id: 'm1', + segmentation: [[10, 10, 90, 10, 90, 30, 50, 30, 50, 50, 10, 50]], + bbox: [10, 10, 80, 40], + saveStatus: 'draft', + })); + expect(useStore.getState().masks[1].id).toBe('m2'); + }); + + it('creates a manual rectangle mask that can be undone and redone', () => { + useStore.setState({ + activeTemplateId: '2', + activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 }, + activeClassId: 'c1', + }); + + render(); + const stage = screen.getByTestId('konva-stage'); + fireEvent.mouseDown(stage); + fireEvent.mouseMove(stage); + fireEvent.mouseUp(stage); + + expect(useStore.getState().masks).toHaveLength(1); + expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ + frameId: 'frame-1', + label: '胆囊', + color: '#ff0000', + saveStatus: 'draft', + segmentation: [[120, 80, 260, 80, 260, 200, 120, 200]], + bbox: [120, 80, 140, 120], + })); + + useStore.getState().undoMasks(); + expect(useStore.getState().masks).toEqual([]); + useStore.getState().redoMasks(); + expect(useStore.getState().masks).toHaveLength(1); + }); + + it('finalizes a clicked polygon with Enter', () => { + render(); + const stage = screen.getByTestId('konva-stage'); + fireEvent.click(stage, { clientX: 120, clientY: 80 }); + fireEvent.click(stage, { clientX: 220, clientY: 80 }); + fireEvent.click(stage, { clientX: 180, clientY: 160 }); + fireEvent.keyDown(window, { key: 'Enter' }); + + expect(useStore.getState().masks).toHaveLength(1); + expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({ + source: 'manual', + shape: '多边形', + })); + }); + it('applies the selected class to current-frame masks and marks saved masks dirty', () => { useStore.setState({ activeTemplateId: '2', diff --git a/src/components/CanvasArea.tsx b/src/components/CanvasArea.tsx index de2ff59..7873a75 100644 --- a/src/components/CanvasArea.tsx +++ b/src/components/CanvasArea.tsx @@ -1,17 +1,180 @@ import React, { useEffect, useRef, useState, useCallback } from 'react'; import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva'; +import polygonClipping, { type MultiPolygon, type Pair } from 'polygon-clipping'; import useImage from 'use-image'; import { useStore } from '../store/useStore'; import { predictMask } from '../lib/api'; -import type { Frame } from '../store/useStore'; +import type { Frame, Mask } from '../store/useStore'; interface CanvasAreaProps { activeTool: string; frame: Frame | null; onClearMasks?: () => void; + onDeleteMaskAnnotations?: (annotationIds: string[]) => Promise | void; } -export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) { +type CanvasPoint = { x: number; y: number }; + +const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']); +const POLYGON_TOOL = 'create_polygon'; +const POINT_TOOL = 'create_point'; +const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']); + +function clamp(value: number, min: number, max: number): number { + return Math.min(Math.max(value, min), max); +} + +function polygonPath(points: CanvasPoint[]): string { + if (points.length === 0) return ''; + return points + .map((point, index) => `${index === 0 ? 'M' : 'L'} ${point.x} ${point.y}`) + .join(' ') + .concat(' Z'); +} + +function segmentationPath(segmentation?: number[][]): string { + return (segmentation || []) + .map((polygon) => polygonPath(flatPolygonToPoints(polygon))) + .filter(Boolean) + .join(' '); +} + +function segmentationPolygonPath(segmentation: number[][] | undefined, polygonIndex: number): string { + const polygon = segmentation?.[polygonIndex]; + return polygon ? polygonPath(flatPolygonToPoints(polygon)) : ''; +} + +function polygonSegmentation(points: CanvasPoint[]): number[][] { + return [points.flatMap((point) => [point.x, point.y])]; +} + +function segmentationToPoints(segmentation?: number[][], polygonIndex = 0): CanvasPoint[] { + const polygon = segmentation?.[polygonIndex] || []; + const points: CanvasPoint[] = []; + for (let index = 0; index < polygon.length - 1; index += 2) { + points.push({ x: polygon[index], y: polygon[index + 1] }); + } + return points; +} + +function flatPolygonToPoints(polygon: number[]): CanvasPoint[] { + const points: CanvasPoint[] = []; + for (let index = 0; index < polygon.length - 1; index += 2) { + points.push({ x: polygon[index], y: polygon[index + 1] }); + } + return points; +} + +function segmentationAllPoints(segmentation?: number[][]): CanvasPoint[] { + return (segmentation || []).flatMap((polygon) => flatPolygonToPoints(polygon)); +} + +function polygonBbox(points: CanvasPoint[]): [number, number, number, number] { + const xs = points.map((point) => point.x); + const ys = points.map((point) => point.y); + const minX = Math.min(...xs); + const minY = Math.min(...ys); + const maxX = Math.max(...xs); + const maxY = Math.max(...ys); + return [minX, minY, maxX - minX, maxY - minY]; +} + +function polygonArea(points: CanvasPoint[]): number { + if (points.length < 3) return 0; + const sum = points.reduce((acc, point, index) => { + const next = points[(index + 1) % points.length]; + return acc + point.x * next.y - next.x * point.y; + }, 0); + return Math.abs(sum) / 2; +} + +function segmentationArea(segmentation?: number[][]): number { + return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0); +} + +function segmentationBbox(segmentation?: number[][]): [number, number, number, number] | undefined { + const points = segmentationAllPoints(segmentation); + return points.length > 0 ? polygonBbox(points) : undefined; +} + +function closeRing(points: CanvasPoint[]): Pair[] { + const ring = points.map((point) => [point.x, point.y] as Pair); + const first = ring[0]; + const last = ring[ring.length - 1]; + if (first && last && (first[0] !== last[0] || first[1] !== last[1])) { + ring.push([first[0], first[1]]); + } + return ring; +} + +function maskToMultiPolygon(mask: Mask): MultiPolygon | null { + const polygons = (mask.segmentation || []) + .map((polygon) => flatPolygonToPoints(polygon)) + .filter((points) => points.length >= 3) + .map((points) => [closeRing(points)]); + return polygons.length > 0 ? polygons : null; +} + +function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] { + return geometry + .map((polygon) => polygon[0] || []) + .map((ring) => { + const openRing = ring.length > 1 + && ring[0][0] === ring[ring.length - 1][0] + && ring[0][1] === ring[ring.length - 1][1] + ? ring.slice(0, -1) + : ring; + return openRing.flatMap(([x, y]) => [x, y]); + }) + .filter((polygon) => polygon.length >= 6); +} + +function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] { + const x1 = Math.min(start.x, end.x); + const y1 = Math.min(start.y, end.y); + const x2 = Math.max(start.x, end.x); + const y2 = Math.max(start.y, end.y); + return [ + { x: x1, y: y1 }, + { x: x2, y: y1 }, + { x: x2, y: y2 }, + { x: x1, y: y2 }, + ]; +} + +function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] { + const cx = (start.x + end.x) / 2; + const cy = (start.y + end.y) / 2; + const rx = Math.abs(end.x - start.x) / 2; + const ry = Math.abs(end.y - start.y) / 2; + return Array.from({ length: 32 }, (_, index) => { + const angle = (Math.PI * 2 * index) / 32; + return { x: cx + Math.cos(angle) * rx, y: cy + Math.sin(angle) * ry }; + }); +} + +function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] { + return Array.from({ length: 12 }, (_, index) => { + const angle = (Math.PI * 2 * index) / 12; + return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius }; + }); +} + +function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] { + const dx = end.x - start.x; + const dy = end.y - start.y; + const length = Math.hypot(dx, dy) || 1; + const nx = (-dy / length) * halfWidth; + const ny = (dx / length) * halfWidth; + return [ + { x: start.x + nx, y: start.y + ny }, + { x: end.x + nx, y: end.y + ny }, + { x: end.x - nx, y: end.y - ny }, + { x: start.x - nx, y: start.y - ny }, + ]; +} + +export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) { const containerRef = useRef(null); const [stageSize, setStageSize] = useState({ width: 800, height: 600 }); const [scale, setScale] = useState(1); @@ -20,22 +183,46 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 }); const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null); const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null); + const [manualStart, setManualStart] = useState(null); + const [manualCurrent, setManualCurrent] = useState(null); + const [polygonPoints, setPolygonPoints] = useState([]); + const [selectedMaskId, setSelectedMaskId] = useState(null); + const [selectedMaskIds, setSelectedMaskIds] = useState([]); + const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0); + const [selectedVertexIndex, setSelectedVertexIndex] = useState(null); const [isInferencing, setIsInferencing] = useState(false); const masks = useStore((state) => state.masks); const addMask = useStore((state) => state.addMask); + const updateMask = useStore((state) => state.updateMask); const clearMasks = useStore((state) => state.clearMasks); const setMasks = useStore((state) => state.setMasks); const storeActiveTool = useStore((state) => state.activeTool); const aiModel = useStore((state) => state.aiModel); const activeTemplateId = useStore((state) => state.activeTemplateId); const activeClass = useStore((state) => state.activeClass); + const undoMasks = useStore((state) => state.undoMasks); + const redoMasks = useStore((state) => state.redoMasks); const effectiveTool = activeTool || storeActiveTool; // Load the actual frame image const [image] = useImage(frame?.url || ''); const frameMasks = masks.filter((mask) => mask.frameId === frame?.id); + const selectedMask = React.useMemo( + () => frameMasks.find((mask) => mask.id === selectedMaskId) || null, + [frameMasks, selectedMaskId], + ); + const booleanSelectedMasks = React.useMemo( + () => selectedMaskIds + .map((id) => frameMasks.find((mask) => mask.id === id)) + .filter((mask): mask is Mask => Boolean(mask)), + [frameMasks, selectedMaskIds], + ); + const selectedMaskPoints = React.useMemo( + () => segmentationToPoints(selectedMask?.segmentation, selectedPolygonIndex), + [selectedMask?.segmentation, selectedPolygonIndex], + ); const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length; const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length; const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length; @@ -55,6 +242,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) return () => window.removeEventListener('resize', handleResize); }, []); + useEffect(() => { + setManualStart(null); + setManualCurrent(null); + setPolygonPoints([]); + setSelectedMaskId(null); + setSelectedMaskIds([]); + setSelectedPolygonIndex(0); + setSelectedVertexIndex(null); + }, [effectiveTool, frame?.id]); + + useEffect(() => { + if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) { + setSelectedMaskId(null); + setSelectedMaskIds([]); + setSelectedPolygonIndex(0); + setSelectedVertexIndex(null); + } + }, [frameMasks, selectedMaskId]); + const handleWheel = (e: any) => { e.evt.preventDefault(); const scaleBy = 1.1; @@ -74,6 +280,50 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) }); }; + const stagePoint = (e: any): CanvasPoint | null => { + const stage = e.target.getStage(); + const relPos = stage?.getRelativePointerPosition(); + if (!relPos) return null; + const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width; + const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height; + return { + x: clamp(relPos.x, 0, imageWidth), + y: clamp(relPos.y, 0, imageHeight), + }; + }; + + const createManualMask = useCallback((shape: string, polygon: CanvasPoint[]) => { + if (!frame?.id || polygon.length < 3) return; + const area = polygonArea(polygon); + if (area <= 1) return; + const color = activeClass?.color || '#06b6d4'; + const label = activeClass?.name || `手工${shape}`; + const mask: Mask = { + id: `manual-${frame.id}-${shape}-${Date.now()}`, + frameId: frame.id, + templateId: activeTemplateId || undefined, + classId: activeClass?.id, + className: activeClass?.name, + classZIndex: activeClass?.zIndex, + saveStatus: 'draft', + saved: false, + pathData: polygonPath(polygon), + label, + color, + segmentation: polygonSegmentation(polygon), + points: shape === '点区域' + ? [[ + polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length, + polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length, + ]] + : undefined, + bbox: polygonBbox(polygon), + area, + metadata: { source: 'manual', shape }, + }; + addMask(mask); + }, [activeClass, activeTemplateId, addMask, frame?.id]); + const handleMouseMove = (e: any) => { const stage = e.target.getStage(); if (!stage) return; @@ -90,6 +340,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) setBoxCurrent({ x: relPos.x, y: relPos.y }); } } + + if (manualStart && DRAG_MANUAL_TOOLS.has(effectiveTool)) { + const pos = stage.getRelativePointerPosition(); + if (pos) { + setManualCurrent({ x: pos.x, y: pos.y }); + } + } }; const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => { @@ -132,6 +389,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) label, color, segmentation: m.segmentation, + points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]), bbox: m.bbox, area: m.area, }); @@ -170,6 +428,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) }; const handleStageMouseDown = (e: any) => { + if (DRAG_MANUAL_TOOLS.has(effectiveTool)) { + const pos = stagePoint(e); + if (pos) { + setManualStart(pos); + setManualCurrent(pos); + } + return; + } + if (effectiveTool === 'box_select') { const stage = e.target.getStage(); const pos = stage.getRelativePointerPosition(); @@ -181,6 +448,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) }; const handleStageMouseUp = (e: any) => { + if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) { + const end = stagePoint(e) || manualCurrent || manualStart; + const width = Math.abs(end.x - manualStart.x); + const height = Math.abs(end.y - manualStart.y); + const distance = Math.hypot(width, height); + + if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) { + createManualMask('矩形', rectanglePoints(manualStart, end)); + } + if (effectiveTool === 'create_circle' && width > 4 && height > 4) { + createManualMask('圆形', circlePoints(manualStart, end)); + } + if (effectiveTool === 'create_line' && distance > 4) { + createManualMask('线段', lineRegion(manualStart, end)); + } + + setManualStart(null); + setManualCurrent(null); + return; + } + if (effectiveTool === 'box_select' && boxStart && boxCurrent) { const x1 = Math.min(boxStart.x, boxCurrent.x); const y1 = Math.min(boxStart.y, boxCurrent.y); @@ -199,12 +487,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) const handleStageClick = (e: any) => { if (effectiveTool === 'move') return; if (effectiveTool === 'box_select') return; // handled by mouseup + if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return; + + if (effectiveTool === POINT_TOOL) { + const pos = stagePoint(e); + if (pos) { + createManualMask('点区域', pointRegion(pos)); + } + return; + } + + if (effectiveTool === POLYGON_TOOL) { + const pos = stagePoint(e); + if (pos) { + setPolygonPoints((current) => [...current, pos]); + } + return; + } if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') { const stage = e.target.getStage(); const pos = stage.getRelativePointerPosition(); if (pos) { - const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }]; + const newPoints = [ + ...points, + { x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' }, + ]; setPoints(newPoints); // Auto-trigger inference after point selection runInference(newPoints); @@ -212,6 +520,74 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) } }; + const updatePolygonMask = useCallback((mask: Mask, nextPoints: CanvasPoint[], polygonIndex = 0) => { + if (nextPoints.length < 3) return; + const nextSegmentation = [...(mask.segmentation || [])]; + nextSegmentation[polygonIndex] = nextPoints.flatMap((point) => [point.x, point.y]); + const bbox = segmentationBbox(nextSegmentation) || polygonBbox(nextPoints); + updateMask(mask.id, { + pathData: segmentationPath(nextSegmentation), + segmentation: nextSegmentation, + bbox, + area: segmentationArea(nextSegmentation), + saveStatus: mask.annotationId ? 'dirty' : 'draft', + saved: mask.annotationId ? false : mask.saved, + }); + }, [updateMask]); + + const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => { + const bbox = segmentationBbox(segmentation); + return { + ...mask, + pathData: segmentationPath(segmentation), + segmentation, + bbox, + area: segmentationArea(segmentation), + saveStatus: mask.annotationId ? 'dirty' : 'draft', + saved: mask.annotationId ? false : mask.saved, + }; + }, []); + + useEffect(() => { + const handleKeyDown = (event: KeyboardEvent) => { + const key = event.key.toLowerCase(); + if ((event.metaKey || event.ctrlKey) && key === 'z') { + event.preventDefault(); + if (event.shiftKey) redoMasks(); + else undoMasks(); + return; + } + if ((event.metaKey || event.ctrlKey) && key === 'y') { + event.preventDefault(); + redoMasks(); + return; + } + if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) { + const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex); + if (currentPoints.length > 3) { + event.preventDefault(); + const nextPoints = currentPoints.filter((_, index) => index !== selectedVertexIndex); + updatePolygonMask(selectedMask, nextPoints, selectedPolygonIndex); + setSelectedVertexIndex(null); + } + return; + } + if (effectiveTool !== POLYGON_TOOL) return; + if (event.key === 'Enter' && polygonPoints.length >= 3) { + event.preventDefault(); + createManualMask('多边形', polygonPoints); + setPolygonPoints([]); + } + if (event.key === 'Escape') { + event.preventDefault(); + setPolygonPoints([]); + } + }; + + window.addEventListener('keydown', handleKeyDown); + return () => window.removeEventListener('keydown', handleKeyDown); + }, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]); + const boxRect = React.useMemo(() => { if (!boxStart || !boxCurrent) return null; const x = Math.min(boxStart.x, boxCurrent.x); @@ -221,6 +597,132 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) return { x, y, width, height }; }, [boxStart, boxCurrent]); + const manualPreviewPath = React.useMemo(() => { + if (manualStart && manualCurrent) { + if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent)); + if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent)); + if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent)); + } + if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) { + const previewPoints = [...polygonPoints, cursorPos]; + return polygonPath(previewPoints); + } + return null; + }, [cursorPos, effectiveTool, manualCurrent, manualStart, polygonPoints]); + + const handleSeedPointDragEnd = (mask: Mask, pointIndex: number, event: any) => { + const x = event.target.x(); + const y = event.target.y(); + const nextPoints = [...(mask.points || [])]; + nextPoints[pointIndex] = [x, y]; + updateMask(mask.id, { + points: nextPoints, + saveStatus: mask.annotationId ? 'dirty' : 'draft', + saved: mask.annotationId ? false : mask.saved, + }); + }; + + const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => { + event.cancelBubble = true; + if (BOOLEAN_TOOLS.has(effectiveTool)) { + setSelectedMaskIds((current) => ( + current.includes(mask.id) + ? current.filter((id) => id !== mask.id) + : [...current, mask.id] + )); + setSelectedMaskId(mask.id); + setSelectedPolygonIndex(polygonIndex); + setSelectedVertexIndex(null); + return; + } + setSelectedMaskId(mask.id); + setSelectedMaskIds([mask.id]); + setSelectedPolygonIndex(polygonIndex); + setSelectedVertexIndex(null); + }; + + const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => { + const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width; + const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height; + const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex); + if (!currentPoints[vertexIndex]) return; + const nextPoints = currentPoints.map((point, index) => ( + index === vertexIndex + ? { + x: clamp(event.target.x(), 0, imageWidth), + y: clamp(event.target.y(), 0, imageHeight), + } + : point + )); + setSelectedMaskId(mask.id); + setSelectedVertexIndex(vertexIndex); + updatePolygonMask(mask, nextPoints, selectedPolygonIndex); + }; + + const handleEdgeInsert = (mask: Mask, edgeIndex: number, event: any) => { + event.cancelBubble = true; + const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex); + const start = currentPoints[edgeIndex]; + const end = currentPoints[(edgeIndex + 1) % currentPoints.length]; + if (!start || !end) return; + const inserted = { x: (start.x + end.x) / 2, y: (start.y + end.y) / 2 }; + const nextPoints = [ + ...currentPoints.slice(0, edgeIndex + 1), + inserted, + ...currentPoints.slice(edgeIndex + 1), + ]; + setSelectedMaskId(mask.id); + setSelectedVertexIndex(edgeIndex + 1); + updatePolygonMask(mask, nextPoints, selectedPolygonIndex); + }; + + const handleBooleanOperation = async () => { + if (!frame || booleanSelectedMasks.length < 2) return; + const primary = booleanSelectedMasks[0]; + const primaryGeometry = maskToMultiPolygon(primary); + if (!primaryGeometry) return; + + const clipGeometries = booleanSelectedMasks + .slice(1) + .map(maskToMultiPolygon) + .filter((geometry): geometry is MultiPolygon => Boolean(geometry)); + if (clipGeometries.length === 0) return; + + const resultGeometry = effectiveTool === 'area_merge' + ? polygonClipping.union(primaryGeometry, ...clipGeometries) + : polygonClipping.difference(primaryGeometry, ...clipGeometries); + const resultSegmentation = multiPolygonToSegmentation(resultGeometry); + + if (resultSegmentation.length === 0) { + const deleteIds = primary.annotationId ? [primary.annotationId] : []; + setMasks(masks.filter((mask) => mask.id !== primary.id)); + if (deleteIds.length > 0) await onDeleteMaskAnnotations?.(deleteIds); + setSelectedMaskId(null); + setSelectedMaskIds([]); + setSelectedVertexIndex(null); + return; + } + + const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation); + const secondaryIds = effectiveTool === 'area_merge' + ? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id)) + : new Set(); + const secondaryAnnotationIds = effectiveTool === 'area_merge' + ? booleanSelectedMasks + .slice(1) + .map((mask) => mask.annotationId) + .filter((annotationId): annotationId is string => Boolean(annotationId)) + : []; + + setMasks(masks + .filter((mask) => !secondaryIds.has(mask.id)) + .map((mask) => (mask.id === primary.id ? nextPrimary : mask))); + if (secondaryAnnotationIds.length > 0) await onDeleteMaskAnnotations?.(secondaryAnnotationIds); + setSelectedMaskId(primary.id); + setSelectedMaskIds([primary.id]); + setSelectedVertexIndex(null); + }; + return (
{isInferencing && ( @@ -257,13 +759,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {/* AI Returned Masks */} {frameMasks.map((mask) => ( - - + + {(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => ( + handleMaskSelect(mask, event, polygonIndex)} + onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)} + /> + ))} ))} @@ -281,6 +788,86 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) /> )} + {/* Manual shape preview */} + {manualPreviewPath && ( + + )} + + {polygonPoints.map((point, index) => ( + + ))} + + {/* Imported GT seed points / editable point regions */} + {frameMasks.flatMap((mask) => (mask.points || []).map(([x, y], index) => ( + + handleSeedPointDragEnd(mask, index, event)} + /> + + + )))} + + {/* Polygon edge insertion handles */} + {selectedMask && selectedMaskPoints.map((point, index) => { + const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length]; + if (!next) return null; + return ( + handleEdgeInsert(selectedMask, index, event)} + onTap={(event: any) => handleEdgeInsert(selectedMask, index, event)} + /> + ); + })} + + {/* Polygon vertex editor */} + {selectedMask && selectedMaskPoints.map((point, index) => ( + { + event.cancelBubble = true; + setSelectedVertexIndex(index); + }} + onTap={(event: any) => { + event.cancelBubble = true; + setSelectedVertexIndex(index); + }} + onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)} + /> + ))} + {/* AI Prompts Point Regions */} {points.map((p, i) => ( @@ -313,6 +900,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {frameMasks.length > 0 && (
+ {BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && ( + + )} {activeClass && ( + )} + {canRetry(task) && ( + + )} + {task.task_id && ( + + )} +
))} {!isLoading && tasks.length === 0 && ( @@ -253,6 +394,46 @@ export function Dashboard() {
+ + {selectedTask && ( +
+
+
+
+

任务详情 #{selectedTask.id}

+

{selectedTask.message || selectedTask.status}

+
+ +
+
+
状态: {selectedTask.status}
+
进度: {selectedTask.progress}%
+
项目 ID: {selectedTask.project_id ?? '-'}
+
Celery ID: {selectedTask.celery_task_id || '-'}
+
创建: {selectedTask.created_at}
+
结束: {selectedTask.finished_at || '-'}
+
+ {selectedTask.error && ( +
+ {selectedTask.error} +
+ )} +
+
+                {JSON.stringify(selectedTask.payload || {}, null, 2)}
+              
+
+                {JSON.stringify(selectedTask.result || {}, null, 2)}
+              
+
+
+
+ )} ); } diff --git a/src/components/ToolsPalette.test.tsx b/src/components/ToolsPalette.test.tsx index 25a32ec..0808664 100644 --- a/src/components/ToolsPalette.test.tsx +++ b/src/components/ToolsPalette.test.tsx @@ -3,18 +3,31 @@ import { describe, expect, it, vi } from 'vitest'; import { ToolsPalette } from './ToolsPalette'; describe('ToolsPalette', () => { - it('switches tools and exposes UI-only placeholder buttons', () => { + it('switches tools and dispatches undo/redo actions when available', () => { const setActiveTool = vi.fn(); + const onUndo = vi.fn(); + const onRedo = vi.fn(); - render(); + render( + , + ); fireEvent.click(screen.getByTitle('创建多边形 (P)')); fireEvent.click(screen.getByTitle('正向选点 (SAM)')); + fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)')); + fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')); expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon'); expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos'); - expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument(); - expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument(); + expect(onUndo).toHaveBeenCalled(); + expect(onRedo).toHaveBeenCalled(); }); it('switches to SAM trigger and calls the AI navigation hook', () => { diff --git a/src/components/ToolsPalette.tsx b/src/components/ToolsPalette.tsx index 772520e..4aa0e35 100644 --- a/src/components/ToolsPalette.tsx +++ b/src/components/ToolsPalette.tsx @@ -6,9 +6,21 @@ interface ToolsPaletteProps { activeTool: string; setActiveTool: (tool: string) => void; onTriggerAI?: () => void; + onUndo?: () => void; + onRedo?: () => void; + canUndo?: boolean; + canRedo?: boolean; } -export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPaletteProps) { +export function ToolsPalette({ + activeTool, + setActiveTool, + onTriggerAI, + onUndo, + onRedo, + canUndo = false, + canRedo = false, +}: ToolsPaletteProps) { const tools = [ { id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' }, { id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' }, @@ -91,10 +103,20 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
- - diff --git a/src/components/VideoWorkspace.test.tsx b/src/components/VideoWorkspace.test.tsx index 15d8848..a1469d7 100644 --- a/src/components/VideoWorkspace.test.tsx +++ b/src/components/VideoWorkspace.test.tsx @@ -14,6 +14,8 @@ const apiMock = vi.hoisted(() => ({ updateAnnotation: vi.fn(), deleteAnnotation: vi.fn(), exportCoco: vi.fn(), + exportMasks: vi.fn(), + importGtMask: vi.fn(), annotationToMask: vi.fn(), buildAnnotationPayload: vi.fn(), getAiModelStatus: vi.fn(), @@ -29,6 +31,8 @@ vi.mock('../lib/api', () => ({ updateAnnotation: apiMock.updateAnnotation, deleteAnnotation: apiMock.deleteAnnotation, exportCoco: apiMock.exportCoco, + exportMasks: apiMock.exportMasks, + importGtMask: apiMock.importGtMask, annotationToMask: apiMock.annotationToMask, buildAnnotationPayload: apiMock.buildAnnotationPayload, getAiModelStatus: apiMock.getAiModelStatus, @@ -256,4 +260,64 @@ describe('VideoWorkspace', () => { await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled()); expect(apiMock.exportCoco).toHaveBeenCalledWith('1'); }); + + it('auto-saves pending masks before exporting PNG masks', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, + ]); + apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } }); + apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 }); + apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' })); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(1)); + act(() => { + useStore.setState({ + masks: [{ + id: 'mask-1', + frameId: '10', + pathData: 'M 0 0 Z', + label: 'AI Mask', + color: '#06b6d4', + segmentation: [[0, 0, 10, 0, 10, 10]], + }], + }); + }); + + fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' })); + + await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled()); + expect(apiMock.exportMasks).toHaveBeenCalledWith('1'); + }); + + it('imports a GT mask for the current frame and hydrates saved annotations', async () => { + apiMock.getProjectFrames.mockResolvedValueOnce([ + { id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 }, + ]); + apiMock.importGtMask.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]); + apiMock.getProjectAnnotations + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([{ id: 88, frame_id: 10 }]); + apiMock.annotationToMask.mockReturnValueOnce({ + id: 'annotation-88', + annotationId: '88', + frameId: '10', + saved: true, + pathData: 'M 0 0 Z', + label: 'GT Mask', + color: '#22c55e', + }); + + render(); + await waitFor(() => expect(useStore.getState().frames).toHaveLength(1)); + + const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement; + const file = new File(['mask'], 'mask.png', { type: 'image/png' }); + fireEvent.change(fileInput, { target: { files: [file] } }); + + await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10')); + await waitFor(() => expect(useStore.getState().masks).toEqual([ + expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }), + ])); + }); }); diff --git a/src/components/VideoWorkspace.tsx b/src/components/VideoWorkspace.tsx index eaf158f..4132193 100644 --- a/src/components/VideoWorkspace.tsx +++ b/src/components/VideoWorkspace.tsx @@ -5,10 +5,12 @@ import { buildAnnotationPayload, deleteAnnotation, exportCoco, + exportMasks, getProjectAnnotations, getProjectFrames, getTask, getTemplates, + importGtMask, parseMedia, saveAnnotation, updateAnnotation, @@ -25,18 +27,24 @@ function sleep(ms: number) { } export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { + const gtMaskInputRef = React.useRef(null); const activeTool = useStore((state) => state.activeTool); const setActiveTool = useStore((state) => state.setActiveTool); const currentProject = useStore((state) => state.currentProject); const frames = useStore((state) => state.frames); const currentFrameIndex = useStore((state) => state.currentFrameIndex); const masks = useStore((state) => state.masks); + const maskHistory = useStore((state) => state.maskHistory); + const maskFuture = useStore((state) => state.maskFuture); const activeTemplateId = useStore((state) => state.activeTemplateId); const setFrames = useStore((state) => state.setFrames); const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setMasks = useStore((state) => state.setMasks); + const undoMasks = useStore((state) => state.undoMasks); + const redoMasks = useStore((state) => state.redoMasks); const [isSaving, setIsSaving] = useState(false); const [isExporting, setIsExporting] = useState(false); + const [isImportingGt, setIsImportingGt] = useState(false); const [statusMessage, setStatusMessage] = useState(''); const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => { @@ -216,6 +224,18 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void } }, [currentFrame, masks, setMasks]); + const handleDeleteMaskAnnotations = useCallback(async (annotationIds: string[]) => { + if (annotationIds.length === 0) return; + try { + await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId))); + setStatusMessage(`已删除 ${annotationIds.length} 个被合并标注`); + } catch (err) { + console.error('Delete merged annotations failed:', err); + setStatusMessage('合并后删除原标注失败,请检查后端服务'); + throw err; + } + }, []); + const handleSave = async () => { try { await savePendingAnnotations(); @@ -248,6 +268,52 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void } }; + const downloadBlob = (blob: Blob, filename: string) => { + const url = URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = url; + link.download = filename; + document.body.appendChild(link); + link.click(); + link.remove(); + URL.revokeObjectURL(url); + }; + + const handleExportMasks = async () => { + if (!currentProject?.id) return; + setIsExporting(true); + setStatusMessage('正在准备导出语义 Mask ZIP...'); + try { + await savePendingAnnotations({ silent: true }); + const blob = await exportMasks(currentProject.id); + downloadBlob(blob, `project_${currentProject.id}_masks.zip`); + setStatusMessage('PNG Mask ZIP 已导出'); + } catch (err) { + console.error('Mask export failed:', err); + setStatusMessage('Mask 导出失败,请检查后端服务'); + } finally { + setIsExporting(false); + } + }; + + const handleImportGtMask = async (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (!file || !currentProject?.id || !currentFrame?.id) return; + setIsImportingGt(true); + setStatusMessage('正在导入 GT Mask...'); + try { + const imported = await importGtMask(file, currentProject.id, currentFrame.id); + await hydrateSavedAnnotations(currentProject.id, frames); + setStatusMessage(`已导入 ${imported.length} 个 GT 区域`); + } catch (err) { + console.error('GT mask import failed:', err); + setStatusMessage('GT Mask 导入失败,请检查文件或后端服务'); + } finally { + setIsImportingGt(false); + event.target.value = ''; + } + }; + return (
{/* Top Header / Status bar */} @@ -264,6 +330,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void )} + + +