feat: 建立 SAM2 标注闭环基线

- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。

- 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。

- 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。

- 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。

- 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。

- 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。

- 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。

- 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。

- 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。

- 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。

- 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。

- 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。

- 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。

- 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。

- 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。

- 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
2026-05-01 15:26:25 +08:00
parent f020ff3b4f
commit 689a9ba283
48 changed files with 3280 additions and 176 deletions

View File

@@ -6,7 +6,7 @@
## 项目概述 ## 项目概述
本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。 本项目是一个**语义分割系统**Semantic Segmentation System当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
- **项目名称**: `react-example``package.json` 中的 `name` - **项目名称**: `react-example``package.json` 中的 `name`
- **前端入口**: `src/main.tsx``src/App.tsx` - **前端入口**: `src/main.tsx``src/App.tsx`
@@ -30,6 +30,7 @@
| 前端请求 | Axios`src/lib/api.ts` | | 前端请求 | Axios`src/lib/api.ts` |
| 实时通信 | WebSocket 客户端(`src/lib/websocket.ts` | | 实时通信 | WebSocket 客户端(`src/lib/websocket.ts` |
| Canvas 渲染 | Konva + react-konva + use-image | | Canvas 渲染 | Konva + react-konva + use-image |
| 几何布尔运算 | polygon-clipping |
| 图标库 | lucide-react | | 图标库 | lucide-react |
| 动画依赖 | motion`package.json` 中声明) | | 动画依赖 | motion`package.json` 中声明) |
| AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) | | AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) |
@@ -38,9 +39,9 @@
| 缓存 / 队列 Broker | Redis | | 缓存 / 队列 Broker | Redis |
| 后台任务 | Celery worker | | 后台任务 | Celery worker |
| 对象存储 | MinIO | | 对象存储 | MinIO |
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch`GET /api/ai/models/status` 返回真实 GPU/模型状态 | | AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorchSAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 |
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom | | 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
| 运行时 | Node.js ES ModulesPython 3.11 后端环境 | | 运行时 | Node.js ES ModulesPython 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 |
--- ---
@@ -70,6 +71,7 @@ Seg_Server/
│ ├── celery_app.py # Celery app 配置 │ ├── celery_app.py # Celery app 配置
│ ├── worker_tasks.py # Celery 任务入口 │ ├── worker_tasks.py # Celery 任务入口
│ ├── download_sam2.py # SAM 2 权重下载脚本 │ ├── download_sam2.py # SAM 2 权重下载脚本
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本
│ ├── requirements.txt # Python 依赖 │ ├── requirements.txt # Python 依赖
│ ├── routers/ │ ├── routers/
│ │ ├── auth.py # /api/auth/login │ │ ├── auth.py # /api/auth/login
@@ -81,7 +83,8 @@ Seg_Server/
│ └── services/ │ └── services/
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传 │ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback │ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器 │ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
└── src/ # React 前端 └── src/ # React 前端
├── main.tsx # React StrictMode 挂载 ├── main.tsx # React StrictMode 挂载
@@ -188,10 +191,13 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- `POST /api/media/parse` - `POST /api/media/parse`
- `GET /api/tasks` - `GET /api/tasks`
- `GET /api/tasks/{task_id}` - `GET /api/tasks/{task_id}`
- `POST /api/tasks/{task_id}/cancel`
- `POST /api/tasks/{task_id}/retry`
- `POST /api/ai/predict` - `POST /api/ai/predict`
- `GET /api/ai/models/status` - `GET /api/ai/models/status`
- `POST /api/ai/auto` - `POST /api/ai/auto`
- `POST /api/ai/annotate` - `POST /api/ai/annotate`
- `POST /api/ai/import-gt-mask`
- `GET /api/ai/annotations` - `GET /api/ai/annotations`
- `PATCH/DELETE /api/ai/annotations/{annotation_id}` - `PATCH/DELETE /api/ai/annotations/{annotation_id}`
- `GET /api/dashboard/overview` - `GET /api/dashboard/overview`
@@ -216,9 +222,11 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。 4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom并持续更新任务进度。 5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom并持续更新任务进度。
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图。 6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
7. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 调用 SAM registry。SAM 2 支持点/框/自动分割SAM 3 入口支持文本语义推理,运行时不满足官方要求时会在状态接口中标为不可用 7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/differenceZustand 维护 `maskHistory/maskFuture` 支持撤销/重做
8. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树 8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割;`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为不可用
9. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出 9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point前端回显 seed point拖动后可归档更新
10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。
11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`
--- ---
@@ -226,12 +234,16 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL``VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。 - `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL``VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id``prompt_type``prompt_data``model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData` - 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id``prompt_type``prompt_data``model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`工作区“导出 JSON 标注集”按钮已绑定下载流程,导出前会先保存当前待归档 mask - 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。 - 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
- 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error` - 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error`
- `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。 - `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress` - `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`
--- ---

View File

@@ -6,7 +6,7 @@
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。 > 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
> >
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2语义文本可选择 SAM 3前端会显示真实 GPU/模型状态。 > 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2语义文本可选择 SAM 3前端会显示真实 GPU/模型状态。
--- ---
@@ -14,10 +14,11 @@
- **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像的上传、存储与解析 - **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像的上传、存储与解析
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box和自动分割autoSAM 3 入口支持文本语义提示并按真实运行环境显示可用性 - **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box和自动分割autoSAM 3 入口支持文本语义提示并按真实运行环境显示可用性
- **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/选点/框选,实时渲染 Mask 遮罩 - **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point前端可回显和拖动 seed point
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 - **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出 - **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射
--- ---
@@ -38,7 +39,7 @@
│ ├── /api/projects 项目 & 视频帧 CRUD │ │ ├── /api/projects 项目 & 视频帧 CRUD │
│ ├── /api/templates 本体字典(分类/颜色/z-index │ ├── /api/templates 本体字典(分类/颜色/z-index
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │ │ ├── /api/media 文件上传 & 异步拆帧任务创建 │
│ ├── /api/tasks Celery 后台任务状态 │ ├── /api/tasks Celery 后台任务状态/取消/重试/详情
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │ │ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │
│ └── /api/export COCO JSON / PNG Masks 导出 │ │ └── /api/export COCO JSON / PNG Masks 导出 │
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
@@ -62,6 +63,7 @@
| 样式方案 | TailwindCSS + 自定义深色主题 | v4 | | 样式方案 | TailwindCSS + 自定义深色主题 | v4 |
| 状态管理 | Zustand | - | | 状态管理 | Zustand | - |
| Canvas 渲染 | Konva + react-konva | - | | Canvas 渲染 | Konva + react-konva | - |
| 几何布尔运算 | polygon-clipping | 0.15+ |
| HTTP 客户端 | Axios | - | | HTTP 客户端 | Axios | - |
| 后端框架 | FastAPI | v0.136+ | | 后端框架 | FastAPI | v0.136+ |
| 数据库 ORM | SQLAlchemy依赖中包含 Alembic | 2.0+ | | 数据库 ORM | SQLAlchemy依赖中包含 Alembic | 2.0+ |
@@ -92,6 +94,7 @@ Seg_Server/
│ ├── celery_app.py # Celery app 配置 │ ├── celery_app.py # Celery app 配置
│ ├── worker_tasks.py # Celery 任务入口 │ ├── worker_tasks.py # Celery 任务入口
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本 │ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本
│ ├── requirements.txt # Python 依赖 │ ├── requirements.txt # Python 依赖
│ ├── routers/ # API 路由 │ ├── routers/ # API 路由
│ │ ├── auth.py # 登录认证 │ │ ├── auth.py # 登录认证
@@ -102,7 +105,8 @@ Seg_Server/
│ │ └── export.py # 数据导出 │ │ └── export.py # 数据导出
│ └── services/ # 业务服务 │ └── services/ # 业务服务
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级 │ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器 │ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发 │ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片 │ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
├── src/ # React 前端 ├── src/ # React 前端
@@ -117,10 +121,10 @@ Seg_Server/
│ └── components/ # 组件(扁平化目录) │ └── components/ # 组件(扁平化目录)
│ ├── Login.tsx # 登录页 │ ├── Login.tsx # 登录页
│ ├── Sidebar.tsx # 左侧导航栏 │ ├── Sidebar.tsx # 左侧导航栏
│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列) │ ├── Dashboard.tsx # 总体概况仪表盘(解析队列/任务控制
│ ├── ProjectLibrary.tsx # 项目库列表 │ ├── ProjectLibrary.tsx # 项目库列表
│ ├── VideoWorkspace.tsx # 核心分割工作区布局 │ ├── VideoWorkspace.tsx # 核心分割工作区布局
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/选点/Mask渲染 │ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染
│ ├── ToolsPalette.tsx # 左侧工具栏 │ ├── ToolsPalette.tsx # 左侧工具栏
│ ├── OntologyInspector.tsx # 右侧本体/属性检查面板 │ ├── OntologyInspector.tsx # 右侧本体/属性检查面板
│ ├── FrameTimeline.tsx # 底部时间轴 │ ├── FrameTimeline.tsx # 底部时间轴
@@ -161,7 +165,7 @@ Seg_Server/
- **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 推理SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境 - **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 推理SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境
- **CUDA**: 12.x / 13.x - **CUDA**: 12.x / 13.x
- **Node.js**: 22.x+ - **Node.js**: 22.x+
- **Python**: 3.11(通过 Miniconda/Anaconda 管理) - **Python**: 主后端使用 3.11(通过 Miniconda/Anaconda 管理)SAM 3 使用独立 `sam3` Python 3.12 conda 环境
### 安装系统级依赖 ### 安装系统级依赖
@@ -243,7 +247,22 @@ python download_sam2.py
> **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。 > **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。
### 步骤 5: 配置环境变量 ### 步骤 5: 可选安装 SAM 3 环境
当前后端不会把 SAM 3 直接装进 `seg_server`,而是通过独立 `sam3` conda 环境执行 `backend/services/sam3_external_worker.py`。这样可以保留现有 Python 3.11 / SAM 2 环境。
```bash
cd ~/Desktop/Seg_Server
./backend/setup_sam3_env.sh
# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录
conda activate sam3
huggingface-cli login
```
官方 `facebook/sam3` 权重约 3.45 GB当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度;`facebook/sam3.1` 约 3.5 GB主要面向新的视频 multiplex checkpoint。未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
### 步骤 6: 配置环境变量
后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件: 后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件:
@@ -258,7 +277,10 @@ minio_secure=false
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
sam_model_config=configs/sam2/sam2_hiera_t.yaml sam_model_config=configs/sam2/sam2_hiera_t.yaml
sam_default_model=sam2 sam_default_model=sam2
sam3_model_version=sam3.1 sam3_model_version=sam3
sam3_external_enabled=true
sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
sam3_timeout_seconds=300
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"] cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
``` ```
@@ -271,7 +293,7 @@ VITE_WS_PROGRESS_URL=ws://192.168.3.11:8000/ws/progress
如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000` 如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000`
### 步骤 6: 启动后端服务 ### 步骤 7: 启动后端服务
```bash ```bash
cd ~/Desktop/Seg_Server/backend cd ~/Desktop/Seg_Server/backend
@@ -287,7 +309,8 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
- 创建数据库表(如果不存在) - 创建数据库表(如果不存在)
- 检查 MinIO bucket 是否存在 - 检查 MinIO bucket 是否存在
- 测试 Redis 连接 - 测试 Redis 连接
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3 与 GPU 的真实可用状态 - 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt``auto_filter_background``min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
### 步骤 6.1: 启动 Celery Worker ### 步骤 6.1: 启动 Celery Worker
@@ -301,7 +324,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 & nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
``` ```
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。 `POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel``/api/tasks/{id}/retry``/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
### 步骤 7: 安装前端依赖并构建 ### 步骤 7: 安装前端依赖并构建
@@ -438,7 +461,8 @@ pip install -e . --no-build-isolation
- 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData` - 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData`
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict` - 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco` - 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`
- 工作区“导出 JSON 标注集”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。 - 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}` - 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。 - 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
@@ -447,6 +471,7 @@ pip install -e . --no-build-isolation
```bash ```bash
curl http://localhost:8000/health curl http://localhost:8000/health
curl http://localhost:8000/api/export/1/coco curl http://localhost:8000/api/export/1/coco
curl http://localhost:8000/api/export/1/masks
``` ```
--- ---

View File

@@ -22,7 +22,12 @@ class Settings(BaseSettings):
sam_default_model: str = "sam2" sam_default_model: str = "sam2"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml" sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
sam3_model_version: str = "sam3.1" sam3_model_version: str = "sam3"
sam3_external_enabled: bool = True
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
sam3_timeout_seconds: int = 300
sam3_status_cache_seconds: int = 30
sam3_confidence_threshold: float = 0.5
# App # App
app_env: str = "development" app_env: str = "development"

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone
from typing import Any from typing import Any
from redis_client import get_redis_client from redis_client import get_redis_client
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,6 +22,8 @@ def _iso_now() -> str:
def _event_type(task_status: str) -> str: def _event_type(task_status: str) -> str:
if task_status == TASK_STATUS_SUCCESS: if task_status == TASK_STATUS_SUCCESS:
return "complete" return "complete"
if task_status == TASK_STATUS_CANCELLED:
return "cancelled"
if task_status == TASK_STATUS_FAILED: if task_status == TASK_STATUS_FAILED:
return "error" return "error"
return "progress" return "progress"

View File

@@ -5,7 +5,7 @@ from typing import Any, List
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from database import get_db from database import get_db
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post( @router.post(
"/predict", "/predict",
response_model=PredictResponse, response_model=PredictResponse,
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
image = _load_frame_image(frame) image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower() prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = [] polygons: List[List[List[float]]] = []
scores: List[float] = [] scores: List[float] = []
negative_points: list[list[float]] = []
try: try:
if prompt_type == "point": if prompt_type == "point":
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="Invalid point prompt data") raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points): if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points) labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels) negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box": elif prompt_type == "box":
box = payload.prompt_data box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4: if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data") raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box) inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic": elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
@@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
except NotImplementedError as exc: except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
return {"polygons": polygons, "scores": scores} return {"polygons": polygons, "scores": scores}
@@ -161,6 +324,100 @@ def save_annotation(
return annotation return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get( @router.get(
"/annotations", "/annotations",
response_model=List[AnnotationOut], response_model=List[AnnotationOut],

View File

@@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"]) router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"} ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"}
def _system_load_percent() -> int: def _system_load_percent() -> int:
@@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
"name": task.project.name if task.project else f"任务 {task.id}", "name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress, "progress": task.progress,
"status": task.message or task.status, "status": task.message or task.status,
"raw_status": task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0), "frame_count": (task.result or {}).get("frames_extracted", 0),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at), "updated_at": _iso_or_none(task.updated_at),
} }
@@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
.limit(50) .limit(50)
.all() .all()
) )
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES] tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
activities: list[dict[str, Any]] = [] activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]: for task in recent_tasks[:10]:

View File

@@ -37,6 +37,54 @@ def _mask_from_polygon(
return mask return mask
def _annotation_z_index(annotation: Annotation) -> int:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
try:
return int(class_meta["zIndex"])
except (TypeError, ValueError):
pass
if annotation.template and annotation.template.z_index is not None:
return int(annotation.template.z_index)
return 0
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
if class_meta.get("id"):
return f"class:{class_meta['id']}"
if class_meta.get("name"):
return f"name:{class_meta['name']}"
if annotation.template_id:
return f"template:{annotation.template_id}"
return f"annotation:{annotation.id}"
def _annotation_label(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("name"):
return str(class_meta["name"])
if mask_data.get("label"):
return str(mask_data["label"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return f"Annotation {annotation.id}"
def _annotation_color(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("color"):
return str(class_meta["color"])
if mask_data.get("color"):
return str(mask_data["color"])
if annotation.template and annotation.template.color:
return str(annotation.template.color)
return "#ffffff"
@router.get( @router.get(
"/{project_id}/coco", "/{project_id}/coco",
summary="Export annotations in COCO format", summary="Export annotations in COCO format",
@@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
summary="Export PNG masks as a ZIP archive", summary="Export PNG masks as a ZIP archive",
) )
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse: def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive.""" """Export individual masks plus z-index fused semantic masks inside a ZIP."""
project = db.query(Project).filter(Project.id == project_id).first() project = db.query(Project).filter(Project.id == project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
import cv2
annotations = ( annotations = (
db.query(Annotation) db.query(Annotation)
.filter(Annotation.project_id == project_id) .filter(Annotation.project_id == project_id)
.all() .all()
) )
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
def class_value(annotation: Annotation) -> int:
key = _annotation_class_key(annotation)
if key not in class_values:
value = len(class_values) + 1
class_values[key] = value
semantic_classes.append({
"value": value,
"key": key,
"label": _annotation_label(annotation),
"color": _annotation_color(annotation),
"zIndex": _annotation_z_index(annotation),
"template_id": annotation.template_id,
})
return class_values[key]
zip_buffer = io.BytesIO() zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
for ann in annotations: for ann in annotations:
if not ann.mask_data: if not ann.mask_data:
continue continue
@@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
mask = _mask_from_polygon(poly, width, height) mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask) combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined) _, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png" fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes()) zf.writestr(fname, encoded.tobytes())
if ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=np.uint8)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
zf.writestr(
"semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
zip_buffer.seek(0) zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip" filename = f"project_{project_id}_masks.zip"

View File

@@ -1,15 +1,45 @@
"""Processing task query endpoints.""" """Processing task query endpoints."""
import logging
from datetime import datetime, timezone
from typing import List from typing import List
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db from database import get_db
from models import ProcessingTask from models import ProcessingTask, Project
from progress_events import publish_task_progress_event
from schemas import ProcessingTaskOut from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
PROJECT_STATUS_PENDING,
PROJECT_STATUS_READY,
TASK_ACTIVE_STATUSES,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media
router = APIRouter(prefix="/api/tasks", tags=["Tasks"]) router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks") @router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
@@ -31,7 +61,78 @@ def list_tasks(
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task") @router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask: def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id.""" """Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() return _get_task_or_404(task_id, db)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not cancellable in status: {task.status}",
)
if task.celery_task_id:
try:
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = "任务已取消"
task.error = "Cancelled by user"
task.finished_at = _now()
if task.project:
task.project.status = _project_status_after_stop(task.project)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not retryable in status: {previous.status}",
)
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
payload.setdefault("source_type", project.source_type or "video")
payload["retry_of"] = previous.id
task = ProcessingTask(
task_type=previous.task_type,
status=TASK_STATUS_QUEUED,
progress=0,
message=f"重试任务已入队(源任务 #{previous.id}",
project_id=project.id,
payload=payload,
)
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task return task

View File

@@ -180,6 +180,7 @@ class PredictRequest(BaseModel):
prompt_type: str # point / box / semantic prompt_type: str # point / box / semantic
prompt_data: Any prompt_data: Any
model: Optional[str] = None model: Optional[str] = None
options: Optional[dict[str, Any]] = None
class PredictResponse(BaseModel): class PredictResponse(BaseModel):
@@ -201,6 +202,8 @@ class AiModelStatus(BaseModel):
python_ok: bool = True python_ok: bool = True
torch_ok: bool = True torch_ok: bool = True
cuda_required: bool = False cuda_required: bool = False
external_available: bool = False
external_python: Optional[str] = None
class GpuStatus(BaseModel): class GpuStatus(BaseModel):

View File

@@ -20,9 +20,11 @@ from services.frame_parser import (
upload_frames_to_minio, upload_frames_to_minio,
) )
from statuses import ( from statuses import (
PROJECT_STATUS_PENDING,
PROJECT_STATUS_ERROR, PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING, PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY, PROJECT_STATUS_READY,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED, TASK_STATUS_FAILED,
TASK_STATUS_RUNNING, TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS, TASK_STATUS_SUCCESS,
@@ -31,6 +33,10 @@ from statuses import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskCancelled(RuntimeError):
"""Raised internally when a persisted task has been cancelled."""
def _now() -> datetime: def _now() -> datetime:
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
@@ -66,12 +72,29 @@ def _set_task_state(
publish_task_progress_event(task) publish_task_progress_event(task)
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise TaskCancelled("Task was cancelled")
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]: def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database.""" """Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task: if not task:
raise ValueError(f"Task not found: {task_id}") raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {
"task_id": task.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message or "任务已取消",
}
if task.project_id is None: if task.project_id is None:
_set_task_state( _set_task_state(
db, db,
@@ -111,6 +134,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
db.commit() db.commit()
raise ValueError("Project has no media uploaded") raise ValueError("Project has no media uploaded")
_ensure_not_cancelled(db, task)
project.status = PROJECT_STATUS_PARSING project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True) _set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
@@ -121,6 +145,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
try: try:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=15, message="正在下载媒体文件") _set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom": if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm") dcm_dir = os.path.join(tmp_dir, "dcm")
@@ -129,20 +154,24 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
client = get_minio_client() client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)) objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects: for obj in objects:
_ensure_not_cancelled(db, task)
if obj.object_name.lower().endswith(".dcm"): if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name) data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name)) local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f: with open(local_dcm, "wb") as f:
f.write(data) f.write(data)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列") _set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir) frame_files = parse_dicom(dcm_dir, output_dir)
else: else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path) media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name) local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f: with open(local_path, "wb") as f:
f.write(media_bytes) f.write(media_bytes)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧") _set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps)) frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps project.original_fps = original_fps
@@ -158,12 +187,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc) logger.warning("Thumbnail extraction failed: %s", exc)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储") _set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id) object_names = upload_frames_to_minio(frame_files, project.id)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=85, message="正在写入帧索引") _set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = [] frames_out = []
for idx, obj_name in enumerate(object_names): for idx, obj_name in enumerate(object_names):
_ensure_not_cancelled(db, task)
local_frame = frame_files[idx] local_frame = frame_files[idx]
try: try:
import cv2 import cv2
@@ -203,6 +235,23 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
) )
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id) logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result return result
except TaskCancelled:
project.status = _project_status_after_stop(project)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
return {
"task_id": task.id,
"project_id": project.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message,
}
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR project.status = PROJECT_STATUS_ERROR
_set_task_state( _set_task_state(

View File

@@ -9,8 +9,14 @@ the package.
from __future__ import annotations from __future__ import annotations
import importlib.util import importlib.util
import json
import logging import logging
import os
import subprocess
import sys import sys
import tempfile
import time
from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -41,6 +47,8 @@ class SAM3Engine:
self._processor: Any | None = None self._processor: Any | None = None
self._model_loaded = False self._model_loaded = False
self._last_error: str | None = None self._last_error: str | None = None
self._external_status_cache: dict[str, Any] | None = None
self._external_status_checked_at = 0.0
def _python_ok(self) -> bool: def _python_ok(self) -> bool:
return sys.version_info >= (3, 12) return sys.version_info >= (3, 12)
@@ -51,6 +59,81 @@ class SAM3Engine:
def _can_load(self) -> bool: def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok()) return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
def _external_python_exists(self) -> bool:
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
def _external_status(self, force: bool = False) -> dict[str, Any]:
now = time.monotonic()
if (
not force
and self._external_status_cache is not None
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
):
return self._external_status_cache
if not settings.sam3_external_enabled:
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": "SAM 3 external runtime is disabled.",
}
elif not self._external_python_exists():
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
}
else:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
text=True,
timeout=min(settings.sam3_timeout_seconds, 30),
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {detail}",
}
else:
status = json.loads(completed.stdout)
except Exception as exc: # noqa: BLE001
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {exc}",
}
self._external_status_cache = status
self._external_status_checked_at = now
return status
def _load_model(self) -> None: def _load_model(self) -> None:
if self._model_loaded: if self._model_loaded:
return return
@@ -92,26 +175,86 @@ class SAM3Engine:
return "SAM 3 dependencies are present; model will load on first inference." return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict: def status(self) -> dict:
available = self._can_load() external_status = self._external_status()
available = bool(self._can_load() or external_status.get("available"))
external_ready = bool(external_status.get("available"))
message = self._last_error or self._status_message()
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return { return {
"id": "sam3", "id": "sam3",
"label": "SAM 3", "label": "SAM 3",
"available": available, "available": available,
"loaded": self._processor is not None, "loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else "unavailable", "device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic"], "supports": ["semantic"],
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()), "message": message,
"package_available": SAM3_PACKAGE_AVAILABLE, "package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE, "checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})", "checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": self._python_ok(), "python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": TORCH_AVAILABLE, "torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True, "cuda_required": True,
"external_available": external_ready,
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
} }
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
tmp_path = Path(tmpdir)
image_path = tmp_path / "image.png"
request_path = tmp_path / "request.json"
Image.fromarray(image).save(image_path)
request_path.write_text(
json.dumps(
{
"image_path": str(image_path),
"text": text.strip(),
"model_version": settings.sam3_model_version,
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]: def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip(): if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.") raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._can_load() and self._external_status().get("available"):
return self._predict_semantic_external(image, text)
if not self._ensure_ready(): if not self._ensure_ready():
raise RuntimeError(self.status()["message"]) raise RuntimeError(self.status()["message"])

View File

@@ -0,0 +1,190 @@
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
environment while this helper is executed with a separate conda env that meets
SAM 3's stricter runtime requirements.
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
try:
import torch
cuda_available = bool(torch.cuda.is_available())
return (
cuda_available,
getattr(torch, "__version__", None),
getattr(torch.version, "cuda", None),
torch.cuda.get_device_name(0) if cuda_available else None,
)
except Exception: # noqa: BLE001
return False, None, None, None
def _compact_error(exc: Exception) -> str:
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
for line in lines:
if "Access to model" in line or "Cannot access gated repo" in line:
return line
return lines[0] if lines else exc.__class__.__name__
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
try:
from huggingface_hub import hf_hub_download
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
hf_hub_download(repo_id=repo_id, filename="config.json")
return True, None
except Exception as exc: # noqa: BLE001
return False, _compact_error(exc)
def runtime_status() -> dict[str, Any]:
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
package_error = None
package_available = importlib.util.find_spec("sam3") is not None
if package_available:
try:
import sam3 # noqa: F401
except Exception as exc: # noqa: BLE001
package_available = False
package_error = str(exc)
cuda_available, torch_version, cuda_version, device_name = _torch_status()
python_ok = sys.version_info >= (3, 12)
checkpoint_access = False
checkpoint_error = None
if package_available:
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
missing = []
if not python_ok:
missing.append("Python 3.12+ runtime")
if not package_available:
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
if torch_version is None:
missing.append("PyTorch")
if not cuda_available:
missing.append("CUDA GPU")
if package_available and not checkpoint_access:
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
return {
"available": available,
"package_available": package_available,
"checkpoint_access": checkpoint_access,
"python_ok": python_ok,
"torch_ok": torch_version is not None,
"torch_version": torch_version,
"cuda_version": cuda_version,
"cuda_available": cuda_available,
"device": "cuda" if cuda_available else "unavailable",
"device_name": device_name,
"message": (
"SAM 3 external runtime is ready."
if available
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
),
}
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
height, width = mask.shape[:2]
largest = []
for contour in contours:
if len(contour) > len(largest):
largest = contour
if len(largest) < 3:
return []
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
def _to_numpy(value: Any) -> np.ndarray:
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
elif hasattr(value, "cpu"):
value = value.cpu().numpy()
return np.asarray(value)
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
payload = json.loads(request_path.read_text(encoding="utf-8"))
image_path = Path(payload["image_path"])
text = str(payload["text"]).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model()
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
output = processor.set_text_prompt(state=state, prompt=text)
masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", []))
if masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 3 and masks.shape[0] == 1:
masks = masks[None, 0]
polygons = []
for mask in masks:
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
return {
"polygons": polygons,
"scores": scores.astype(float).tolist() if scores.size else [],
}
def main() -> int:
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
parser.add_argument("--status", action="store_true")
parser.add_argument("--request", type=Path)
args = parser.parse_args()
try:
if args.status:
print(json.dumps(runtime_status(), ensure_ascii=False))
return 0
if args.request:
print(json.dumps(predict(args.request), ensure_ascii=False))
return 0
parser.error("Use --status or --request")
except Exception as exc: # noqa: BLE001
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
return 1
return 2
if __name__ == "__main__":
raise SystemExit(main())

24
backend/setup_sam3_env.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
# Create the dedicated SAM 3 runtime used by backend/services/sam3_external_worker.py.
# Keep Hugging Face tokens outside this repository, for example:
# export HF_TOKEN=...
# huggingface-cli login --token "$HF_TOKEN"
set -euo pipefail
ENV_NAME="${SAM3_CONDA_ENV:-sam3}"
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then
conda create -y -n "$ENV_NAME" python=3.12
fi
conda activate "$ENV_NAME"
python -m pip install --upgrade pip
python -m pip install "setuptools<81"
python -m pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128
python -m pip install opencv-python pillow numpy huggingface_hub einops pycocotools psutil
python -m pip install git+https://github.com/facebookresearch/sam3.git
python /home/wkmgc/Desktop/Seg_Server/backend/services/sam3_external_worker.py --status

View File

@@ -9,3 +9,7 @@ TASK_STATUS_QUEUED = "queued"
TASK_STATUS_RUNNING = "running" TASK_STATUS_RUNNING = "running"
TASK_STATUS_SUCCESS = "success" TASK_STATUS_SUCCESS = "success"
TASK_STATUS_FAILED = "failed" TASK_STATUS_FAILED = "failed"
TASK_STATUS_CANCELLED = "cancelled"
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}

View File

@@ -1,4 +1,5 @@
import numpy as np import numpy as np
import cv2
def _create_project_and_frame(client): def _create_project_and_frame(client):
@@ -46,6 +47,46 @@ def test_predict_accepts_point_object_with_labels(client, monkeypatch):
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0]) assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
def fake_predict_points(model, image, points, labels):
calls["shape"] = image.shape
calls["points"] = points
calls["labels"] = labels
return (
[
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
],
[0.9, 0.01],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
"options": {
"crop_to_prompt": True,
"crop_margin": 0.1,
"auto_filter_background": True,
"min_score": 0.05,
},
})
assert response.status_code == 200
assert calls["shape"][0] < 100
assert calls["shape"][1] < 200
assert calls["labels"] == [1, 0]
assert response.json()["scores"] == [0.9]
polygon = response.json()["polygons"][0]
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
def test_predict_box_and_semantic_fallback(client, monkeypatch): def test_predict_box_and_semantic_fallback(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client) _, frame, _ = _create_project_and_frame(client)
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
@@ -246,3 +287,62 @@ def test_update_and_delete_annotation_validation(client):
f"/api/ai/annotations/{saved['id']}", f"/api/ai/annotations/{saved['id']}",
json={"template_id": 999}, json={"template_id": 999},
).status_code == 404 ).status_code == 404
def test_import_gt_mask_creates_annotations_with_seed_points(client):
project, frame, template = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"label": "Imported GT",
"color": "#22c55e",
},
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["project_id"] == project["id"]
assert body[0]["frame_id"] == frame["id"]
assert body[0]["template_id"] == template["id"]
assert body[0]["mask_data"]["label"] == "Imported GT"
assert body[0]["mask_data"]["source"] == "gt_mask"
assert body[0]["mask_data"]["gt_label_value"] == 255
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
assert len(body[0]["points"]) == 1
assert 0.0 <= body[0]["points"][0][0] <= 1.0
assert 0.0 <= body[0]["points"][0][1] <= 1.0
def test_import_gt_mask_splits_label_values(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "GT Class",
},
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
assert all(len(item["points"]) == 1 for item in body)

View File

@@ -59,7 +59,9 @@ def test_dashboard_overview_uses_persisted_records(client, db_session):
"name": "Pending Project", "name": "Pending Project",
"progress": 35, "progress": 35,
"status": "正在使用 FFmpeg/OpenCV 拆帧", "status": "正在使用 FFmpeg/OpenCV 拆帧",
"raw_status": "running",
"frame_count": 0, "frame_count": 0,
"error": None,
"updated_at": body["tasks"][0]["updated_at"], "updated_at": body["tasks"][0]["updated_at"],
}, },
] ]

View File

@@ -1,6 +1,10 @@
import zipfile import zipfile
import json
from io import BytesIO from io import BytesIO
import cv2
import numpy as np
def _seed_export_data(client): def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json() project = client.post("/api/projects", json={"name": "Export Project"}).json()
@@ -58,7 +62,55 @@ def test_export_masks_zip(client):
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip") assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive: with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"] assert archive.namelist() == [
f"mask_{annotation['id']:06d}.png",
"semantic_frame_000000.png",
"semantic_classes.json",
]
def test_export_masks_uses_z_index_for_semantic_fusion(client):
project = client.post("/api/projects", json={"name": "Fusion Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 20,
"height": 20,
}).json()
low = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Low",
"color": "#00ff00",
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
},
}).json()
high = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
"label": "High",
"color": "#ff0000",
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
},
}).json()
response = client.get(f"/api/export/{project['id']}/masks")
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert f"mask_{low['id']:06d}.png" in archive.namelist()
assert f"mask_{high['id']:06d}.png" in archive.namelist()
legend = json.loads(archive.read("semantic_classes.json"))
high_value = next(item["value"] for item in legend["classes"] if item["key"] == "class:high")
semantic_bytes = np.frombuffer(archive.read("semantic_frame_000000.png"), dtype=np.uint8)
semantic = cv2.imdecode(semantic_bytes, cv2.IMREAD_GRAYSCALE)
assert semantic[10, 10] == high_value
def test_export_missing_project_returns_404(client): def test_export_missing_project_returns_404(client):

View File

@@ -140,3 +140,25 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
assert project_detail["status"] == "ready" assert project_detail["status"] == "ready"
frames = client.get(f"/api/projects/{project['id']}/frames").json() frames = client.get(f"/api/projects/{project['id']}/frames").json()
assert "frame_000001.jpg" in frames[0]["image_url"] assert "frame_000001.jpg" in frames[0]["image_url"]
def test_parse_task_runner_skips_already_cancelled_task(db_session):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
task = ProcessingTask(
task_type="parse_video",
status="cancelled",
progress=100,
message="任务已取消",
project_id=1,
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
result = run_parse_media_task(db_session, task.id)
assert result["status"] == "cancelled"
assert result["message"] == "任务已取消"

View File

@@ -26,6 +26,25 @@ def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
assert payload["status"] == "解析完成" assert payload["status"] == "解析完成"
def test_task_progress_payload_marks_cancelled_tasks():
task = SimpleNamespace(
id=13,
project_id=7,
project=SimpleNamespace(name="demo.mp4"),
status="cancelled",
progress=100,
message="任务已取消",
error="Cancelled by user",
updated_at=None,
)
payload = task_progress_payload(task)
assert payload["type"] == "cancelled"
assert payload["status"] == "任务已取消"
assert payload["error"] == "Cancelled by user"
def test_publish_progress_event_writes_json_to_redis(monkeypatch): def test_publish_progress_event_writes_json_to_redis(monkeypatch):
calls = [] calls = []

View File

@@ -0,0 +1,112 @@
import json
from pathlib import Path
import numpy as np
from services.sam3_engine import SAM3Engine
class _Completed:
def __init__(self, returncode=0, stdout="", stderr=""):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
def _external_settings(monkeypatch, python_path: Path):
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
python_path.chmod(0o755)
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
monkeypatch.setattr("services.sam3_engine.TORCH_AVAILABLE", False)
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_enabled", True)
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_python", str(python_path))
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
assert "--status" in args
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
status = SAM3Engine().status()
assert status["available"] is True
assert status["external_available"] is True
assert status["package_available"] is True
assert status["python_ok"] is True
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
calls = []
def fake_run(args, **_kwargs):
calls.append(args)
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
request_path = Path(args[-1])
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["text"] == "vessel"
assert request["confidence_threshold"] == 0.4
assert Path(request["image_path"]).exists()
return _Completed(stdout=json.dumps({
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
"scores": [0.91],
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
polygons, scores = SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), " vessel ")
assert polygons == [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]
assert scores == [0.91]
assert any("--request" in args for args in calls)
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
return _Completed(returncode=1, stderr=json.dumps({"error": "HF access denied"}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
try:
SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), "vessel")
except RuntimeError as exc:
assert "HF access denied" in str(exc)
else:
raise AssertionError("Expected SAM 3 external inference failure.")

104
backend/tests/test_tasks.py Normal file
View File

@@ -0,0 +1,104 @@
from models import ProcessingTask
def test_cancel_task_revokes_celery_and_updates_project(client, db_session, monkeypatch):
project = client.post("/api/projects", json={
"name": "Cancelable",
"video_path": "uploads/1/clip.mp4",
"status": "parsing",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="running",
progress=35,
message="正在使用 FFmpeg/OpenCV 拆帧",
project_id=project["id"],
celery_task_id="celery-1",
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
revoked = []
published = []
monkeypatch.setattr(
"routers.tasks.celery_app.control.revoke",
lambda celery_id, terminate, signal: revoked.append((celery_id, terminate, signal)),
)
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append(event_task.status))
response = client.post(f"/api/tasks/{task.id}/cancel")
assert response.status_code == 200
body = response.json()
assert body["status"] == "cancelled"
assert body["progress"] == 100
assert body["message"] == "任务已取消"
assert body["error"] == "Cancelled by user"
assert revoked == [("celery-1", True, "SIGTERM")]
assert published == ["cancelled"]
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
project = client.post("/api/projects", json={
"name": "Retryable",
"video_path": "uploads/2/clip.mp4",
"source_type": "video",
"status": "error",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="failed",
progress=100,
message="解析失败",
error="ffmpeg failed",
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
class FakeAsyncResult:
id = "celery-retry"
queued = []
published = []
monkeypatch.setattr("routers.tasks.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append((event_task.id, event_task.status)))
response = client.post(f"/api/tasks/{task.id}/retry")
assert response.status_code == 202
body = response.json()
assert body["id"] != task.id
assert body["status"] == "queued"
assert body["progress"] == 0
assert body["celery_task_id"] == "celery-retry"
assert body["payload"]["retry_of"] == task.id
assert queued == [body["id"]]
assert published[0] == (body["id"], "queued")
assert published[-1] == (body["id"], "queued")
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
def test_task_actions_reject_invalid_states(client, db_session):
project = client.post("/api/projects", json={
"name": "Done",
"video_path": "uploads/3/clip.mp4",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="success",
progress=100,
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
assert client.post(f"/api/tasks/{task.id}/cancel").status_code == 409
assert client.post(f"/api/tasks/{task.id}/retry").status_code == 409

View File

@@ -38,21 +38,21 @@ Word 方案描述的理想系统包含:
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` | | 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` |
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 | | DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` | | WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` |
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 3 依赖官方运行环境,当前环境不满足时会标为不可用 | | SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 |
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;重叠裁决算法未落地 | | 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新当前帧删除;逐点几何编辑未落地 | | 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
| COCO / Mask 导出 | 部分落地 | `backend/routers/export.py`COCO JSON 前端按钮已接入,PNG mask ZIP 尚未提供前端按钮 | | COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`COCO JSON PNG mask ZIP 前端按钮均已接入ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
## 当前代码尚未落地的目标 ## 当前代码尚未落地的目标
- SAM 3当前已提供 `sam3_engine.py` 适配入口和状态检测;要实际运行仍需安装官方 `facebookresearch/sam3` 依赖并满足 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ - SAM 3当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py``setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU 和官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。 - Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。
- GT mask 导入:当前前端没有 GT Label 导入入口,后端也没有对应路由 - GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point骨架提取、HDBSCAN 和模板自动映射尚未实现
- Mask 到点区域的拓扑降维:当前没有距离变换、骨架提取、HDBSCAN 等实现。 - Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。
- 类别优先级融合:模板有 z-index但没有后端融合算法 - 类别优先级融合:PNG mask 导出时已按 zIndex 生成语义融合 mask前端裁决预览尚未实现
- 撤销/重做:工具栏有按钮,但没有历史栈。 - 撤销/重做:当前已有全局 mask 历史栈。
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。 - 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
## 结论 ## 结论
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可实时查看任务进度、可浏览项目帧、可维护模板、可点/框 AI 推理、可保存标注、可导出 COCO、可查看 Dashboard 后端概览”的全栈雏形,但离 Word 中描述的完整智能标注系统还有明显差距。下一阶段最重要的是继续补齐手工绘制、撤销重做和真实语义文本分割 当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强

View File

@@ -71,6 +71,7 @@
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。 5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
6. worker 把拆出的帧重新上传 MinIO写入 `frames` 表,并更新任务状态。 6. worker 把拆出的帧重新上传 MinIO写入 `frames` 表,并更新任务状态。
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。 7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。
### 工作区浏览 ### 工作区浏览
@@ -98,7 +99,7 @@
## 当前主要风险点 ## 当前主要风险点
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。 - 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖时走 SAM 3当前环境若不满足会在模型状态中标明不可用 - AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖、并具备 Hugging Face gated 权重访问时走 SAM 3当前状态接口会分别暴露外部 Python 环境、CUDA、包导入和 checkpoint access 是否满足
- 工作区顶部“导出 JSON 标注集”和“结构化归档保存”已接入导出、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。撤销重做和手工绘制仍未持久化 - 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`worker 进度通过 Redis `seg:progress` 转发到 WebSocket。 - Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。
- 后端路由大多未做真实鉴权。 - 后端路由大多未做真实鉴权。

View File

@@ -30,8 +30,11 @@
| 元素 | 状态 | 说明 | | 元素 | 状态 | 说明 |
|------|------|------| |------|------|------|
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` | | WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running 任务生成 | | 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/failed/cancelled 任务生成 |
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`FastAPI 广播 progress/complete/error | | 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` |
| 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 |
| 失败详情 | 真实可用 | 任务详情按钮调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间 |
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`FastAPI 广播 progress/complete/error/cancelled |
| 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 | | 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 |
| 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录WebSocket status/complete/error 会继续追加 | | 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录WebSocket status/complete/error 会继续追加 |
@@ -60,6 +63,8 @@
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 | | SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask | | 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON | | “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON |
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point再回显到工作区 |
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}` | | “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
## CanvasArea 画布 ## CanvasArea 画布
@@ -70,10 +75,13 @@
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale | | 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable | | 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 | | 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
| 正向/反向选点 | 部分可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;需点击归档保存才持久化 | | 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`结果需点击归档保存才持久化 |
| 框选 | 部分可用 | UI 能画框,并把框坐标归一化后调用后端推理;需点击归档保存才持久化 | | 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
| AI 推理中提示 | 真实可用 | 请求期间会显示 | | AI 推理中提示 | 真实可用 | 请求期间会显示 |
| Mask 渲染 | 部分可用 | 前端会把推理/已保存标注转成 Konva `pathData` 渲染 | | 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon点工具生成小区域均写入 `Mask.segmentation`,可归档保存 |
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 |
| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty选中顶点后 Delete/Backspace 可删点但保留至少三点 |
| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty归档保存会更新后端 |
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask已保存 mask 会标为 dirty归档保存时更新后端 | | 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask已保存 mask 会标为 dirty归档保存时更新后端 |
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask | | 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 | | 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
@@ -84,11 +92,11 @@
| 元素 | 状态 | 说明 | | 元素 | 状态 | 说明 |
|------|------|------| |------|------|------|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 | | 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
| 多边形/矩形/圆/点/线 | Mock / UI-only | 切换 activeTool,没有对应绘制逻辑 | | 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
| 区域合并/去除 | Mock / UI-only | 只切换 activeTool没有后端或前端算法 | | 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask使用 `polygon-clipping` 做 union / difference合并会保留主 mask 并移除被合并 mask去除会从主 mask 扣除后续选中 mask |
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 | | 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 | | 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
| 撤销/重做 | Mock / UI-only | 按钮无事件 | | 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y |
## FrameTimeline 时间轴 ## FrameTimeline 时间轴
@@ -117,10 +125,11 @@
|------|------|------| |------|------|------|
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand`predictMask()` 会把 `model` 传给后端 SAM registry | | 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand`predictMask()` 会把 `model` 传给后端 SAM registry |
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 | | 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且运行环境满足官方依赖时走 SAM 3 文本语义推理,否则状态接口会标明不可用 | | 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
| 参数开关 | Mock / UI-only | `cropMode``autoDeleteBg` 只改本地状态 | | 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon`autoDeleteBg` 会发送 `auto_filter_background``min_score`,后端过滤低分结果和覆盖负向点的结果 |
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 | | 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
| 上传替换底图 | Mock / UI-only | 按钮无事件 | | 上传替换底图 | Mock / UI-only | 按钮无事件 |
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks | | 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store但没有保存/确认流程 | | 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store但没有保存/确认流程 |
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 | | 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
@@ -141,6 +150,6 @@
## 总体结论 ## 总体结论
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区点/框 AI 推理、标注保存/回显、COCO 导出、模板 CRUD。 当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
当前最主要的 Mock 或未打通链路是:撤销重做、手工几何绘制、GT 导入、mask 降维点区域、真正的文本语义分割和语义优先级融合 当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标

View File

@@ -34,6 +34,8 @@ Authorization: Bearer <token>
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data | | `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task | | `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 | | `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url | | `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` | | `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 | | `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
@@ -41,8 +43,10 @@ Authorization: Bearer <token>
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask | | `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask | | `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 | | `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
| `importGtMask(file, projectId, frameId, templateId?)` | `POST /api/ai/import-gt-mask` | 对齐 | multipart 上传 GT mask后端按非零像素值/连通域生成 polygon 标注和 seed point |
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 | | `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` | | `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
| `exportMasks(projectId)` | `GET /api/export/{projectId}/masks` | 对齐 | 下载单标注 mask、语义融合 mask 和类别映射 ZIP |
## 后端 FastAPI 接口 ## 后端 FastAPI 接口
@@ -69,10 +73,13 @@ Authorization: Bearer <token>
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 | | POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
| GET | `/api/tasks` | 查询后台任务列表 | | GET | `/api/tasks` | 查询后台任务列表 |
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 | | GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 | | POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 | | GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
| POST | `/api/ai/auto` | 自动分割 | | POST | `/api/ai/auto` | 自动分割 |
| POST | `/api/ai/annotate` | 保存 AI 标注 | | POST | `/api/ai/annotate` | 保存 AI 标注 |
| POST | `/api/ai/import-gt-mask` | 导入 GT mask 并生成标注/seed point |
| GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 | | GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 |
| PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 | | PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 |
| DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 | | DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 |
@@ -143,7 +150,13 @@ Authorization: Bearer <token>
- `point` - `point`
- `box` - `box`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation - `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定。
可选 `options` 字段:
- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
后端响应: 后端响应:
@@ -181,13 +194,19 @@ Authorization: Bearer <token>
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations` - `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}` - `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}` - `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`
- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。
- `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask` - `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`
- `getTask()` 已接入 `GET /api/tasks/{taskId}` - `getTask()` 已接入 `GET /api/tasks/{taskId}`
- `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`
- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。 - `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
- 工作区导出按钮已调用 `exportCoco()`,并会先保存未归档 mask - Dashboard 任务列表已展示 queued/running/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情
- 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。
- PNG mask ZIP 已包含每帧 `semantic_frame_*.png``semantic_classes.json`,重叠区域按 zIndex 裁决。
## 仍需处理的接口问题 ## 仍需处理的接口问题
- WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。 - WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。
- Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`FastAPI 订阅后广播到 `/ws/progress` - Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`FastAPI 订阅后广播到 `/ws/progress`
- 已保存标注目前支持分类级更新和整帧清空删除;逐点几何编辑器尚未实现。 - 已保存标注目前支持分类级更新、polygon 顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑后的 PATCH 更新和整帧清空删除;复杂洞结构的专业编辑仍未实现。

View File

@@ -16,8 +16,8 @@
剩余边界: 剩余边界:
1. SAM 3 真实推理需要独立满足官方 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ 环境 1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器尚未实现 2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强
## 阶段 2打通标注保存已完成基础闭环 ## 阶段 2打通标注保存已完成基础闭环
@@ -34,16 +34,22 @@
剩余建议: 剩余建议:
1. 加入保存冲突处理和批量保存错误提示。 1. 加入保存冲突处理和批量保存错误提示。
2. 增加逐点几何编辑器,让已保存 mask 的 polygon 本身可以被修改后 PATCH 2. 逐点几何编辑器已支持拖动/删除顶点、边中点插入新点和多 polygon 子区域编辑;后续增强为复杂洞结构编辑
3. 区域合并/去除已支持基础 union/difference后续增强为更明确的多选列表、操作预览和冲突确认。
## 阶段 3接入导出按钮已完成 COCO JSON ## 阶段 3接入导出按钮已完成 COCO JSON 和 PNG Mask ZIP
当前工作区“导出 JSON 标注集”会先保存未归档 mask再调用 COCO 导出接口。 当前工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”都会先保存未归档 mask再调用后端导出接口。
建议 已完成
1. 增加“导出 PNG Mask ZIP”按钮调用 `/api/export/{projectId}/masks` 1. COCO JSON 调用 `/api/export/{projectId}/coco`
2. 无标注时给出更明确的空导出提示 2. PNG Mask ZIP 调用 `/api/export/{projectId}/masks`
3. ZIP 内保留单标注二值 `mask_*.png`,同时输出 `semantic_frame_*.png``semantic_classes.json`
剩余建议:
1. 无标注时给出更明确的空导出提示。
## 阶段 4替换 Dashboard mock ## 阶段 4替换 Dashboard mock
@@ -52,13 +58,18 @@
已完成: 已完成:
- 聚合项目、帧、标注、模板数量和主机 load average。 - 聚合项目、帧、标注、模板数量和主机 load average。
-`processing_tasks` queued/running 任务生成解析队列。 -`processing_tasks` queued/running/failed/cancelled 任务生成解析队列。
- 按最近任务、项目、标注、模板记录生成活动流。 - 按最近任务、项目、标注、模板记录生成活动流。
已完成补充:
1. Dashboard 对 queued/running 任务提供取消按钮。
2. Dashboard 对 failed/cancelled 任务提供重试按钮。
3. Dashboard 详情弹窗展示任务 error、payload、result、Celery ID 和时间。
剩余建议: 剩余建议:
1.任务增加取消、重试和失败详情 UI 1. Dashboard 增加任务历史筛选
2. 为 Dashboard 增加任务历史筛选和失败详情入口。
## 阶段 5异步拆帧和进度 ## 阶段 5异步拆帧和进度
@@ -72,44 +83,55 @@ Word 方案中提到 Celery + Redis。当前已经有 Celery app、worker task
4. worker 写 PostgreSQL 任务进度。 4. worker 写 PostgreSQL 任务进度。
5. worker 发布 Redis `seg:progress`FastAPI 广播到 `/ws/progress` 5. worker 发布 Redis `seg:progress`FastAPI 广播到 `/ws/progress`
已完成补充:
1. `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,并尝试 revoke Celery。
2. `POST /api/tasks/{task_id}/retry` 为 failed/cancelled 任务创建新的 queued 任务。
3. worker 在关键阶段检查 cancelled 状态,避免取消后继续写帧。
4. Redis/WebSocket 进度事件增加 `cancelled` 类型。
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务历史筛选和更细的 worker 中断粒度。
## 阶段 6GT 导入与点区域(已完成基础增强版)
Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前已经完成基础增强版:导入二值/标签 mask 图片后,后端按非零像素值拆分类别,再按连通域生成 polygon 标注,并用距离变换提取一个正向 seed point。
已完成:
1. 工作区提供“导入 GT Mask”入口。
2. 前端调用 `POST /api/ai/import-gt-mask` multipart 接口。
3. 后端按非零像素值拆分多类别 mask。
4. 后端使用 OpenCV contour 提取每个类别下的连通域。
5. 后端使用 distance transform 生成 `points` seed。
6. 导入结果写入 `annotations` 表并回显为工作区 mask。
7. 前端把 seed point 转为像素坐标显示在 Canvas 上,拖动后会标记标注为 dirty 并可归档保存。
剩余建议: 剩余建议:
1. 为任务增加取消、重试和失败详情接口 1. 增加骨架提取和聚类增强
2. 前端 Dashboard 保留轮询兜底,并补充失败详情 UI 2. 为多类别像素值提供模板分类自动映射规则
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务控制。 ## 阶段 7模板优先级融合已完成导出侧裁决
## 阶段 6GT 导入与点区域 当前导出 PNG Mask ZIP 时已经按 class/template z-index 做重叠裁决,从低到高覆盖,生成每帧 `semantic_frame_*.png`
这是 Word 方案中最复杂的部分,当前完全未实现。 已完成:
建议拆成小步:
1. 先支持上传二值/多类别 mask。
2. 后端按类别提取 connected components。
3. 用 OpenCV distance transform 找正向点。
4. 暂时不做骨架/HDBSCAN先生成最小可用点集。
5. 前端以可拖拽点显示并保存。
6. 后续再做骨架和聚类增强。
## 阶段 7模板优先级融合
当前模板有 z-index但没有真正用于语义冲突裁决。
建议:
1. 标注保存时记录 template class id / name / zIndex。 1. 标注保存时记录 template class id / name / zIndex。
2. 导出 mask 时按 zIndex 从低到高覆盖。 2. 导出 mask 时按 zIndex 从低到高覆盖。
3. 同类 mask 做 union 3. 同类语义值在融合图中共享同一个 class value
4. 跨类重叠由高 zIndex 覆盖低 zIndex。 4. 跨类重叠由高 zIndex 覆盖低 zIndex。
这一步完成后,系统才真正符合“语义分割一个像素一个类别”的目标。 剩余建议:
1. 在前端预览重叠裁决结果。
2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。
## 阶段 8清理 UI 文案与 Mock ## 阶段 8清理 UI 文案与 Mock
建议统一这些文案和真实能力: 建议统一这些文案和真实能力:
- SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。 - SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。
- 撤销/重做按钮接历史栈,否则隐藏 - 撤销/重做按钮已接全局 mask 历史栈
- “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。 - “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。
- AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。 - AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。

View File

@@ -31,6 +31,9 @@
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready` - 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。 - 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。 - 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
- 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。
- 后端支持 `POST /api/tasks/{task_id}/retry` 对 failed/cancelled 任务创建新的 queued 任务。
- worker 会在关键阶段检查任务是否已取消,取消后停止继续写帧。
## R4 工作区与帧浏览 ## R4 工作区与帧浏览
@@ -46,7 +49,13 @@
- 工具栏可以切换当前 active tool。 - 工具栏可以切换当前 active tool。
- 正向点、反向点、框选工具会影响 Canvas 交互。 - 正向点、反向点、框选工具会影响 Canvas 交互。
- 魔法棒按钮切换到 AI 页面。 - 魔法棒按钮切换到 AI 页面。
- 多边形、矩形、圆、点、线、合并、去除、撤销、重做当前只提供 UI 状态或占位按钮,不完成真实绘制/算法 - 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask
- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
- 撤销、重做绑定全局 `maskHistory/maskFuture`支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
- 区域合并工具支持多选当前帧 mask并使用 polygon union 生成合并后的主 mask。
- 区域去除工具支持多选当前帧 mask并从第一个选中的主 mask 中扣除后续选中 mask。
## R6 AI 推理 ## R6 AI 推理
@@ -56,7 +65,8 @@
- 前端发送后端契约:`image_id``prompt_type``prompt_data``model` - 前端发送后端契约:`image_id``prompt_type``prompt_data``model`
- 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。 - 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。
- 框选提示传归一化 `[x1, y1, x2, y2]` - 框选提示传归一化 `[x1, y1, x2, y2]`
- 语义文本提示传 `semantic`;选择 `sam3`环境满足依赖时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。 - 语义文本提示传 `semantic`;选择 `sam3`独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
- AI 参数支持 `crop_to_prompt``auto_filter_background``min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
- 后端返回 `polygons``scores` - 后端返回 `polygons``scores`
- 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area` - 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area`
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。 - AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
@@ -71,6 +81,9 @@
- 当前前端“结构化归档保存”会保存当前项目未保存 mask并会更新已标记为 dirty 的已保存 mask。 - 当前前端“结构化归档保存”会保存当前项目未保存 mask并会更新已标记为 dirty 的已保存 mask。
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。 - 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
- 工作区加载项目帧后会查询已保存标注并回显。 - 工作区加载项目帧后会查询已保存标注并回显。
- 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask`
- 后端导入 GT mask 时按非零像素值拆分多类别区域,再按连通域生成 polygon 标注,并通过距离变换写入 seed point。
- 前端会回显导入标注的 seed point拖动 seed point 后,已保存标注会变为 dirty归档保存时会更新后端 `points`
## R8 模板库 ## R8 模板库
@@ -93,9 +106,12 @@
- Dashboard 显示基础统计、解析队列和活动日志。 - Dashboard 显示基础统计、解析队列和活动日志。
- Dashboard 初始数据来自 `GET /api/dashboard/overview` - Dashboard 初始数据来自 `GET /api/dashboard/overview`
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。 - 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
- 解析队列由 `processing_tasks` 中的 queued/running 任务生成;活动日志由最近任务、项目、标注和模板记录生成。 - 解析队列由 `processing_tasks` 中的 queued/running/failed/cancelled 任务生成;活动日志由最近任务、项目、标注和模板记录生成。
- Dashboard 对 queued/running 任务提供取消按钮,对 failed/cancelled 任务提供重试按钮。
- Dashboard 任务详情会读取 `GET /api/tasks/{task_id}` 并展示失败 error、payload、result、Celery ID 和时间信息。
- Dashboard 会连接 `/ws/progress` - Dashboard 会连接 `/ws/progress`
- 收到 progress、complete、error、status 消息时,前端会更新队列或日志。 - 收到 progress、complete、error、status 消息时,前端会更新队列或日志。
- 收到 cancelled 消息时,前端会把对应任务标记为已取消。
- Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件FastAPI 订阅并广播给 `/ws/progress` 客户端。 - Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件FastAPI 订阅并广播给 `/ws/progress` 客户端。
- 后端 WebSocket 接收到客户端消息后返回 status heartbeat。 - 后端 WebSocket 接收到客户端消息后返回 status heartbeat。
@@ -104,7 +120,10 @@
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。 - 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。 - 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
- 当前前端 `exportCoco()` API 封装已对齐后端路径。 - 当前前端 `exportCoco()` API 封装已对齐后端路径。
- 当前前端 `exportMasks()` API 封装已对齐后端路径。
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。 - 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
- 工作区“导出 PNG Mask ZIP”按钮已绑定下载事件导出前会先保存当前未归档 mask。
- PNG mask ZIP 包含单标注二值 mask、按 zIndex 融合后的每帧语义 mask 和 `semantic_classes.json`
## R12 配置 ## R12 配置

View File

@@ -19,17 +19,17 @@
| 模块 | 文件 | 设计职责 | | 模块 | 文件 | 设计职责 |
|------|------|----------| |------|------|----------|
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 | | 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
| 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、工具状态 | | 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 |
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 | | API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 | | 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 | | WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 | | 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
| 登录页 | `src/components/Login.tsx` | 调用登录 API写入 store | | 登录页 | `src/components/Login.tsx` | 调用登录 API写入 store |
| Dashboard | `src/components/Dashboard.tsx` | 展示统计和 WebSocket 进度消息 | | Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 |
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM | | 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 | | 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 |
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask | | Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具跳转 AI 页面 | | 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具跳转 AI 页面、触发 mask 撤销/重做 |
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 | | 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 | | 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 | | AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
@@ -51,7 +51,7 @@
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 | | AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 | | Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 | | SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
| SAM 3 | `backend/services/sam3_engine.py` | SAM 3 状态检测和文本语义推理适配 | | SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 |
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 | | SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
## 状态模型 ## 状态模型
@@ -62,6 +62,7 @@
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。 - `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
- `Template` / `TemplateClass`:模板和分类定义。 - `Template` / `TemplateClass`:模板和分类定义。
- `Mask`:前端渲染用 mask包含 `pathData``segmentation``bbox``area` - `Mask`:前端渲染用 mask包含 `pathData``segmentation``bbox``area`
- `maskHistory` / `maskFuture`mask 编辑历史栈,用于撤销和重做。
- `activeModule`:当前页面。 - `activeModule`:当前页面。
- `activeTool`:当前工具。 - `activeTool`:当前工具。
- `aiModel`:当前选择的 AI 模型,取值为 `sam2``sam3` - `aiModel`:当前选择的 AI 模型,取值为 `sam2``sam3`
@@ -82,6 +83,14 @@
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress` 4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`
5. 刷新项目列表。 5. 刷新项目列表。
### 任务控制
1. Dashboard 从 `GET /api/dashboard/overview` 读取 queued/running/failed/cancelled 任务。
2. 用户取消任务时,前端调用 `POST /api/tasks/{task_id}/cancel`;后端写入 `cancelled`、设置 `finished_at`,并尝试 `celery_app.control.revoke(..., terminate=True)`
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。
### 工作区加载 ### 工作区加载
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()` 1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`
@@ -102,6 +111,41 @@
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。 9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。 10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
### 手工绘制与历史栈
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
2. `CanvasArea` 将交互坐标转换成像素 polygon。
3. 新 mask 写入 `pathData`、像素 `segmentation``bbox``area` 和当前模板分类元数据。
4. `addMask()``setMasks()``updateMask()``clearMasks()` 会维护 `maskHistory/maskFuture`
5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`
### Polygon 逐点编辑
1. 用户点击 Canvas 上的 mask path 后,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。
2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation``bbox``area`
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty``saved=false`
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
### 区域合并与去除
1. 用户选择 `area_merge``area_remove` 后,点击多个当前帧 mask 组成选择集。
2. `CanvasArea``Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。
3. `area_merge` 使用 union更新第一个选中的主 mask并从前端 store 移除后续被合并 mask如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。
4. `area_remove` 使用 difference从第一个选中的主 mask 中扣除后续选中 mask扣除对象本身保留。
5. 结果会重算 `pathData``segmentation``bbox``area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路。
### GT Mask 导入
1. 工作区“导入 GT Mask”选择图片文件。
2. 前端 `importGtMask()` 以 multipart form-data 调用 `POST /api/ai/import-gt-mask`,携带 `project_id``frame_id`
3. 后端验证项目、帧、模板后使用 OpenCV 读取灰度 mask。
4. 后端按非零像素值拆分多类别标签。
5. 后端对每个类别的前景做 contour 提取,每个连通域保存为一个 `Annotation`
6. `points` 字段保存距离变换中心 seed point`mask_data.polygons` 保存 normalized polygon`mask_data.gt_label_value` 保存原始像素类别值。
7. 前端重新读取项目标注并回显。
8. `annotationToMask()` 会把 normalized seed point 转成像素坐标Canvas 以可拖拽点显示;拖动后 `buildAnnotationPayload()` 会把点再归一化写回后端。
### 模板管理 ### 模板管理
1. `TemplateRegistry` 从后端读取模板。 1. `TemplateRegistry` 从后端读取模板。
@@ -114,8 +158,10 @@
### 导出 ### 导出
1. 后端根据项目、帧、标注和模板生成 COCO JSON。 1. 后端根据项目、帧、标注和模板生成 COCO JSON。
2. PNG mask 导出会把 normalized polygon 渲染为二值 mask 并打包 ZIP 2. PNG mask 导出会把 normalized polygon 渲染为单标注二值 mask。
3. 前端“导出 JSON 标注集”按钮会在导出前保存待归档标注,然后下载 COCO JSON 3. PNG mask 导出还会按 `mask_data.class.zIndex` 或模板 `z_index` 从低到高覆盖,生成每帧语义融合 mask
4. ZIP 内写入 `semantic_classes.json`,记录语义值到类别、颜色和 zIndex 的映射。
5. 前端“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮都会在导出前保存待归档标注然后下载对应文件。
## 接口契约 ## 接口契约
@@ -123,13 +169,18 @@
- `updateProject()` 使用 `PATCH /api/projects/{id}` - `updateProject()` 使用 `PATCH /api/projects/{id}`
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco` - `exportCoco()` 使用 `GET /api/export/{projectId}/coco`
- `exportMasks()` 使用 `GET /api/export/{projectId}/masks`
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id``prompt_type``prompt_data``model` - `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id``prompt_type``prompt_data``model`
- `saveAnnotation()` 使用 `POST /api/ai/annotate` - `saveAnnotation()` 使用 `POST /api/ai/annotate`
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations` - `getProjectAnnotations()` 使用 `GET /api/ai/annotations`
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}` - `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}` - `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。 - 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态 - 后端 `/api/ai/predict` 支持可选 `options``crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon`auto_filter_background` 会按 `min_score` 和负向点过滤结果
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。 - point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
## 外部依赖边界 ## 外部依赖边界
@@ -146,10 +197,8 @@
以下能力属于当前冻结版本的占位或半可用功能: 以下能力属于当前冻结版本的占位或半可用功能:
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running 任务生成。 - Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。
- 多边形、矩形、圆、点、线手工绘制未实现。 - 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。
- 合并、去除、撤销、重做未实现 - SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权时 `available=false`
- 工作区导出 PNG mask ZIP 按钮尚未提供。
- 已保存标注支持通过“应用分类”进入 dirty 状态并归档更新;暂未提供逐点几何编辑器。
- SAM 3 文本语义分割取决于官方依赖和 GPU 运行环境;状态接口会暴露真实可用性。
- 自定义分类只存在本地组件状态。 - 自定义分类只存在本地组件状态。
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。

View File

@@ -16,15 +16,15 @@
|------|----------|--------| |------|----------|--------|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 | | R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 | | R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
| R3 媒体上传与拆帧 | `backend/tests/test_media.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧 | | R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 | | R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx` | 工具切换、AI 跳转、占位按钮存在 | | R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry | | R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、项目不存在、帧不存在 | | R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD | | R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 | | R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py` | 后端概览接口、任务表驱动队列、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat | | R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO 按钮下载、导出前自动保存、COCO 路径、JSON 结构、mask ZIP | | R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 | | R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 | | R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |

26
package-lock.json generated
View File

@@ -18,6 +18,7 @@
"konva": "^10.2.5", "konva": "^10.2.5",
"lucide-react": "^0.546.0", "lucide-react": "^0.546.0",
"motion": "^12.23.24", "motion": "^12.23.24",
"polygon-clipping": "^0.15.7",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
"react-konva": "^19.2.3", "react-konva": "^19.2.3",
@@ -4165,6 +4166,16 @@
"url": "https://github.com/sponsors/jonschlinkert" "url": "https://github.com/sponsors/jonschlinkert"
} }
}, },
"node_modules/polygon-clipping": {
"version": "0.15.7",
"resolved": "https://registry.npmjs.org/polygon-clipping/-/polygon-clipping-0.15.7.tgz",
"integrity": "sha512-nhfdr83ECBg6xtqOAJab1tbksbBAOMUltN60bU+llHVOL0e5Onm1WpAXXWXVB39L8AJFssoIhEVuy/S90MmotA==",
"license": "MIT",
"dependencies": {
"robust-predicates": "^3.0.2",
"splaytree": "^3.1.0"
}
},
"node_modules/postcss": { "node_modules/postcss": {
"version": "8.5.12", "version": "8.5.12",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
@@ -4438,6 +4449,12 @@
"node": ">= 4" "node": ">= 4"
} }
}, },
"node_modules/robust-predicates": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/robust-predicates/-/robust-predicates-3.0.3.tgz",
"integrity": "sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==",
"license": "Unlicense"
},
"node_modules/rollup": { "node_modules/rollup": {
"version": "4.60.2", "version": "4.60.2",
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz",
@@ -4684,6 +4701,15 @@
"node": ">=0.10.0" "node": ">=0.10.0"
} }
}, },
"node_modules/splaytree": {
"version": "3.2.3",
"resolved": "https://registry.npmjs.org/splaytree/-/splaytree-3.2.3.tgz",
"integrity": "sha512-7OXrNWzy6CK+r7Ch9OLPBDTKfB6XlWHjX4P0RU5B3IgFuWPeYN0XtRtlexGRjgbQxpfaUve6jTAwBGWuGntz/w==",
"license": "MIT",
"engines": {
"node": ">=18.20 || >=20"
}
},
"node_modules/stackback": { "node_modules/stackback": {
"version": "0.0.2", "version": "0.0.2",
"resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz",

View File

@@ -24,6 +24,7 @@
"konva": "^10.2.5", "konva": "^10.2.5",
"lucide-react": "^0.546.0", "lucide-react": "^0.546.0",
"motion": "^12.23.24", "motion": "^12.23.24",
"polygon-clipping": "^0.15.7",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
"react-konva": "^19.2.3", "react-konva": "^19.2.3",

View File

@@ -40,4 +40,26 @@ describe('AISegmentation', () => {
expect(useStore.getState().aiModel).toBe('sam3'); expect(useStore.getState().aiModel).toBe('sam3');
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument(); expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
}); });
it('passes enabled inference parameters to the backend', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 120, y: 80, type: 'pos' }],
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
}));
});
}); });

View File

@@ -17,6 +17,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks); const clearMasks = useStore((state) => state.clearMasks);
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId); const activeTemplateId = useStore((state) => state.activeTemplateId);
@@ -109,6 +113,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
model: aiModel, model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })), points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined, text: semanticText.trim() || undefined,
options: {
crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg,
min_score: autoDeleteBg ? 0.05 : 0,
},
}); });
result.masks.forEach((m) => { result.masks.forEach((m) => {
@@ -136,7 +145,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]); }, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (effectiveTool === 'move') return;
@@ -290,10 +299,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span> <span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span>
</div> </div>
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)"> <button
onClick={undoMasks}
disabled={maskHistory.length === 0}
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
title="撤销操作 (Ctrl+Z)"
>
<Undo size={14} /> <Undo size={14} />
</button> </button>
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)"> <button
onClick={redoMasks}
disabled={maskFuture.length === 0}
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
title="重做操作 (Ctrl+Shift+Z)"
>
<Redo size={14} /> <Redo size={14} />
</button> </button>
<div className="w-px h-4 bg-white/10 mx-1"></div> <div className="w-px h-4 bg-white/10 mx-1"></div>

View File

@@ -79,6 +79,271 @@ describe('CanvasArea', () => {
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument(); expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
}); });
it('renders imported GT seed points for editable point regions', () => {
useStore.setState({
masks: [
{
id: 'gt-1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'GT',
color: '#22c55e',
points: [[120, 80]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
expect(screen.getAllByTestId('konva-circle')).toHaveLength(2);
});
it('selects a polygon mask and drags a vertex into dirty saved state', () => {
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Saved',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const handles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
expect(handles).toHaveLength(3);
fireEvent.mouseUp(handles[0], { clientX: 20, clientY: 30 });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 20 30 L 90 10 L 90 40 Z',
segmentation: [[20, 30, 90, 10, 90, 40]],
bbox: [20, 10, 70, 30],
area: 1050,
saveStatus: 'dirty',
saved: false,
}));
});
it('deletes a selected polygon vertex without dropping below three points', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 L 10 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40, 10, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const handles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
fireEvent.click(handles[0]);
fireEvent.keyDown(window, { key: 'Delete' });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 90 10 L 90 40 L 10 40 Z',
segmentation: [[90, 10, 90, 40, 10, 40]],
saveStatus: 'draft',
}));
});
it('inserts a polygon vertex from an edge midpoint handle', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const edgeHandles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#22d3ee');
fireEvent.click(edgeHandles[0]);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
saveStatus: 'draft',
}));
});
it('edits the selected polygon in a multi-polygon mask', () => {
useStore.setState({
masks: [
{
id: 'multi-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 50 10 L 50 40 Z M 100 100 L 150 100 L 150 140 Z',
label: 'Multi',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [
[10, 10, 50, 10, 50, 40],
[100, 100, 150, 100, 150, 140],
],
bbox: [10, 10, 140, 130],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[1]);
const vertexHandles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
fireEvent.mouseUp(vertexHandles[0], { clientX: 120, clientY: 120 });
expect(useStore.getState().masks[0].segmentation).toEqual([
[10, 10, 50, 10, 50, 40],
[120, 120, 150, 100, 150, 140],
]);
});
it('merges selected draft masks with polygon union', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
},
],
});
render(<CanvasArea activeTool="area_merge" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
segmentation: [[10, 10, 90, 10, 90, 30, 120, 30, 120, 80, 50, 80, 50, 50, 10, 50]],
bbox: [10, 10, 110, 70],
saveStatus: 'draft',
}));
});
it('removes overlap from the primary selected mask with polygon difference', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
},
],
});
render(<CanvasArea activeTool="area_remove" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
expect(useStore.getState().masks).toHaveLength(2);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
segmentation: [[10, 10, 90, 10, 90, 30, 50, 30, 50, 50, 10, 50]],
bbox: [10, 10, 80, 40],
saveStatus: 'draft',
}));
expect(useStore.getState().masks[1].id).toBe('m2');
});
it('creates a manual rectangle mask that can be undone and redone', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
});
render(<CanvasArea activeTool="create_rectangle" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage);
fireEvent.mouseMove(stage);
fireEvent.mouseUp(stage);
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
segmentation: [[120, 80, 260, 80, 260, 200, 120, 200]],
bbox: [120, 80, 140, 120],
}));
useStore.getState().undoMasks();
expect(useStore.getState().masks).toEqual([]);
useStore.getState().redoMasks();
expect(useStore.getState().masks).toHaveLength(1);
});
it('finalizes a clicked polygon with Enter', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.click(stage, { clientX: 120, clientY: 80 });
fireEvent.click(stage, { clientX: 220, clientY: 80 });
fireEvent.click(stage, { clientX: 180, clientY: 160 });
fireEvent.keyDown(window, { key: 'Enter' });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({
source: 'manual',
shape: '多边形',
}));
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => { it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({ useStore.setState({
activeTemplateId: '2', activeTemplateId: '2',

View File

@@ -1,17 +1,180 @@
import React, { useEffect, useRef, useState, useCallback } from 'react'; import React, { useEffect, useRef, useState, useCallback } from 'react';
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva'; import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
import polygonClipping, { type MultiPolygon, type Pair } from 'polygon-clipping';
import useImage from 'use-image'; import useImage from 'use-image';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api'; import { predictMask } from '../lib/api';
import type { Frame } from '../store/useStore'; import type { Frame, Mask } from '../store/useStore';
interface CanvasAreaProps { interface CanvasAreaProps {
activeTool: string; activeTool: string;
frame: Frame | null; frame: Frame | null;
onClearMasks?: () => void; onClearMasks?: () => void;
onDeleteMaskAnnotations?: (annotationIds: string[]) => Promise<void> | void;
} }
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) { type CanvasPoint = { x: number; y: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max);
}
function polygonPath(points: CanvasPoint[]): string {
if (points.length === 0) return '';
return points
.map((point, index) => `${index === 0 ? 'M' : 'L'} ${point.x} ${point.y}`)
.join(' ')
.concat(' Z');
}
function segmentationPath(segmentation?: number[][]): string {
return (segmentation || [])
.map((polygon) => polygonPath(flatPolygonToPoints(polygon)))
.filter(Boolean)
.join(' ');
}
function segmentationPolygonPath(segmentation: number[][] | undefined, polygonIndex: number): string {
const polygon = segmentation?.[polygonIndex];
return polygon ? polygonPath(flatPolygonToPoints(polygon)) : '';
}
function polygonSegmentation(points: CanvasPoint[]): number[][] {
return [points.flatMap((point) => [point.x, point.y])];
}
function segmentationToPoints(segmentation?: number[][], polygonIndex = 0): CanvasPoint[] {
const polygon = segmentation?.[polygonIndex] || [];
const points: CanvasPoint[] = [];
for (let index = 0; index < polygon.length - 1; index += 2) {
points.push({ x: polygon[index], y: polygon[index + 1] });
}
return points;
}
function flatPolygonToPoints(polygon: number[]): CanvasPoint[] {
const points: CanvasPoint[] = [];
for (let index = 0; index < polygon.length - 1; index += 2) {
points.push({ x: polygon[index], y: polygon[index + 1] });
}
return points;
}
function segmentationAllPoints(segmentation?: number[][]): CanvasPoint[] {
return (segmentation || []).flatMap((polygon) => flatPolygonToPoints(polygon));
}
function polygonBbox(points: CanvasPoint[]): [number, number, number, number] {
const xs = points.map((point) => point.x);
const ys = points.map((point) => point.y);
const minX = Math.min(...xs);
const minY = Math.min(...ys);
const maxX = Math.max(...xs);
const maxY = Math.max(...ys);
return [minX, minY, maxX - minX, maxY - minY];
}
function polygonArea(points: CanvasPoint[]): number {
if (points.length < 3) return 0;
const sum = points.reduce((acc, point, index) => {
const next = points[(index + 1) % points.length];
return acc + point.x * next.y - next.x * point.y;
}, 0);
return Math.abs(sum) / 2;
}
function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
}
function segmentationBbox(segmentation?: number[][]): [number, number, number, number] | undefined {
const points = segmentationAllPoints(segmentation);
return points.length > 0 ? polygonBbox(points) : undefined;
}
function closeRing(points: CanvasPoint[]): Pair[] {
const ring = points.map((point) => [point.x, point.y] as Pair);
const first = ring[0];
const last = ring[ring.length - 1];
if (first && last && (first[0] !== last[0] || first[1] !== last[1])) {
ring.push([first[0], first[1]]);
}
return ring;
}
function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
const polygons = (mask.segmentation || [])
.map((polygon) => flatPolygonToPoints(polygon))
.filter((points) => points.length >= 3)
.map((points) => [closeRing(points)]);
return polygons.length > 0 ? polygons : null;
}
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
return geometry
.map((polygon) => polygon[0] || [])
.map((ring) => {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.flatMap(([x, y]) => [x, y]);
})
.filter((polygon) => polygon.length >= 6);
}
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const x1 = Math.min(start.x, end.x);
const y1 = Math.min(start.y, end.y);
const x2 = Math.max(start.x, end.x);
const y2 = Math.max(start.y, end.y);
return [
{ x: x1, y: y1 },
{ x: x2, y: y1 },
{ x: x2, y: y2 },
{ x: x1, y: y2 },
];
}
function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const cx = (start.x + end.x) / 2;
const cy = (start.y + end.y) / 2;
const rx = Math.abs(end.x - start.x) / 2;
const ry = Math.abs(end.y - start.y) / 2;
return Array.from({ length: 32 }, (_, index) => {
const angle = (Math.PI * 2 * index) / 32;
return { x: cx + Math.cos(angle) * rx, y: cy + Math.sin(angle) * ry };
});
}
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
return Array.from({ length: 12 }, (_, index) => {
const angle = (Math.PI * 2 * index) / 12;
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
});
}
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
const dx = end.x - start.x;
const dy = end.y - start.y;
const length = Math.hypot(dx, dy) || 1;
const nx = (-dy / length) * halfWidth;
const ny = (dx / length) * halfWidth;
return [
{ x: start.x + nx, y: start.y + ny },
{ x: end.x + nx, y: end.y + ny },
{ x: end.x - nx, y: end.y - ny },
{ x: start.x - nx, y: start.y - ny },
];
}
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
const [stageSize, setStageSize] = useState({ width: 800, height: 600 }); const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1); const [scale, setScale] = useState(1);
@@ -20,22 +183,46 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 }); const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null); const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null); const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const [isInferencing, setIsInferencing] = useState(false); const [isInferencing, setIsInferencing] = useState(false);
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const updateMask = useStore((state) => state.updateMask);
const clearMasks = useStore((state) => state.clearMasks); const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks); const setMasks = useStore((state) => state.setMasks);
const storeActiveTool = useStore((state) => state.activeTool); const storeActiveTool = useStore((state) => state.activeTool);
const aiModel = useStore((state) => state.aiModel); const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId); const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass); const activeClass = useStore((state) => state.activeClass);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const effectiveTool = activeTool || storeActiveTool; const effectiveTool = activeTool || storeActiveTool;
// Load the actual frame image // Load the actual frame image
const [image] = useImage(frame?.url || ''); const [image] = useImage(frame?.url || '');
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id); const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
const selectedMask = React.useMemo(
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
[frameMasks, selectedMaskId],
);
const booleanSelectedMasks = React.useMemo(
() => selectedMaskIds
.map((id) => frameMasks.find((mask) => mask.id === id))
.filter((mask): mask is Mask => Boolean(mask)),
[frameMasks, selectedMaskIds],
);
const selectedMaskPoints = React.useMemo(
() => segmentationToPoints(selectedMask?.segmentation, selectedPolygonIndex),
[selectedMask?.segmentation, selectedPolygonIndex],
);
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length; const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length; const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length; const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
@@ -55,6 +242,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
return () => window.removeEventListener('resize', handleResize); return () => window.removeEventListener('resize', handleResize);
}, []); }, []);
useEffect(() => {
setManualStart(null);
setManualCurrent(null);
setPolygonPoints([]);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]);
useEffect(() => {
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}
}, [frameMasks, selectedMaskId]);
const handleWheel = (e: any) => { const handleWheel = (e: any) => {
e.evt.preventDefault(); e.evt.preventDefault();
const scaleBy = 1.1; const scaleBy = 1.1;
@@ -74,6 +280,50 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
}); });
}; };
const stagePoint = (e: any): CanvasPoint | null => {
const stage = e.target.getStage();
const relPos = stage?.getRelativePointerPosition();
if (!relPos) return null;
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
return {
x: clamp(relPos.x, 0, imageWidth),
y: clamp(relPos.y, 0, imageHeight),
};
};
const createManualMask = useCallback((shape: string, polygon: CanvasPoint[]) => {
if (!frame?.id || polygon.length < 3) return;
const area = polygonArea(polygon);
if (area <= 1) return;
const color = activeClass?.color || '#06b6d4';
const label = activeClass?.name || `手工${shape}`;
const mask: Mask = {
id: `manual-${frame.id}-${shape}-${Date.now()}`,
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: polygonPath(polygon),
label,
color,
segmentation: polygonSegmentation(polygon),
points: shape === '点区域'
? [[
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
]]
: undefined,
bbox: polygonBbox(polygon),
area,
metadata: { source: 'manual', shape },
};
addMask(mask);
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const handleMouseMove = (e: any) => { const handleMouseMove = (e: any) => {
const stage = e.target.getStage(); const stage = e.target.getStage();
if (!stage) return; if (!stage) return;
@@ -90,6 +340,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
setBoxCurrent({ x: relPos.x, y: relPos.y }); setBoxCurrent({ x: relPos.x, y: relPos.y });
} }
} }
if (manualStart && DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stage.getRelativePointerPosition();
if (pos) {
setManualCurrent({ x: pos.x, y: pos.y });
}
}
}; };
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => { const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
@@ -132,6 +389,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
label, label,
color, color,
segmentation: m.segmentation, segmentation: m.segmentation,
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
bbox: m.bbox, bbox: m.bbox,
area: m.area, area: m.area,
}); });
@@ -170,6 +428,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
}; };
const handleStageMouseDown = (e: any) => { const handleStageMouseDown = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
if (pos) {
setManualStart(pos);
setManualCurrent(pos);
}
return;
}
if (effectiveTool === 'box_select') { if (effectiveTool === 'box_select') {
const stage = e.target.getStage(); const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition(); const pos = stage.getRelativePointerPosition();
@@ -181,6 +448,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
}; };
const handleStageMouseUp = (e: any) => { const handleStageMouseUp = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
const end = stagePoint(e) || manualCurrent || manualStart;
const width = Math.abs(end.x - manualStart.x);
const height = Math.abs(end.y - manualStart.y);
const distance = Math.hypot(width, height);
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
createManualMask('矩形', rectanglePoints(manualStart, end));
}
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
createManualMask('圆形', circlePoints(manualStart, end));
}
if (effectiveTool === 'create_line' && distance > 4) {
createManualMask('线段', lineRegion(manualStart, end));
}
setManualStart(null);
setManualCurrent(null);
return;
}
if (effectiveTool === 'box_select' && boxStart && boxCurrent) { if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
const x1 = Math.min(boxStart.x, boxCurrent.x); const x1 = Math.min(boxStart.x, boxCurrent.x);
const y1 = Math.min(boxStart.y, boxCurrent.y); const y1 = Math.min(boxStart.y, boxCurrent.y);
@@ -199,12 +487,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (effectiveTool === 'move') return;
if (effectiveTool === 'box_select') return; // handled by mouseup if (effectiveTool === 'box_select') return; // handled by mouseup
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
if (effectiveTool === POINT_TOOL) {
const pos = stagePoint(e);
if (pos) {
createManualMask('点区域', pointRegion(pos));
}
return;
}
if (effectiveTool === POLYGON_TOOL) {
const pos = stagePoint(e);
if (pos) {
setPolygonPoints((current) => [...current, pos]);
}
return;
}
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') { if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
const stage = e.target.getStage(); const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition(); const pos = stage.getRelativePointerPosition();
if (pos) { if (pos) {
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }]; const newPoints = [
...points,
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
];
setPoints(newPoints); setPoints(newPoints);
// Auto-trigger inference after point selection // Auto-trigger inference after point selection
runInference(newPoints); runInference(newPoints);
@@ -212,6 +520,74 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
} }
}; };
const updatePolygonMask = useCallback((mask: Mask, nextPoints: CanvasPoint[], polygonIndex = 0) => {
if (nextPoints.length < 3) return;
const nextSegmentation = [...(mask.segmentation || [])];
nextSegmentation[polygonIndex] = nextPoints.flatMap((point) => [point.x, point.y]);
const bbox = segmentationBbox(nextSegmentation) || polygonBbox(nextPoints);
updateMask(mask.id, {
pathData: segmentationPath(nextSegmentation),
segmentation: nextSegmentation,
bbox,
area: segmentationArea(nextSegmentation),
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
});
}, [updateMask]);
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
const bbox = segmentationBbox(segmentation);
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: segmentationArea(segmentation),
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
}, []);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
const key = event.key.toLowerCase();
if ((event.metaKey || event.ctrlKey) && key === 'z') {
event.preventDefault();
if (event.shiftKey) redoMasks();
else undoMasks();
return;
}
if ((event.metaKey || event.ctrlKey) && key === 'y') {
event.preventDefault();
redoMasks();
return;
}
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) {
const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex);
if (currentPoints.length > 3) {
event.preventDefault();
const nextPoints = currentPoints.filter((_, index) => index !== selectedVertexIndex);
updatePolygonMask(selectedMask, nextPoints, selectedPolygonIndex);
setSelectedVertexIndex(null);
}
return;
}
if (effectiveTool !== POLYGON_TOOL) return;
if (event.key === 'Enter' && polygonPoints.length >= 3) {
event.preventDefault();
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
}
if (event.key === 'Escape') {
event.preventDefault();
setPolygonPoints([]);
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => { const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null; if (!boxStart || !boxCurrent) return null;
const x = Math.min(boxStart.x, boxCurrent.x); const x = Math.min(boxStart.x, boxCurrent.x);
@@ -221,6 +597,132 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
return { x, y, width, height }; return { x, y, width, height };
}, [boxStart, boxCurrent]); }, [boxStart, boxCurrent]);
const manualPreviewPath = React.useMemo(() => {
if (manualStart && manualCurrent) {
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
}
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
const previewPoints = [...polygonPoints, cursorPos];
return polygonPath(previewPoints);
}
return null;
}, [cursorPos, effectiveTool, manualCurrent, manualStart, polygonPoints]);
const handleSeedPointDragEnd = (mask: Mask, pointIndex: number, event: any) => {
const x = event.target.x();
const y = event.target.y();
const nextPoints = [...(mask.points || [])];
nextPoints[pointIndex] = [x, y];
updateMask(mask.id, {
points: nextPoints,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
});
};
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
event.cancelBubble = true;
if (BOOLEAN_TOOLS.has(effectiveTool)) {
setSelectedMaskIds((current) => (
current.includes(mask.id)
? current.filter((id) => id !== mask.id)
: [...current, mask.id]
));
setSelectedMaskId(mask.id);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(null);
return;
}
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(null);
};
const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => {
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
if (!currentPoints[vertexIndex]) return;
const nextPoints = currentPoints.map((point, index) => (
index === vertexIndex
? {
x: clamp(event.target.x(), 0, imageWidth),
y: clamp(event.target.y(), 0, imageHeight),
}
: point
));
setSelectedMaskId(mask.id);
setSelectedVertexIndex(vertexIndex);
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
const handleEdgeInsert = (mask: Mask, edgeIndex: number, event: any) => {
event.cancelBubble = true;
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
const start = currentPoints[edgeIndex];
const end = currentPoints[(edgeIndex + 1) % currentPoints.length];
if (!start || !end) return;
const inserted = { x: (start.x + end.x) / 2, y: (start.y + end.y) / 2 };
const nextPoints = [
...currentPoints.slice(0, edgeIndex + 1),
inserted,
...currentPoints.slice(edgeIndex + 1),
];
setSelectedMaskId(mask.id);
setSelectedVertexIndex(edgeIndex + 1);
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
const handleBooleanOperation = async () => {
if (!frame || booleanSelectedMasks.length < 2) return;
const primary = booleanSelectedMasks[0];
const primaryGeometry = maskToMultiPolygon(primary);
if (!primaryGeometry) return;
const clipGeometries = booleanSelectedMasks
.slice(1)
.map(maskToMultiPolygon)
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
if (clipGeometries.length === 0) return;
const resultGeometry = effectiveTool === 'area_merge'
? polygonClipping.union(primaryGeometry, ...clipGeometries)
: polygonClipping.difference(primaryGeometry, ...clipGeometries);
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
if (resultSegmentation.length === 0) {
const deleteIds = primary.annotationId ? [primary.annotationId] : [];
setMasks(masks.filter((mask) => mask.id !== primary.id));
if (deleteIds.length > 0) await onDeleteMaskAnnotations?.(deleteIds);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedVertexIndex(null);
return;
}
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
const secondaryIds = effectiveTool === 'area_merge'
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
: new Set<string>();
const secondaryAnnotationIds = effectiveTool === 'area_merge'
? booleanSelectedMasks
.slice(1)
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId))
: [];
setMasks(masks
.filter((mask) => !secondaryIds.has(mask.id))
.map((mask) => (mask.id === primary.id ? nextPrimary : mask)));
if (secondaryAnnotationIds.length > 0) await onDeleteMaskAnnotations?.(secondaryAnnotationIds);
setSelectedMaskId(primary.id);
setSelectedMaskIds([primary.id]);
setSelectedVertexIndex(null);
};
return ( return (
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm"> <div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
{isInferencing && ( {isInferencing && (
@@ -257,13 +759,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
{/* AI Returned Masks */} {/* AI Returned Masks */}
{frameMasks.map((mask) => ( {frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.5}> <Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
<Path {(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
data={mask.pathData} <Path
fill={mask.color} key={`${mask.id}-polygon-${polygonIndex}`}
stroke={mask.color} data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
strokeWidth={1 / scale} fill={mask.color}
/> stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group> </Group>
))} ))}
@@ -281,6 +788,86 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
/> />
)} )}
{/* Manual shape preview */}
{manualPreviewPath && (
<Path
data={manualPreviewPath}
fill="rgba(34, 211, 238, 0.12)"
stroke="#22d3ee"
strokeWidth={2 / scale}
dash={[5 / scale, 5 / scale]}
/>
)}
{polygonPoints.map((point, index) => (
<Circle
key={`poly-point-${index}`}
x={point.x}
y={point.y}
radius={4 / scale}
fill="#22d3ee"
stroke="#ffffff"
strokeWidth={1 / scale}
/>
))}
{/* Imported GT seed points / editable point regions */}
{frameMasks.flatMap((mask) => (mask.points || []).map(([x, y], index) => (
<Group key={`${mask.id}-seed-${index}`} x={x} y={y}>
<Circle
radius={5 / scale}
fill="#facc15"
stroke="#111827"
strokeWidth={2 / scale}
draggable
onDragEnd={(event: any) => handleSeedPointDragEnd(mask, index, event)}
/>
<Circle radius={1.5 / scale} fill="#111827" />
</Group>
)))}
{/* Polygon edge insertion handles */}
{selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null;
return (
<Circle
key={`${selectedMask.id}-edge-${selectedPolygonIndex}-${index}`}
x={(point.x + next.x) / 2}
y={(point.y + next.y) / 2}
radius={3.5 / scale}
fill="#22d3ee"
stroke="#111827"
strokeWidth={1.5 / scale}
onClick={(event: any) => handleEdgeInsert(selectedMask, index, event)}
onTap={(event: any) => handleEdgeInsert(selectedMask, index, event)}
/>
);
})}
{/* Polygon vertex editor */}
{selectedMask && selectedMaskPoints.map((point, index) => (
<Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x}
y={point.y}
radius={(selectedVertexIndex === index ? 6 : 4.5) / scale}
fill={selectedVertexIndex === index ? '#22d3ee' : '#ffffff'}
stroke={selectedMask.color}
strokeWidth={2 / scale}
draggable
onClick={(event: any) => {
event.cancelBubble = true;
setSelectedVertexIndex(index);
}}
onTap={(event: any) => {
event.cancelBubble = true;
setSelectedVertexIndex(index);
}}
onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)}
/>
))}
{/* AI Prompts Point Regions */} {/* AI Prompts Point Regions */}
{points.map((p, i) => ( {points.map((p, i) => (
<Group key={i} x={p.x} y={p.y}> <Group key={i} x={p.x} y={p.y}>
@@ -313,6 +900,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
{frameMasks.length > 0 && ( {frameMasks.length > 0 && (
<div className="absolute bottom-4 right-4 flex gap-2"> <div className="absolute bottom-4 right-4 flex gap-2">
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
<button
onClick={handleBooleanOperation}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
)}
{activeClass && ( {activeClass && (
<button <button
onClick={handleApplyActiveClass} onClick={handleApplyActiveClass}

View File

@@ -1,9 +1,12 @@
import { act, render, screen, waitFor } from '@testing-library/react'; import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest'; import { beforeEach, describe, expect, it, vi } from 'vitest';
import { Dashboard } from './Dashboard'; import { Dashboard } from './Dashboard';
const apiMock = vi.hoisted(() => ({ const apiMock = vi.hoisted(() => ({
getDashboardOverview: vi.fn(), getDashboardOverview: vi.fn(),
cancelTask: vi.fn(),
retryTask: vi.fn(),
getTask: vi.fn(),
})); }));
const wsMock = vi.hoisted(() => { const wsMock = vi.hoisted(() => {
@@ -31,6 +34,9 @@ vi.mock('../lib/websocket', () => ({
vi.mock('../lib/api', () => ({ vi.mock('../lib/api', () => ({
getDashboardOverview: apiMock.getDashboardOverview, getDashboardOverview: apiMock.getDashboardOverview,
cancelTask: apiMock.cancelTask,
retryTask: apiMock.retryTask,
getTask: apiMock.getTask,
})); }));
describe('Dashboard', () => { describe('Dashboard', () => {
@@ -55,6 +61,8 @@ describe('Dashboard', () => {
name: '真实项目.mp4', name: '真实项目.mp4',
progress: 60, progress: 60,
status: 'pending', status: 'pending',
raw_status: 'running',
error: null,
frame_count: 10, frame_count: 10,
updated_at: '2026-05-01T00:00:00Z', updated_at: '2026-05-01T00:00:00Z',
}, },
@@ -112,4 +120,100 @@ describe('Dashboard', () => {
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument()); await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument(); expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
}); });
it('cancels, retries, and opens task failure details', async () => {
apiMock.getDashboardOverview.mockResolvedValueOnce({
summary: {
project_count: 1,
parsing_task_count: 1,
annotation_count: 0,
frame_count: 0,
template_count: 0,
system_load_percent: 5,
},
tasks: [
{
id: 'task-10',
task_id: 10,
project_id: 1,
name: 'running.mp4',
progress: 30,
status: '正在下载媒体文件',
raw_status: 'running',
error: null,
frame_count: 0,
updated_at: '2026-05-01T00:00:00Z',
},
{
id: 'task-11',
task_id: 11,
project_id: 1,
name: 'failed.mp4',
progress: 100,
status: '解析失败',
raw_status: 'failed',
error: 'ffmpeg failed',
frame_count: 0,
updated_at: '2026-05-01T00:01:00Z',
},
],
activity: [],
});
apiMock.cancelTask.mockResolvedValueOnce({
id: 10,
task_type: 'parse_video',
status: 'cancelled',
progress: 100,
message: '任务已取消',
project_id: 1,
error: 'Cancelled by user',
result: null,
payload: { source_type: 'video' },
created_at: 'created',
updated_at: 'updated',
});
apiMock.retryTask.mockResolvedValueOnce({
id: 12,
task_type: 'parse_video',
status: 'queued',
progress: 0,
message: '重试任务已入队(源任务 #11',
project_id: 1,
error: null,
result: null,
payload: { source_type: 'video', retry_of: 11 },
created_at: 'created',
updated_at: 'updated',
});
apiMock.getTask.mockResolvedValueOnce({
id: 11,
task_type: 'parse_video',
status: 'failed',
progress: 100,
message: '解析失败',
project_id: 1,
celery_task_id: 'celery-11',
payload: { source_type: 'video' },
result: null,
error: 'ffmpeg failed',
created_at: 'created',
started_at: 'started',
finished_at: 'finished',
updated_at: 'updated',
});
render(<Dashboard />);
await screen.findByText('running.mp4');
fireEvent.click(screen.getByRole('button', { name: '取消' }));
await waitFor(() => expect(apiMock.cancelTask).toHaveBeenCalledWith(10));
fireEvent.click(screen.getAllByRole('button', { name: '详情' })[1]);
await waitFor(() => expect(apiMock.getTask).toHaveBeenCalledWith(11));
expect(await screen.findByText('任务详情 #11')).toBeInTheDocument();
expect(screen.getByText('ffmpeg failed')).toBeInTheDocument();
fireEvent.click(screen.getAllByRole('button', { name: '重试' })[1]);
await waitFor(() => expect(apiMock.retryTask).toHaveBeenCalledWith(11));
});
}); });

View File

@@ -1,8 +1,17 @@
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react'; import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket'; import { progressWS, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api'; import {
cancelTask,
getDashboardOverview,
getTask,
retryTask,
type DashboardActivity,
type DashboardOverview,
type DashboardTask,
type ProcessingTask,
} from '../lib/api';
const emptySummary: DashboardOverview['summary'] = { const emptySummary: DashboardOverview['summary'] = {
project_count: 0, project_count: 0,
@@ -20,6 +29,29 @@ export function Dashboard() {
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]); const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState(''); const [loadError, setLoadError] = useState('');
const [selectedTask, setSelectedTask] = useState<ProcessingTask | null>(null);
const [taskActionMessage, setTaskActionMessage] = useState('');
const [busyTaskId, setBusyTaskId] = useState<string | null>(null);
const taskFromProcessingTask = (task: ProcessingTask, name = `任务 ${task.id}`): DashboardTask => ({
id: `task-${task.id}`,
task_id: task.id,
project_id: task.project_id ?? 0,
name,
progress: task.progress,
status: task.message || task.status,
raw_status: task.status,
error: task.error,
frame_count: Number(task.result?.frames_extracted || 0),
updated_at: task.updated_at,
});
const prependActivity = (message: string, project = '系统') => {
setActivityLog((prev) => [
{ id: `task-action-${Date.now()}`, kind: 'task', time: new Date().toISOString(), message, project },
...prev.slice(0, 9),
]);
};
useEffect(() => { useEffect(() => {
let cancelled = false; let cancelled = false;
@@ -90,6 +122,8 @@ export function Dashboard() {
name: taskTitle(data), name: taskTitle(data),
progress: data.progress ?? 0, progress: data.progress ?? 0,
status: data.status ?? '处理中', status: data.status ?? '处理中',
raw_status: 'running',
error: data.error,
frame_count: 0, frame_count: 0,
updated_at: new Date().toISOString(), updated_at: new Date().toISOString(),
}, },
@@ -100,7 +134,7 @@ export function Dashboard() {
if (data.type === 'complete' && data.taskId) { if (data.type === 'complete' && data.taskId) {
setTasks((prev) => setTasks((prev) =>
prev.map((t) => prev.map((t) =>
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t t.id === data.taskId ? { ...t, progress: 100, status: '已完成', raw_status: 'success' } : t
) )
); );
setActivityLog((prev) => [ setActivityLog((prev) => [
@@ -109,10 +143,26 @@ export function Dashboard() {
]); ]);
} }
if (data.type === 'cancelled' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId
? { ...t, progress: 100, status: data.message || '任务已取消', raw_status: 'cancelled', error: data.error }
: t
)
);
setActivityLog((prev) => [
{ id: `ws-cancelled-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `任务已取消: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9),
]);
}
if (data.type === 'error' && data.taskId) { if (data.type === 'error' && data.taskId) {
setTasks((prev) => setTasks((prev) =>
prev.map((t) => prev.map((t) =>
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t t.id === data.taskId
? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}`, raw_status: 'failed', error: data.error }
: t
) )
); );
setActivityLog((prev) => [ setActivityLog((prev) => [
@@ -160,6 +210,65 @@ export function Dashboard() {
}); });
} }
const taskRawStatus = (task: DashboardTask): string => task.raw_status || (
task.status.includes('取消') ? 'cancelled'
: task.status.includes('失败') || task.status.includes('错误') ? 'failed'
: task.progress >= 100 ? 'success'
: 'running'
);
const canCancel = (task: DashboardTask): boolean => ['queued', 'running'].includes(taskRawStatus(task)) && Boolean(task.task_id);
const canRetry = (task: DashboardTask): boolean => ['failed', 'cancelled'].includes(taskRawStatus(task)) && Boolean(task.task_id);
const handleCancelTask = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
const updated = await cancelTask(task.task_id);
setTasks((prev) => prev.map((item) => (
item.id === task.id ? taskFromProcessingTask(updated, task.name) : item
)));
prependActivity(`任务已取消 #${updated.id}`, task.name);
} catch (err) {
console.error('Cancel task failed:', err);
setTaskActionMessage('任务取消失败,请检查后端服务');
} finally {
setBusyTaskId(null);
}
};
const handleRetryTask = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
const retried = await retryTask(task.task_id);
const dashboardTask = taskFromProcessingTask(retried, task.name);
setTasks((prev) => [dashboardTask, ...prev.filter((item) => item.id !== dashboardTask.id)]);
prependActivity(`重试任务已入队 #${retried.id}`, task.name);
} catch (err) {
console.error('Retry task failed:', err);
setTaskActionMessage('任务重试失败,请检查后端服务');
} finally {
setBusyTaskId(null);
}
};
const handleOpenTaskDetail = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
setSelectedTask(await getTask(task.task_id));
} catch (err) {
console.error('Load task detail failed:', err);
setTaskActionMessage('失败详情加载失败');
} finally {
setBusyTaskId(null);
}
};
return ( return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]"> <div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
<header className="mb-8"> <header className="mb-8">
@@ -177,6 +286,7 @@ export function Dashboard() {
</div> </div>
<p className="text-gray-400 text-sm mt-1"></p> <p className="text-gray-400 text-sm mt-1"></p>
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>} {loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
{taskActionMessage && <p className="text-amber-400 text-xs mt-2">{taskActionMessage}</p>}
</header> </header>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8"> <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
@@ -213,16 +323,47 @@ export function Dashboard() {
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} /> <div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
</div> </div>
<div className="text-xs text-gray-500 flex items-center gap-2"> <div className="text-xs text-gray-500 flex items-center gap-2">
{task.status === '已完成' || task.progress >= 100 ? ( {taskRawStatus(task) === 'success' || task.status === '已完成' ? (
<CheckCircle2 size={12} className="text-emerald-400" /> <CheckCircle2 size={12} className="text-emerald-400" />
) : task.status.includes('错误') ? ( ) : taskRawStatus(task) === 'failed' ? (
<span className="text-red-400"></span> <AlertTriangle size={12} className="text-red-400" />
) : taskRawStatus(task) === 'cancelled' ? (
<XCircle size={12} className="text-amber-400" />
) : ( ) : (
<Loader2 size={12} className="text-cyan-400 animate-spin" /> <Loader2 size={12} className="text-cyan-400 animate-spin" />
)} )}
{task.status} {task.status}
<span className="text-gray-600">: {task.frame_count}</span> <span className="text-gray-600">: {task.frame_count}</span>
</div> </div>
<div className="mt-3 flex flex-wrap items-center gap-2">
{canCancel(task) && (
<button
onClick={() => handleCancelTask(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-red-500/20 bg-red-500/10 px-2 py-1 text-[11px] text-red-300 hover:bg-red-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
<XCircle size={12} />
</button>
)}
{canRetry(task) && (
<button
onClick={() => handleRetryTask(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-cyan-500/20 bg-cyan-500/10 px-2 py-1 text-[11px] text-cyan-300 hover:bg-cyan-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
<RotateCcw size={12} />
</button>
)}
{task.task_id && (
<button
onClick={() => handleOpenTaskDetail(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-white/10 bg-white/5 px-2 py-1 text-[11px] text-gray-300 hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
>
<Info size={12} />
</button>
)}
</div>
</div> </div>
))} ))}
{!isLoading && tasks.length === 0 && ( {!isLoading && tasks.length === 0 && (
@@ -253,6 +394,46 @@ export function Dashboard() {
</div> </div>
</div> </div>
</div> </div>
{selectedTask && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70 px-4">
<div className="w-full max-w-2xl rounded-lg border border-white/10 bg-[#111] p-5 shadow-2xl">
<div className="flex items-center justify-between gap-3 border-b border-white/10 pb-3">
<div>
<h3 className="text-sm font-semibold text-white"> #{selectedTask.id}</h3>
<p className="mt-1 text-xs text-gray-500">{selectedTask.message || selectedTask.status}</p>
</div>
<button
onClick={() => setSelectedTask(null)}
className="rounded border border-white/10 bg-white/5 px-2 py-1 text-xs text-gray-300 hover:bg-white/10"
>
</button>
</div>
<div className="mt-4 grid grid-cols-2 gap-3 text-xs text-gray-400">
<div>: <span className="text-gray-200">{selectedTask.status}</span></div>
<div>: <span className="text-gray-200">{selectedTask.progress}%</span></div>
<div> ID: <span className="text-gray-200">{selectedTask.project_id ?? '-'}</span></div>
<div>Celery ID: <span className="text-gray-200">{selectedTask.celery_task_id || '-'}</span></div>
<div>: <span className="text-gray-200">{selectedTask.created_at}</span></div>
<div>: <span className="text-gray-200">{selectedTask.finished_at || '-'}</span></div>
</div>
{selectedTask.error && (
<div className="mt-4 rounded border border-red-500/20 bg-red-500/10 p-3 text-xs text-red-200">
{selectedTask.error}
</div>
)}
<div className="mt-4 grid gap-3 md:grid-cols-2">
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
{JSON.stringify(selectedTask.payload || {}, null, 2)}
</pre>
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
{JSON.stringify(selectedTask.result || {}, null, 2)}
</pre>
</div>
</div>
</div>
)}
</div> </div>
); );
} }

View File

@@ -3,18 +3,31 @@ import { describe, expect, it, vi } from 'vitest';
import { ToolsPalette } from './ToolsPalette'; import { ToolsPalette } from './ToolsPalette';
describe('ToolsPalette', () => { describe('ToolsPalette', () => {
it('switches tools and exposes UI-only placeholder buttons', () => { it('switches tools and dispatches undo/redo actions when available', () => {
const setActiveTool = vi.fn(); const setActiveTool = vi.fn();
const onUndo = vi.fn();
const onRedo = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />); render(
<ToolsPalette
activeTool="move"
setActiveTool={setActiveTool}
onUndo={onUndo}
onRedo={onRedo}
canUndo
canRedo
/>,
);
fireEvent.click(screen.getByTitle('创建多边形 (P)')); fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)')); fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon'); expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos'); expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument(); expect(onUndo).toHaveBeenCalled();
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument(); expect(onRedo).toHaveBeenCalled();
}); });
it('switches to SAM trigger and calls the AI navigation hook', () => { it('switches to SAM trigger and calls the AI navigation hook', () => {

View File

@@ -6,9 +6,21 @@ interface ToolsPaletteProps {
activeTool: string; activeTool: string;
setActiveTool: (tool: string) => void; setActiveTool: (tool: string) => void;
onTriggerAI?: () => void; onTriggerAI?: () => void;
onUndo?: () => void;
onRedo?: () => void;
canUndo?: boolean;
canRedo?: boolean;
} }
export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPaletteProps) { export function ToolsPalette({
activeTool,
setActiveTool,
onTriggerAI,
onUndo,
onRedo,
canUndo = false,
canRedo = false,
}: ToolsPaletteProps) {
const tools = [ const tools = [
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' }, { id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' }, { id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
@@ -91,10 +103,20 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
<div className="w-full h-px bg-white/10 my-1" /> <div className="w-full h-px bg-white/10 my-1" />
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)"> <button
onClick={onUndo}
disabled={!canUndo}
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="撤销操作 (Ctrl+Z)"
>
<Undo size={18} /> <Undo size={18} />
</button> </button>
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)"> <button
onClick={onRedo}
disabled={!canRedo}
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="重做操作 (Ctrl+Shift+Z)"
>
<Redo size={18} /> <Redo size={18} />
</button> </button>

View File

@@ -14,6 +14,8 @@ const apiMock = vi.hoisted(() => ({
updateAnnotation: vi.fn(), updateAnnotation: vi.fn(),
deleteAnnotation: vi.fn(), deleteAnnotation: vi.fn(),
exportCoco: vi.fn(), exportCoco: vi.fn(),
exportMasks: vi.fn(),
importGtMask: vi.fn(),
annotationToMask: vi.fn(), annotationToMask: vi.fn(),
buildAnnotationPayload: vi.fn(), buildAnnotationPayload: vi.fn(),
getAiModelStatus: vi.fn(), getAiModelStatus: vi.fn(),
@@ -29,6 +31,8 @@ vi.mock('../lib/api', () => ({
updateAnnotation: apiMock.updateAnnotation, updateAnnotation: apiMock.updateAnnotation,
deleteAnnotation: apiMock.deleteAnnotation, deleteAnnotation: apiMock.deleteAnnotation,
exportCoco: apiMock.exportCoco, exportCoco: apiMock.exportCoco,
exportMasks: apiMock.exportMasks,
importGtMask: apiMock.importGtMask,
annotationToMask: apiMock.annotationToMask, annotationToMask: apiMock.annotationToMask,
buildAnnotationPayload: apiMock.buildAnnotationPayload, buildAnnotationPayload: apiMock.buildAnnotationPayload,
getAiModelStatus: apiMock.getAiModelStatus, getAiModelStatus: apiMock.getAiModelStatus,
@@ -256,4 +260,64 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled()); await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportCoco).toHaveBeenCalledWith('1'); expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
}); });
it('auto-saves pending masks before exporting PNG masks', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
});
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.importGtMask.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
apiMock.getProjectAnnotations
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-88',
annotationId: '88',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'GT Mask',
color: '#22c55e',
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
fireEvent.change(fileInput, { target: { files: [file] } });
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
]));
});
}); });

View File

@@ -5,10 +5,12 @@ import {
buildAnnotationPayload, buildAnnotationPayload,
deleteAnnotation, deleteAnnotation,
exportCoco, exportCoco,
exportMasks,
getProjectAnnotations, getProjectAnnotations,
getProjectFrames, getProjectFrames,
getTask, getTask,
getTemplates, getTemplates,
importGtMask,
parseMedia, parseMedia,
saveAnnotation, saveAnnotation,
updateAnnotation, updateAnnotation,
@@ -25,18 +27,24 @@ function sleep(ms: number) {
} }
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
const activeTool = useStore((state) => state.activeTool); const activeTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool); const setActiveTool = useStore((state) => state.setActiveTool);
const currentProject = useStore((state) => state.currentProject); const currentProject = useStore((state) => state.currentProject);
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const activeTemplateId = useStore((state) => state.activeTemplateId); const activeTemplateId = useStore((state) => state.activeTemplateId);
const setFrames = useStore((state) => state.setFrames); const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks); const setMasks = useStore((state) => state.setMasks);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const [isSaving, setIsSaving] = useState(false); const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false); const [isExporting, setIsExporting] = useState(false);
const [isImportingGt, setIsImportingGt] = useState(false);
const [statusMessage, setStatusMessage] = useState(''); const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => { const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
@@ -216,6 +224,18 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
} }
}, [currentFrame, masks, setMasks]); }, [currentFrame, masks, setMasks]);
const handleDeleteMaskAnnotations = useCallback(async (annotationIds: string[]) => {
if (annotationIds.length === 0) return;
try {
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
setStatusMessage(`已删除 ${annotationIds.length} 个被合并标注`);
} catch (err) {
console.error('Delete merged annotations failed:', err);
setStatusMessage('合并后删除原标注失败,请检查后端服务');
throw err;
}
}, []);
const handleSave = async () => { const handleSave = async () => {
try { try {
await savePendingAnnotations(); await savePendingAnnotations();
@@ -248,6 +268,52 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
} }
}; };
const downloadBlob = (blob: Blob, filename: string) => {
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
link.remove();
URL.revokeObjectURL(url);
};
const handleExportMasks = async () => {
if (!currentProject?.id) return;
setIsExporting(true);
setStatusMessage('正在准备导出语义 Mask ZIP...');
try {
await savePendingAnnotations({ silent: true });
const blob = await exportMasks(currentProject.id);
downloadBlob(blob, `project_${currentProject.id}_masks.zip`);
setStatusMessage('PNG Mask ZIP 已导出');
} catch (err) {
console.error('Mask export failed:', err);
setStatusMessage('Mask 导出失败,请检查后端服务');
} finally {
setIsExporting(false);
}
};
const handleImportGtMask = async (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0];
if (!file || !currentProject?.id || !currentFrame?.id) return;
setIsImportingGt(true);
setStatusMessage('正在导入 GT Mask...');
try {
const imported = await importGtMask(file, currentProject.id, currentFrame.id);
await hydrateSavedAnnotations(currentProject.id, frames);
setStatusMessage(`已导入 ${imported.length} 个 GT 区域`);
} catch (err) {
console.error('GT mask import failed:', err);
setStatusMessage('GT Mask 导入失败,请检查文件或后端服务');
} finally {
setIsImportingGt(false);
event.target.value = '';
}
};
return ( return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]"> <div className="w-full h-full flex flex-col bg-[#0a0a0a]">
{/* Top Header / Status bar */} {/* Top Header / Status bar */}
@@ -264,6 +330,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
</span> </span>
)} )}
<ModelStatusBadge /> <ModelStatusBadge />
<input
ref={gtMaskInputRef}
type="file"
accept="image/png,image/jpeg,image/bmp,image/tiff"
className="hidden"
onChange={handleImportGtMask}
/>
<button
onClick={() => gtMaskInputRef.current?.click()}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isImportingGt ? '导入中...' : '导入 GT Mask'}
</button>
<button
onClick={handleExportMasks}
disabled={!currentProject?.id || isExporting || isSaving}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
</button>
<button <button
onClick={handleExport} onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving} disabled={!currentProject?.id || isExporting || isSaving}
@@ -283,11 +370,24 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
{/* Main Workspace Area */} {/* Main Workspace Area */}
<div className="flex-1 flex overflow-hidden"> <div className="flex-1 flex overflow-hidden">
<ToolsPalette activeTool={activeTool} setActiveTool={setActiveTool} onTriggerAI={onNavigateToAI} /> <ToolsPalette
activeTool={activeTool}
setActiveTool={setActiveTool}
onTriggerAI={onNavigateToAI}
onUndo={undoMasks}
onRedo={redoMasks}
canUndo={maskHistory.length > 0}
canRedo={maskFuture.length > 0}
/>
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden"> <div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm"> <div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} /> <CanvasArea
activeTool={activeTool}
frame={currentFrame}
onClearMasks={handleClearCurrentFrameMasks}
onDeleteMaskAnnotations={handleDeleteMaskAnnotations}
/>
</div> </div>
</div> </div>

View File

@@ -101,6 +101,17 @@ describe('api client contracts', () => {
}); });
}); });
it('exports PNG masks from the backend route shape', async () => {
const { exportMasks } = await import('./api');
const blob = new Blob(['zip'], { type: 'application/zip' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportMasks('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/masks', {
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => { it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api'); const { getDashboardOverview } = await import('./api');
const overview = { const overview = {
@@ -125,8 +136,8 @@ describe('api client contracts', () => {
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview'); expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
}); });
it('queues media parsing and reads processing task status', async () => { it('queues media parsing and manages processing task lifecycle', async () => {
const { getTask, parseMedia } = await import('./api'); const { cancelTask, getTask, parseMedia, retryTask } = await import('./api');
const task = { const task = {
id: 12, id: 12,
task_type: 'parse_video', task_type: 'parse_video',
@@ -145,6 +156,8 @@ describe('api client contracts', () => {
}; };
axiosMock.client.post.mockResolvedValueOnce({ data: task }); axiosMock.client.post.mockResolvedValueOnce({ data: task });
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } }); axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
await expect(parseMedia('9')).resolves.toEqual(task); await expect(parseMedia('9')).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, { expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
@@ -153,6 +166,12 @@ describe('api client contracts', () => {
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 })); await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12'); expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
await expect(cancelTask(12)).resolves.toEqual(expect.objectContaining({ status: 'cancelled' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/cancel');
await expect(retryTask(12)).resolves.toEqual(expect.objectContaining({ id: 13, status: 'queued' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/retry');
}); });
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => { it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
@@ -204,6 +223,25 @@ describe('api client contracts', () => {
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1'); expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
}); });
it('imports GT masks through multipart form data', async () => {
const { importGtMask } = await import('./api');
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith(
'/api/ai/import-gt-mask',
expect.any(FormData),
{ headers: { 'Content-Type': 'multipart/form-data' } },
);
const form = axiosMock.client.post.mock.calls.at(-1)?.[1] as FormData;
expect(form.get('file')).toBe(file);
expect(form.get('project_id')).toBe('9');
expect(form.get('frame_id')).toBe('5');
expect(form.get('template_id')).toBe('2');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => { it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
const { annotationToMask, buildAnnotationPayload } = await import('./api'); const { annotationToMask, buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 }; const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
@@ -244,7 +282,7 @@ describe('api client contracts', () => {
color: '#06b6d4', color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 }, class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
}, },
points: null, points: [[0.5, 0.5]],
bbox: null, bbox: null,
created_at: 'created', created_at: 'created',
updated_at: 'updated', updated_at: 'updated',
@@ -261,10 +299,28 @@ describe('api client contracts', () => {
saveStatus: 'saved', saveStatus: 'saved',
saved: true, saved: true,
pathData: 'M 10 10 L 90 10 L 90 40 Z', pathData: 'M 10 10 L 90 10 L 90 40 Z',
points: [[50, 25]],
bbox: [10, 10, 80, 30], bbox: [10, 10, 80, 30],
})); }));
}); });
it('preserves editable point regions in annotation payloads', async () => {
const { buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
expect(buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'GT Mask',
color: '#22c55e',
segmentation: [[10, 10, 90, 10, 90, 40]],
points: [[50, 25]],
}, frame)).toEqual(expect.objectContaining({
points: [[0.5, 0.5]],
}));
});
it('normalizes positive and negative point prompts for AI prediction', async () => { it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api'); const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ axiosMock.client.post.mockResolvedValueOnce({
@@ -341,6 +397,38 @@ describe('api client contracts', () => {
}); });
}); });
it('passes AI post-processing options to prediction endpoint', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '7',
imageWidth: 640,
imageHeight: 360,
points: [{ x: 320, y: 180, type: 'pos' }],
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 7,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5]],
labels: [1],
},
model: 'sam2',
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
});
it('loads AI model and GPU runtime status', async () => { it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api'); const { getAiModelStatus } = await import('./api');
const status = { const status = {

View File

@@ -197,6 +197,16 @@ export async function getTask(taskId: string | number): Promise<ProcessingTask>
return response.data; return response.data;
} }
export async function cancelTask(taskId: string | number): Promise<ProcessingTask> {
const response = await apiClient.post(`/api/tasks/${taskId}/cancel`);
return response.data;
}
export async function retryTask(taskId: string | number): Promise<ProcessingTask> {
const response = await apiClient.post(`/api/tasks/${taskId}/retry`);
return response.data;
}
interface PredictMaskPayload { interface PredictMaskPayload {
imageId: string; imageId: string;
imageWidth: number; imageWidth: number;
@@ -205,6 +215,12 @@ interface PredictMaskPayload {
points?: { x: number; y: number; type: 'pos' | 'neg' }[]; points?: { x: number; y: number; type: 'pos' | 'neg' }[];
box?: { x1: number; y1: number; x2: number; y2: number }; box?: { x1: number; y1: number; x2: number; y2: number };
text?: string; text?: string;
options?: {
crop_to_prompt?: boolean;
auto_filter_background?: boolean;
min_score?: number;
crop_margin?: number;
};
} }
interface PredictMaskResult { interface PredictMaskResult {
@@ -234,6 +250,8 @@ export interface AiModelStatus {
python_ok: boolean; python_ok: boolean;
torch_ok: boolean; torch_ok: boolean;
cuda_required: boolean; cuda_required: boolean;
external_available?: boolean;
external_python?: string | null;
} }
export interface AiRuntimeStatus { export interface AiRuntimeStatus {
@@ -301,6 +319,8 @@ export interface DashboardTask {
name: string; name: string;
progress: number; progress: number;
status: string; status: string;
raw_status?: string;
error?: string | null;
frame_count: number; frame_count: number;
updated_at: string | null; updated_at: string | null;
} }
@@ -397,7 +417,7 @@ export function buildAnnotationPayload(
} }
: undefined; : undefined;
return { const payload: SaveAnnotationPayload = {
project_id: Number(projectId), project_id: Number(projectId),
frame_id: Number(frame.id), frame_id: Number(frame.id),
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined, template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
@@ -416,6 +436,15 @@ export function buildAnnotationPayload(
] ]
: undefined, : undefined,
}; };
if (mask.points) {
payload.points = mask.points.map(([x, y]) => [
clamp01(x / Math.max(frame.width, 1)),
clamp01(y / Math.max(frame.height, 1)),
]);
}
return payload;
} }
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null { export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
@@ -438,6 +467,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`, label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4', color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])), segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
points: annotation.points?.map(([x, y]) => [x * frame.width, y * frame.height]),
bbox, bbox,
area: bbox[2] * bbox[3], area: bbox[2] * bbox[3],
}; };
@@ -471,6 +501,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
prompt_type, prompt_type,
prompt_data, prompt_data,
model: payload.model || 'sam2', model: payload.model || 'sam2',
...(payload.options ? { options: payload.options } : {}),
}); });
const polygons: number[][][] = response.data.polygons || []; const polygons: number[][][] = response.data.polygons || [];
@@ -523,6 +554,23 @@ export async function deleteAnnotation(annotationId: string): Promise<void> {
await apiClient.delete(`/api/ai/annotations/${annotationId}`); await apiClient.delete(`/api/ai/annotations/${annotationId}`);
} }
export async function importGtMask(
file: File,
projectId: string,
frameId: string,
templateId?: string | null,
): Promise<SavedAnnotation[]> {
const formData = new FormData();
formData.append('file', file);
formData.append('project_id', projectId);
formData.append('frame_id', frameId);
if (templateId) formData.append('template_id', templateId);
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
});
return response.data;
}
export async function getDashboardOverview(): Promise<DashboardOverview> { export async function getDashboardOverview(): Promise<DashboardOverview> {
const response = await apiClient.get('/api/dashboard/overview'); const response = await apiClient.get('/api/dashboard/overview');
return response.data; return response.data;
@@ -536,4 +584,11 @@ export async function exportCoco(projectId: string): Promise<Blob> {
return response.data; return response.data;
} }
export async function exportMasks(projectId: string): Promise<Blob> {
const response = await apiClient.get(`/api/export/${projectId}/masks`, {
responseType: 'blob',
});
return response.data;
}
export default apiClient; export default apiClient;

View File

@@ -3,7 +3,7 @@ import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void; type ProgressCallback = (data: ProgressMessage) => void;
interface ProgressMessage { interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete'; type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
taskId?: string; taskId?: string;
task_id?: number; task_id?: number;
project_id?: number; project_id?: number;

View File

@@ -53,4 +53,15 @@ describe('useStore', () => {
expect(useStore.getState().masks).toEqual([]); expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().templates).toEqual([]); expect(useStore.getState().templates).toEqual([]);
}); });
it('keeps undo and redo history for mask edits', () => {
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask 1', color: '#fff' });
useStore.getState().addMask({ id: 'm2', frameId: 'f1', pathData: 'M 1 1 Z', label: 'mask 2', color: '#000' });
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
useStore.getState().undoMasks();
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1']);
useStore.getState().redoMasks();
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
});
}); });

View File

@@ -56,8 +56,10 @@ export interface Mask {
color: string; color: string;
opacity?: number; opacity?: number;
segmentation?: number[][]; segmentation?: number[][];
points?: number[][];
bbox?: [number, number, number, number]; bbox?: [number, number, number, number];
area?: number; area?: number;
metadata?: Record<string, unknown>;
} }
export interface Template { export interface Template {
@@ -110,6 +112,8 @@ export interface AppState {
currentFrameIndex: number; currentFrameIndex: number;
annotations: Annotation[]; annotations: Annotation[];
masks: Mask[]; masks: Mask[];
maskHistory: Mask[][];
maskFuture: Mask[][];
setActiveModule: (module: string) => void; setActiveModule: (module: string) => void;
setActiveTool: (tool: string) => void; setActiveTool: (tool: string) => void;
setAiModel: (model: AiModelId) => void; setAiModel: (model: AiModelId) => void;
@@ -120,6 +124,8 @@ export interface AppState {
updateMask: (id: string, updates: Partial<Mask>) => void; updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void; setMasks: (masks: Mask[]) => void;
clearMasks: () => void; clearMasks: () => void;
undoMasks: () => void;
redoMasks: () => void;
removeAnnotation: (id: string) => void; removeAnnotation: (id: string) => void;
// Templates // Templates
@@ -161,6 +167,8 @@ export const useStore = create<AppState>((set) => ({
frames: [], frames: [],
annotations: [], annotations: [],
masks: [], masks: [],
maskHistory: [],
maskFuture: [],
activeTemplateId: null, activeTemplateId: null,
activeClassId: null, activeClassId: null,
activeClass: null, activeClass: null,
@@ -187,6 +195,8 @@ export const useStore = create<AppState>((set) => ({
currentFrameIndex: 0, currentFrameIndex: 0,
annotations: [], annotations: [],
masks: [], masks: [],
maskHistory: [],
maskFuture: [],
setActiveModule: (activeModule: string) => set({ activeModule }), setActiveModule: (activeModule: string) => set({ activeModule }),
setActiveTool: (activeTool: string) => set({ activeTool }), setActiveTool: (activeTool: string) => set({ activeTool }),
setAiModel: (aiModel: AiModelId) => set({ aiModel }), setAiModel: (aiModel: AiModelId) => set({ aiModel }),
@@ -195,13 +205,54 @@ export const useStore = create<AppState>((set) => ({
addAnnotation: (annotation: Annotation) => addAnnotation: (annotation: Annotation) =>
set((state) => ({ annotations: [...state.annotations, annotation] })), set((state) => ({ annotations: [...state.annotations, annotation] })),
addMask: (mask: Mask) => addMask: (mask: Mask) =>
set((state) => ({ masks: [...state.masks, mask] })), set((state) => ({
masks: [...state.masks, mask],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),
updateMask: (id: string, updates: Partial<Mask>) => updateMask: (id: string, updates: Partial<Mask>) =>
set((state) => ({ set((state) => ({
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)), masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})), })),
setMasks: (masks: Mask[]) => set({ masks }), setMasks: (masks: Mask[]) =>
clearMasks: () => set({ masks: [] }), set((state) => {
const isInitialHydration = state.masks.length === 0
&& state.maskHistory.length === 0
&& state.maskFuture.length === 0;
return {
masks,
maskHistory: isInitialHydration ? [] : [...state.maskHistory, state.masks],
maskFuture: [],
};
}),
clearMasks: () =>
set((state) => ({
masks: [],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),
undoMasks: () =>
set((state) => {
if (state.maskHistory.length === 0) return state;
const previous = state.maskHistory[state.maskHistory.length - 1];
return {
masks: previous,
maskHistory: state.maskHistory.slice(0, -1),
maskFuture: [state.masks, ...state.maskFuture],
};
}),
redoMasks: () =>
set((state) => {
if (state.maskFuture.length === 0) return state;
const [next, ...rest] = state.maskFuture;
return {
masks: next,
maskHistory: [...state.maskHistory, state.masks],
maskFuture: rest,
};
}),
removeAnnotation: (id: string) => removeAnnotation: (id: string) =>
set((state) => ({ set((state) => ({
annotations: state.annotations.filter((a) => a.id !== id), annotations: state.annotations.filter((a) => a.id !== id),

View File

@@ -32,24 +32,69 @@ function makeStageEvent(x = 120, y = 80) {
} }
vi.mock('react-konva', () => ({ vi.mock('react-konva', () => ({
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => ( Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => {
const coords = (event: React.MouseEvent<HTMLDivElement>, fallbackX: number, fallbackY: number) => ({
x: event.clientX || fallbackX,
y: event.clientY || fallbackY,
});
return (
<div <div
data-testid="konva-stage" data-testid="konva-stage"
onClick={() => onClick?.(makeStageEvent())} onClick={(event) => {
onMouseDown={() => onMouseDown?.(makeStageEvent())} const point = coords(event, 120, 80);
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))} onClick?.(makeStageEvent(point.x, point.y));
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))} }}
onMouseDown={(event) => {
const point = coords(event, 120, 80);
onMouseDown?.(makeStageEvent(point.x, point.y));
}}
onMouseUp={(event) => {
const point = coords(event, 260, 200);
onMouseUp?.(makeStageEvent(point.x, point.y));
}}
onMouseMove={(event) => {
const point = coords(event, 180, 120);
onMouseMove?.(makeStageEvent(point.x, point.y));
}}
onWheel={() => onWheel?.(makeStageEvent())} onWheel={() => onWheel?.(makeStageEvent())}
> >
{children} {children}
</div> </div>
), );
},
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>, Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>, Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />, Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />, Circle: (props: any) => (
<span
data-testid="konva-circle"
data-fill={props.fill}
data-x={props.x}
data-y={props.y}
onClick={() => props.onClick?.({ cancelBubble: false })}
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
y: () => event.clientY || props.y || 0,
},
})}
onDragEnd={(event: React.DragEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
y: () => event.clientY || props.y || 0,
},
})}
/>
),
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />, Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />, Path: (props: any) => (
<span
data-testid="konva-path"
data-path={props.data}
data-fill={props.fill}
onClick={() => props.onClick?.({ cancelBubble: false })}
/>
),
})); }));
vi.mock('use-image', () => ({ vi.mock('use-image', () => ({

View File

@@ -13,6 +13,8 @@ export function resetStore() {
currentFrameIndex: 0, currentFrameIndex: 0,
annotations: [], annotations: [],
masks: [], masks: [],
maskHistory: [],
maskFuture: [],
templates: [], templates: [],
activeTemplateId: null, activeTemplateId: null,
activeClassId: null, activeClassId: null,