feat: 建立 SAM2 标注闭环基线
- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。 - 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。 - 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。 - 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。 - 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。 - 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。 - 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。 - 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。 - 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。 - 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。 - 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。 - 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。 - 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。 - 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。 - 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。 - 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
30
AGENTS.md
30
AGENTS.md
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
## 项目概述
|
## 项目概述
|
||||||
|
|
||||||
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、GT mask 导入、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
||||||
|
|
||||||
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
||||||
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
||||||
@@ -30,6 +30,7 @@
|
|||||||
| 前端请求 | Axios(`src/lib/api.ts`) |
|
| 前端请求 | Axios(`src/lib/api.ts`) |
|
||||||
| 实时通信 | WebSocket 客户端(`src/lib/websocket.ts`) |
|
| 实时通信 | WebSocket 客户端(`src/lib/websocket.ts`) |
|
||||||
| Canvas 渲染 | Konva + react-konva + use-image |
|
| Canvas 渲染 | Konva + react-konva + use-image |
|
||||||
|
| 几何布尔运算 | polygon-clipping |
|
||||||
| 图标库 | lucide-react |
|
| 图标库 | lucide-react |
|
||||||
| 动画依赖 | motion(在 `package.json` 中声明) |
|
| 动画依赖 | motion(在 `package.json` 中声明) |
|
||||||
| AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) |
|
| AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) |
|
||||||
@@ -38,9 +39,9 @@
|
|||||||
| 缓存 / 队列 Broker | Redis |
|
| 缓存 / 队列 Broker | Redis |
|
||||||
| 后台任务 | Celery worker |
|
| 后台任务 | Celery worker |
|
||||||
| 对象存储 | MinIO |
|
| 对象存储 | MinIO |
|
||||||
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;`GET /api/ai/models/status` 返回真实 GPU/模型状态 |
|
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;SAM 3 通过独立 Python 3.12 conda 环境桥接;`GET /api/ai/models/status` 返回真实 GPU/模型/HF 权重访问状态 |
|
||||||
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
|
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
|
||||||
| 运行时 | Node.js ES Modules;Python 3.11 后端环境 |
|
| 运行时 | Node.js ES Modules;Python 3.11 后端环境;可选 `sam3` Python 3.12 conda 环境 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -70,6 +71,7 @@ Seg_Server/
|
|||||||
│ ├── celery_app.py # Celery app 配置
|
│ ├── celery_app.py # Celery app 配置
|
||||||
│ ├── worker_tasks.py # Celery 任务入口
|
│ ├── worker_tasks.py # Celery 任务入口
|
||||||
│ ├── download_sam2.py # SAM 2 权重下载脚本
|
│ ├── download_sam2.py # SAM 2 权重下载脚本
|
||||||
|
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本
|
||||||
│ ├── requirements.txt # Python 依赖
|
│ ├── requirements.txt # Python 依赖
|
||||||
│ ├── routers/
|
│ ├── routers/
|
||||||
│ │ ├── auth.py # /api/auth/login
|
│ │ ├── auth.py # /api/auth/login
|
||||||
@@ -81,7 +83,8 @@ Seg_Server/
|
|||||||
│ └── services/
|
│ └── services/
|
||||||
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
|
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
|
||||||
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
|
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
|
||||||
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
|
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
|
||||||
|
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
|
||||||
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||||
└── src/ # React 前端
|
└── src/ # React 前端
|
||||||
├── main.tsx # React StrictMode 挂载
|
├── main.tsx # React StrictMode 挂载
|
||||||
@@ -188,10 +191,13 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
- `POST /api/media/parse`
|
- `POST /api/media/parse`
|
||||||
- `GET /api/tasks`
|
- `GET /api/tasks`
|
||||||
- `GET /api/tasks/{task_id}`
|
- `GET /api/tasks/{task_id}`
|
||||||
|
- `POST /api/tasks/{task_id}/cancel`
|
||||||
|
- `POST /api/tasks/{task_id}/retry`
|
||||||
- `POST /api/ai/predict`
|
- `POST /api/ai/predict`
|
||||||
- `GET /api/ai/models/status`
|
- `GET /api/ai/models/status`
|
||||||
- `POST /api/ai/auto`
|
- `POST /api/ai/auto`
|
||||||
- `POST /api/ai/annotate`
|
- `POST /api/ai/annotate`
|
||||||
|
- `POST /api/ai/import-gt-mask`
|
||||||
- `GET /api/ai/annotations`
|
- `GET /api/ai/annotations`
|
||||||
- `PATCH/DELETE /api/ai/annotations/{annotation_id}`
|
- `PATCH/DELETE /api/ai/annotations/{annotation_id}`
|
||||||
- `GET /api/dashboard/overview`
|
- `GET /api/dashboard/overview`
|
||||||
@@ -216,9 +222,11 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
|
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
|
||||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。
|
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。
|
||||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
|
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
|
||||||
7. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 调用 SAM registry。SAM 2 支持点/框/自动分割;SAM 3 入口支持文本语义推理,运行时不满足官方要求时会在状态接口中标为不可用。
|
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;点击 mask 可拖动/删除 polygon 顶点、通过边中点插入新顶点,并能选择编辑多 polygon mask 的单个子区域;区域合并/去除使用 `polygon-clipping` 做 union/difference;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||||
8. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
8. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。SAM 2 支持点/框/自动分割;`options.crop_to_prompt` 可对点/框 prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果;SAM 3 入口支持文本语义推理,主后端会通过 `sam3_external_worker.py` 调用独立 Python 3.12 环境;如果 Python/CUDA/包/Hugging Face gated 权重访问任一条件不满足,会在状态接口中标为不可用。
|
||||||
9. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出。
|
9. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||||
|
10. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||||
|
11. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -226,12 +234,16 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
|||||||
|
|
||||||
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
|
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
|
||||||
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。
|
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。
|
||||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;工作区“导出 JSON 标注集”按钮已绑定下载流程,导出前会先保存当前待归档 mask。
|
- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
|
||||||
|
- Polygon 顶点编辑会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
|
||||||
|
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
|
||||||
|
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
||||||
|
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
||||||
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||||
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
||||||
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
||||||
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。
|
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,支持取消 queued/running 任务、重试 failed/cancelled 任务和查看失败详情。Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
53
README.md
53
README.md
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
|
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
|
||||||
>
|
>
|
||||||
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。
|
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、GT mask 导入、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -14,10 +14,11 @@
|
|||||||
|
|
||||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析
|
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析
|
||||||
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性
|
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性
|
||||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/选点/框选,实时渲染 Mask 遮罩
|
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点拖动/删除、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||||
|
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
||||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||||
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
|
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
|
||||||
- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出
|
- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出;PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@
|
|||||||
│ ├── /api/projects 项目 & 视频帧 CRUD │
|
│ ├── /api/projects 项目 & 视频帧 CRUD │
|
||||||
│ ├── /api/templates 本体字典(分类/颜色/z-index) │
|
│ ├── /api/templates 本体字典(分类/颜色/z-index) │
|
||||||
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │
|
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │
|
||||||
│ ├── /api/tasks Celery 后台任务状态 │
|
│ ├── /api/tasks Celery 后台任务状态/取消/重试/详情 │
|
||||||
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │
|
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │
|
||||||
│ └── /api/export COCO JSON / PNG Masks 导出 │
|
│ └── /api/export COCO JSON / PNG Masks 导出 │
|
||||||
└──────────────────────────┬──────────────────────────────────┘
|
└──────────────────────────┬──────────────────────────────────┘
|
||||||
@@ -62,6 +63,7 @@
|
|||||||
| 样式方案 | TailwindCSS + 自定义深色主题 | v4 |
|
| 样式方案 | TailwindCSS + 自定义深色主题 | v4 |
|
||||||
| 状态管理 | Zustand | - |
|
| 状态管理 | Zustand | - |
|
||||||
| Canvas 渲染 | Konva + react-konva | - |
|
| Canvas 渲染 | Konva + react-konva | - |
|
||||||
|
| 几何布尔运算 | polygon-clipping | 0.15+ |
|
||||||
| HTTP 客户端 | Axios | - |
|
| HTTP 客户端 | Axios | - |
|
||||||
| 后端框架 | FastAPI | v0.136+ |
|
| 后端框架 | FastAPI | v0.136+ |
|
||||||
| 数据库 ORM | SQLAlchemy(依赖中包含 Alembic) | 2.0+ |
|
| 数据库 ORM | SQLAlchemy(依赖中包含 Alembic) | 2.0+ |
|
||||||
@@ -92,6 +94,7 @@ Seg_Server/
|
|||||||
│ ├── celery_app.py # Celery app 配置
|
│ ├── celery_app.py # Celery app 配置
|
||||||
│ ├── worker_tasks.py # Celery 任务入口
|
│ ├── worker_tasks.py # Celery 任务入口
|
||||||
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
|
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
|
||||||
|
│ ├── setup_sam3_env.sh # SAM 3 独立 Python 3.12 环境安装脚本
|
||||||
│ ├── requirements.txt # Python 依赖
|
│ ├── requirements.txt # Python 依赖
|
||||||
│ ├── routers/ # API 路由
|
│ ├── routers/ # API 路由
|
||||||
│ │ ├── auth.py # 登录认证
|
│ │ ├── auth.py # 登录认证
|
||||||
@@ -102,7 +105,8 @@ Seg_Server/
|
|||||||
│ │ └── export.py # 数据导出
|
│ │ └── export.py # 数据导出
|
||||||
│ └── services/ # 业务服务
|
│ └── services/ # 业务服务
|
||||||
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级)
|
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级)
|
||||||
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
|
│ ├── sam3_engine.py # SAM 3 状态检测、外部环境桥接与文本语义推理适配器
|
||||||
|
│ ├── sam3_external_worker.py # 独立 sam3 conda 环境中执行的状态/推理 helper
|
||||||
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||||
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
|
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
|
||||||
├── src/ # React 前端
|
├── src/ # React 前端
|
||||||
@@ -117,10 +121,10 @@ Seg_Server/
|
|||||||
│ └── components/ # 组件(扁平化目录)
|
│ └── components/ # 组件(扁平化目录)
|
||||||
│ ├── Login.tsx # 登录页
|
│ ├── Login.tsx # 登录页
|
||||||
│ ├── Sidebar.tsx # 左侧导航栏
|
│ ├── Sidebar.tsx # 左侧导航栏
|
||||||
│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列)
|
│ ├── Dashboard.tsx # 总体概况仪表盘(解析队列/任务控制)
|
||||||
│ ├── ProjectLibrary.tsx # 项目库列表
|
│ ├── ProjectLibrary.tsx # 项目库列表
|
||||||
│ ├── VideoWorkspace.tsx # 核心分割工作区布局
|
│ ├── VideoWorkspace.tsx # 核心分割工作区布局
|
||||||
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/选点/Mask渲染)
|
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/手工绘制/选点/Mask渲染)
|
||||||
│ ├── ToolsPalette.tsx # 左侧工具栏
|
│ ├── ToolsPalette.tsx # 左侧工具栏
|
||||||
│ ├── OntologyInspector.tsx # 右侧本体/属性检查面板
|
│ ├── OntologyInspector.tsx # 右侧本体/属性检查面板
|
||||||
│ ├── FrameTimeline.tsx # 底部时间轴
|
│ ├── FrameTimeline.tsx # 底部时间轴
|
||||||
@@ -161,7 +165,7 @@ Seg_Server/
|
|||||||
- **GPU**: NVIDIA GPU(推荐 RTX 4090 或同等算力),用于 SAM 推理;SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境
|
- **GPU**: NVIDIA GPU(推荐 RTX 4090 或同等算力),用于 SAM 推理;SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境
|
||||||
- **CUDA**: 12.x / 13.x
|
- **CUDA**: 12.x / 13.x
|
||||||
- **Node.js**: 22.x+
|
- **Node.js**: 22.x+
|
||||||
- **Python**: 3.11(通过 Miniconda/Anaconda 管理)
|
- **Python**: 主后端使用 3.11(通过 Miniconda/Anaconda 管理);SAM 3 使用独立 `sam3` Python 3.12 conda 环境
|
||||||
|
|
||||||
### 安装系统级依赖
|
### 安装系统级依赖
|
||||||
|
|
||||||
@@ -243,7 +247,22 @@ python download_sam2.py
|
|||||||
|
|
||||||
> **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。
|
> **注意**:当前系统磁盘紧张时,建议仅保留 `sam2_hiera_tiny.pt`,删除其他模型以释放空间。
|
||||||
|
|
||||||
### 步骤 5: 配置环境变量
|
### 步骤 5: 可选安装 SAM 3 环境
|
||||||
|
|
||||||
|
当前后端不会把 SAM 3 直接装进 `seg_server`,而是通过独立 `sam3` conda 环境执行 `backend/services/sam3_external_worker.py`。这样可以保留现有 Python 3.11 / SAM 2 环境。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ~/Desktop/Seg_Server
|
||||||
|
./backend/setup_sam3_env.sh
|
||||||
|
|
||||||
|
# 首次使用官方权重前,需要先在 Hugging Face 申请 facebook/sam3 访问权限并登录
|
||||||
|
conda activate sam3
|
||||||
|
huggingface-cli login
|
||||||
|
```
|
||||||
|
|
||||||
|
官方 `facebook/sam3` 权重约 3.45 GB,当前没有类似 SAM 2 `tiny/small/base/large` 的官方小权重梯度;`facebook/sam3.1` 约 3.5 GB,主要面向新的视频 multiplex checkpoint。未获得 gated model 授权时,`GET /api/ai/models/status` 会把 SAM 3 标为不可用并说明 checkpoint access 不满足。
|
||||||
|
|
||||||
|
### 步骤 6: 配置环境变量
|
||||||
|
|
||||||
后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件:
|
后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件:
|
||||||
|
|
||||||
@@ -258,7 +277,10 @@ minio_secure=false
|
|||||||
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
|
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
|
||||||
sam_model_config=configs/sam2/sam2_hiera_t.yaml
|
sam_model_config=configs/sam2/sam2_hiera_t.yaml
|
||||||
sam_default_model=sam2
|
sam_default_model=sam2
|
||||||
sam3_model_version=sam3.1
|
sam3_model_version=sam3
|
||||||
|
sam3_external_enabled=true
|
||||||
|
sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
|
||||||
|
sam3_timeout_seconds=300
|
||||||
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
|
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -271,7 +293,7 @@ VITE_WS_PROGRESS_URL=ws://192.168.3.11:8000/ws/progress
|
|||||||
|
|
||||||
如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000`。
|
如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000`。
|
||||||
|
|
||||||
### 步骤 6: 启动后端服务
|
### 步骤 7: 启动后端服务
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ~/Desktop/Seg_Server/backend
|
cd ~/Desktop/Seg_Server/backend
|
||||||
@@ -287,7 +309,8 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
|
|||||||
- 创建数据库表(如果不存在)
|
- 创建数据库表(如果不存在)
|
||||||
- 检查 MinIO bucket 是否存在
|
- 检查 MinIO bucket 是否存在
|
||||||
- 测试 Redis 连接
|
- 测试 Redis 连接
|
||||||
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3 与 GPU 的真实可用状态
|
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3、GPU 和 SAM 3 checkpoint access 的真实可用状态
|
||||||
|
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
|
||||||
|
|
||||||
### 步骤 6.1: 启动 Celery Worker
|
### 步骤 6.1: 启动 Celery Worker
|
||||||
|
|
||||||
@@ -301,7 +324,7 @@ celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
|||||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。
|
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。Dashboard 也可调用 `/api/tasks/{id}/cancel`、`/api/tasks/{id}/retry` 和 `/api/tasks/{id}` 完成任务取消、重试与失败详情查看。
|
||||||
|
|
||||||
### 步骤 7: 安装前端依赖并构建
|
### 步骤 7: 安装前端依赖并构建
|
||||||
|
|
||||||
@@ -438,7 +461,8 @@ pip install -e . --no-build-isolation
|
|||||||
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
||||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||||
- 工作区“导出 JSON 标注集”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||||
|
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。
|
||||||
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。
|
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。
|
||||||
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||||
|
|
||||||
@@ -447,6 +471,7 @@ pip install -e . --no-build-isolation
|
|||||||
```bash
|
```bash
|
||||||
curl http://localhost:8000/health
|
curl http://localhost:8000/health
|
||||||
curl http://localhost:8000/api/export/1/coco
|
curl http://localhost:8000/api/export/1/coco
|
||||||
|
curl http://localhost:8000/api/export/1/masks
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ class Settings(BaseSettings):
|
|||||||
sam_default_model: str = "sam2"
|
sam_default_model: str = "sam2"
|
||||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||||
sam3_model_version: str = "sam3.1"
|
sam3_model_version: str = "sam3"
|
||||||
|
sam3_external_enabled: bool = True
|
||||||
|
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
|
||||||
|
sam3_timeout_seconds: int = 300
|
||||||
|
sam3_status_cache_seconds: int = 30
|
||||||
|
sam3_confidence_threshold: float = 0.5
|
||||||
|
|
||||||
# App
|
# App
|
||||||
app_env: str = "development"
|
app_env: str = "development"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from redis_client import get_redis_client
|
from redis_client import get_redis_client
|
||||||
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,6 +22,8 @@ def _iso_now() -> str:
|
|||||||
def _event_type(task_status: str) -> str:
|
def _event_type(task_status: str) -> str:
|
||||||
if task_status == TASK_STATUS_SUCCESS:
|
if task_status == TASK_STATUS_SUCCESS:
|
||||||
return "complete"
|
return "complete"
|
||||||
|
if task_status == TASK_STATUS_CANCELLED:
|
||||||
|
return "cancelled"
|
||||||
if task_status == TASK_STATUS_FAILED:
|
if task_status == TASK_STATUS_FAILED:
|
||||||
return "error"
|
return "error"
|
||||||
return "progress"
|
return "progress"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, List
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from database import get_db
|
from database import get_db
|
||||||
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
|||||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||||||
|
"""Approximate a contour and convert it to normalized polygon coordinates."""
|
||||||
|
arc_length = cv2.arcLength(contour, True)
|
||||||
|
epsilon = max(1.0, arc_length * 0.01)
|
||||||
|
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||||
|
points = approx.reshape(-1, 2)
|
||||||
|
if len(points) < 3:
|
||||||
|
points = contour.reshape(-1, 2)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||||
|
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||||
|
]
|
||||||
|
for x, y in points
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||||||
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
|
return [
|
||||||
|
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||||
|
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||||
|
min(max(float(w) / max(width, 1), 0.0), 1.0),
|
||||||
|
min(max(float(h) / max(height, 1), 0.0), 1.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||||||
|
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||||||
|
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||||||
|
_, _, _, max_loc = cv2.minMaxLoc(dist)
|
||||||
|
x, y = max_loc
|
||||||
|
return [
|
||||||
|
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||||
|
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
return min(max(float(value), 0.0), 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
|
||||||
|
"""Return whether a normalized point is inside a normalized polygon."""
|
||||||
|
if len(polygon) < 3:
|
||||||
|
return False
|
||||||
|
x, y = point
|
||||||
|
inside = False
|
||||||
|
j = len(polygon) - 1
|
||||||
|
for i, current in enumerate(polygon):
|
||||||
|
xi, yi = current
|
||||||
|
xj, yj = polygon[j]
|
||||||
|
intersects = ((yi > y) != (yj > y)) and (
|
||||||
|
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
|
||||||
|
)
|
||||||
|
if intersects:
|
||||||
|
inside = not inside
|
||||||
|
j = i
|
||||||
|
return inside
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
|
||||||
|
xs = [_clamp01(point[0]) for point in points]
|
||||||
|
ys = [_clamp01(point[1]) for point in points]
|
||||||
|
x1 = max(0.0, min(xs) - margin)
|
||||||
|
y1 = max(0.0, min(ys) - margin)
|
||||||
|
x2 = min(1.0, max(xs) + margin)
|
||||||
|
y2 = min(1.0, max(ys) + margin)
|
||||||
|
if x2 - x1 < 0.05:
|
||||||
|
center = (x1 + x2) / 2
|
||||||
|
x1 = max(0.0, center - 0.025)
|
||||||
|
x2 = min(1.0, center + 0.025)
|
||||||
|
if y2 - y1 < 0.05:
|
||||||
|
center = (y1 + y2) / 2
|
||||||
|
y1 = max(0.0, center - 0.025)
|
||||||
|
y2 = min(1.0, center + 0.025)
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
x1, y1, x2, y2 = bounds
|
||||||
|
left = int(round(x1 * width))
|
||||||
|
top = int(round(y1 * height))
|
||||||
|
right = max(left + 1, int(round(x2 * width)))
|
||||||
|
bottom = max(top + 1, int(round(y2 * height)))
|
||||||
|
return image[top:bottom, left:right]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
|
||||||
|
x1, y1, x2, y2 = bounds
|
||||||
|
return [
|
||||||
|
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
|
||||||
|
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _from_crop_polygon(
|
||||||
|
polygon: list[list[float]],
|
||||||
|
bounds: tuple[float, float, float, float],
|
||||||
|
) -> list[list[float]]:
|
||||||
|
x1, y1, x2, y2 = bounds
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
_clamp01(x1 + float(point[0]) * (x2 - x1)),
|
||||||
|
_clamp01(y1 + float(point[1]) * (y2 - y1)),
|
||||||
|
]
|
||||||
|
for point in polygon
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_predictions(
|
||||||
|
polygons: list[list[list[float]]],
|
||||||
|
scores: list[float],
|
||||||
|
options: dict[str, Any],
|
||||||
|
negative_points: list[list[float]] | None = None,
|
||||||
|
) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
|
if not options.get("auto_filter_background"):
|
||||||
|
return polygons, scores
|
||||||
|
|
||||||
|
min_score = float(options.get("min_score", 0.0) or 0.0)
|
||||||
|
next_polygons: list[list[list[float]]] = []
|
||||||
|
next_scores: list[float] = []
|
||||||
|
for index, polygon in enumerate(polygons):
|
||||||
|
score = scores[index] if index < len(scores) else 0.0
|
||||||
|
if score < min_score:
|
||||||
|
continue
|
||||||
|
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
|
||||||
|
continue
|
||||||
|
next_polygons.append(polygon)
|
||||||
|
next_scores.append(score)
|
||||||
|
return next_polygons, next_scores
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/predict",
|
"/predict",
|
||||||
response_model=PredictResponse,
|
response_model=PredictResponse,
|
||||||
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|||||||
|
|
||||||
image = _load_frame_image(frame)
|
image = _load_frame_image(frame)
|
||||||
prompt_type = payload.prompt_type.lower()
|
prompt_type = payload.prompt_type.lower()
|
||||||
|
options = payload.options or {}
|
||||||
|
|
||||||
polygons: List[List[List[float]]] = []
|
polygons: List[List[List[float]]] = []
|
||||||
scores: List[float] = []
|
scores: List[float] = []
|
||||||
|
negative_points: list[list[float]] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prompt_type == "point":
|
if prompt_type == "point":
|
||||||
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|||||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||||
if not isinstance(labels, list) or len(labels) != len(points):
|
if not isinstance(labels, list) or len(labels) != len(points):
|
||||||
labels = [1] * len(points)
|
labels = [1] * len(points)
|
||||||
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
|
negative_points = [
|
||||||
|
point for point, label in zip(points, labels) if label == 0
|
||||||
|
]
|
||||||
|
inference_image = image
|
||||||
|
inference_points = points
|
||||||
|
crop_bounds = None
|
||||||
|
if options.get("crop_to_prompt"):
|
||||||
|
margin = float(options.get("crop_margin", 0.25) or 0.25)
|
||||||
|
crop_bounds = _crop_bounds_from_points(points, margin)
|
||||||
|
inference_image = _crop_image(image, crop_bounds)
|
||||||
|
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||||||
|
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
|
||||||
|
if crop_bounds:
|
||||||
|
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||||
|
|
||||||
elif prompt_type == "box":
|
elif prompt_type == "box":
|
||||||
box = payload.prompt_data
|
box = payload.prompt_data
|
||||||
if not isinstance(box, list) or len(box) != 4:
|
if not isinstance(box, list) or len(box) != 4:
|
||||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||||
polygons, scores = sam_registry.predict_box(payload.model, image, box)
|
inference_image = image
|
||||||
|
inference_box = box
|
||||||
|
crop_bounds = None
|
||||||
|
if options.get("crop_to_prompt"):
|
||||||
|
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||||||
|
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
|
||||||
|
inference_image = _crop_image(image, crop_bounds)
|
||||||
|
inference_box = [
|
||||||
|
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||||||
|
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||||||
|
]
|
||||||
|
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
|
||||||
|
if crop_bounds:
|
||||||
|
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||||
|
|
||||||
elif prompt_type == "semantic":
|
elif prompt_type == "semantic":
|
||||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||||
@@ -97,6 +259,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
|
||||||
return {"polygons": polygons, "scores": scores}
|
return {"polygons": polygons, "scores": scores}
|
||||||
|
|
||||||
|
|
||||||
@@ -161,6 +324,100 @@ def save_annotation(
|
|||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/import-gt-mask",
|
||||||
|
response_model=List[AnnotationOut],
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Import a GT mask and reduce components to editable point regions",
|
||||||
|
)
|
||||||
|
async def import_gt_mask(
|
||||||
|
project_id: int = Form(...),
|
||||||
|
frame_id: int = Form(...),
|
||||||
|
template_id: int | None = Form(None),
|
||||||
|
label: str = Form("GT Mask"),
|
||||||
|
color: str = Form("#22c55e"),
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> List[Annotation]:
|
||||||
|
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||||
|
|
||||||
|
Each connected component becomes one annotation. The `points` field stores a
|
||||||
|
positive seed point at the component's distance-transform center, which gives
|
||||||
|
the frontend an editable point-region representation instead of a static
|
||||||
|
bitmap layer.
|
||||||
|
"""
|
||||||
|
project = db.query(Project).filter(Project.id == project_id).first()
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||||
|
if not frame:
|
||||||
|
raise HTTPException(status_code=404, detail="Frame not found")
|
||||||
|
|
||||||
|
if template_id is not None:
|
||||||
|
template = db.query(Template).filter(Template.id == template_id).first()
|
||||||
|
if not template:
|
||||||
|
raise HTTPException(status_code=404, detail="Template not found")
|
||||||
|
|
||||||
|
data = await file.read()
|
||||||
|
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||||
|
if image is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||||
|
|
||||||
|
width = int(frame.width or image.shape[1])
|
||||||
|
height = int(frame.height or image.shape[0])
|
||||||
|
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||||
|
if not label_values:
|
||||||
|
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||||
|
has_multiple_labels = len(label_values) > 1
|
||||||
|
|
||||||
|
annotations: list[Annotation] = []
|
||||||
|
for label_value in label_values:
|
||||||
|
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||||
|
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||||
|
|
||||||
|
for contour in contours:
|
||||||
|
if cv2.contourArea(contour) < 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||||
|
if len(polygon) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
component = np.zeros_like(binary, dtype=np.uint8)
|
||||||
|
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||||
|
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||||
|
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||||
|
|
||||||
|
annotation = Annotation(
|
||||||
|
project_id=project_id,
|
||||||
|
frame_id=frame_id,
|
||||||
|
template_id=template_id,
|
||||||
|
mask_data={
|
||||||
|
"polygons": [polygon],
|
||||||
|
"label": annotation_label,
|
||||||
|
"color": color,
|
||||||
|
"source": "gt_mask",
|
||||||
|
"gt_label_value": label_value,
|
||||||
|
"image_size": {"width": width, "height": height},
|
||||||
|
},
|
||||||
|
points=[seed_point],
|
||||||
|
bbox=bbox,
|
||||||
|
)
|
||||||
|
db.add(annotation)
|
||||||
|
annotations.append(annotation)
|
||||||
|
|
||||||
|
if not annotations:
|
||||||
|
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
for annotation in annotations:
|
||||||
|
db.refresh(annotation)
|
||||||
|
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/annotations",
|
"/annotations",
|
||||||
response_model=List[AnnotationOut],
|
response_model=List[AnnotationOut],
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
|
|||||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||||
|
|
||||||
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
||||||
|
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
def _system_load_percent() -> int:
|
def _system_load_percent() -> int:
|
||||||
@@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
|||||||
"name": task.project.name if task.project else f"任务 {task.id}",
|
"name": task.project.name if task.project else f"任务 {task.id}",
|
||||||
"progress": task.progress,
|
"progress": task.progress,
|
||||||
"status": task.message or task.status,
|
"status": task.message or task.status,
|
||||||
|
"raw_status": task.status,
|
||||||
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
||||||
|
"error": task.error,
|
||||||
"updated_at": _iso_or_none(task.updated_at),
|
"updated_at": _iso_or_none(task.updated_at),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
|||||||
.limit(50)
|
.limit(50)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
|
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
|
||||||
|
|
||||||
activities: list[dict[str, Any]] = []
|
activities: list[dict[str, Any]] = []
|
||||||
for task in recent_tasks[:10]:
|
for task in recent_tasks[:10]:
|
||||||
|
|||||||
@@ -37,6 +37,54 @@ def _mask_from_polygon(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _annotation_z_index(annotation: Annotation) -> int:
|
||||||
|
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||||
|
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
|
||||||
|
try:
|
||||||
|
return int(class_meta["zIndex"])
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
if annotation.template and annotation.template.z_index is not None:
|
||||||
|
return int(annotation.template.z_index)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _annotation_class_key(annotation: Annotation) -> str:
|
||||||
|
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||||
|
if isinstance(class_meta, dict):
|
||||||
|
if class_meta.get("id"):
|
||||||
|
return f"class:{class_meta['id']}"
|
||||||
|
if class_meta.get("name"):
|
||||||
|
return f"name:{class_meta['name']}"
|
||||||
|
if annotation.template_id:
|
||||||
|
return f"template:{annotation.template_id}"
|
||||||
|
return f"annotation:{annotation.id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _annotation_label(annotation: Annotation) -> str:
|
||||||
|
mask_data = annotation.mask_data or {}
|
||||||
|
class_meta = mask_data.get("class") or {}
|
||||||
|
if isinstance(class_meta, dict) and class_meta.get("name"):
|
||||||
|
return str(class_meta["name"])
|
||||||
|
if mask_data.get("label"):
|
||||||
|
return str(mask_data["label"])
|
||||||
|
if annotation.template and annotation.template.name:
|
||||||
|
return str(annotation.template.name)
|
||||||
|
return f"Annotation {annotation.id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _annotation_color(annotation: Annotation) -> str:
|
||||||
|
mask_data = annotation.mask_data or {}
|
||||||
|
class_meta = mask_data.get("class") or {}
|
||||||
|
if isinstance(class_meta, dict) and class_meta.get("color"):
|
||||||
|
return str(class_meta["color"])
|
||||||
|
if mask_data.get("color"):
|
||||||
|
return str(mask_data["color"])
|
||||||
|
if annotation.template and annotation.template.color:
|
||||||
|
return str(annotation.template.color)
|
||||||
|
return "#ffffff"
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{project_id}/coco",
|
"/{project_id}/coco",
|
||||||
summary="Export annotations in COCO format",
|
summary="Export annotations in COCO format",
|
||||||
@@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
|||||||
summary="Export PNG masks as a ZIP archive",
|
summary="Export PNG masks as a ZIP archive",
|
||||||
)
|
)
|
||||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||||
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
|
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||||
project = db.query(Project).filter(Project.id == project_id).first()
|
project = db.query(Project).filter(Project.id == project_id).first()
|
||||||
if not project:
|
if not project:
|
||||||
raise HTTPException(status_code=404, detail="Project not found")
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
annotations = (
|
annotations = (
|
||||||
db.query(Annotation)
|
db.query(Annotation)
|
||||||
.filter(Annotation.project_id == project_id)
|
.filter(Annotation.project_id == project_id)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
frames = (
|
||||||
|
db.query(Frame)
|
||||||
|
.filter(Frame.project_id == project_id)
|
||||||
|
.order_by(Frame.frame_index)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
class_values: dict[str, int] = {}
|
||||||
|
semantic_classes: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def class_value(annotation: Annotation) -> int:
|
||||||
|
key = _annotation_class_key(annotation)
|
||||||
|
if key not in class_values:
|
||||||
|
value = len(class_values) + 1
|
||||||
|
class_values[key] = value
|
||||||
|
semantic_classes.append({
|
||||||
|
"value": value,
|
||||||
|
"key": key,
|
||||||
|
"label": _annotation_label(annotation),
|
||||||
|
"color": _annotation_color(annotation),
|
||||||
|
"zIndex": _annotation_z_index(annotation),
|
||||||
|
"template_id": annotation.template_id,
|
||||||
|
})
|
||||||
|
return class_values[key]
|
||||||
|
|
||||||
zip_buffer = io.BytesIO()
|
zip_buffer = io.BytesIO()
|
||||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||||
for ann in annotations:
|
for ann in annotations:
|
||||||
if not ann.mask_data:
|
if not ann.mask_data:
|
||||||
continue
|
continue
|
||||||
@@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
|||||||
mask = _mask_from_polygon(poly, width, height)
|
mask = _mask_from_polygon(poly, width, height)
|
||||||
combined = np.maximum(combined, mask)
|
combined = np.maximum(combined, mask)
|
||||||
|
|
||||||
# Encode PNG
|
|
||||||
import cv2
|
|
||||||
_, encoded = cv2.imencode(".png", combined)
|
_, encoded = cv2.imencode(".png", combined)
|
||||||
fname = f"mask_{ann.id:06d}.png"
|
fname = f"mask_{ann.id:06d}.png"
|
||||||
zf.writestr(fname, encoded.tobytes())
|
zf.writestr(fname, encoded.tobytes())
|
||||||
|
if ann.frame_id is not None:
|
||||||
|
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
entries = frame_masks.get(frame.id, [])
|
||||||
|
if not entries:
|
||||||
|
continue
|
||||||
|
width = frame.width or 1920
|
||||||
|
height = frame.height or 1080
|
||||||
|
semantic = np.zeros((height, width), dtype=np.uint8)
|
||||||
|
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||||
|
semantic[mask > 0] = class_value(ann)
|
||||||
|
_, encoded = cv2.imencode(".png", semantic)
|
||||||
|
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||||
|
|
||||||
|
zf.writestr(
|
||||||
|
"semantic_classes.json",
|
||||||
|
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
zip_buffer.seek(0)
|
zip_buffer.seek(0)
|
||||||
filename = f"project_{project_id}_masks.zip"
|
filename = f"project_{project_id}_masks.zip"
|
||||||
|
|||||||
@@ -1,15 +1,45 @@
|
|||||||
"""Processing task query endpoints."""
|
"""Processing task query endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from celery_app import celery_app
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import ProcessingTask
|
from models import ProcessingTask, Project
|
||||||
|
from progress_events import publish_task_progress_event
|
||||||
from schemas import ProcessingTaskOut
|
from schemas import ProcessingTaskOut
|
||||||
|
from statuses import (
|
||||||
|
PROJECT_STATUS_PARSING,
|
||||||
|
PROJECT_STATUS_PENDING,
|
||||||
|
PROJECT_STATUS_READY,
|
||||||
|
TASK_ACTIVE_STATUSES,
|
||||||
|
TASK_STATUS_CANCELLED,
|
||||||
|
TASK_STATUS_FAILED,
|
||||||
|
TASK_STATUS_QUEUED,
|
||||||
|
)
|
||||||
|
from worker_tasks import parse_project_media
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||||||
|
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def _project_status_after_stop(project: Project) -> str:
|
||||||
|
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||||||
@@ -31,7 +61,78 @@ def list_tasks(
|
|||||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||||
"""Return one background task by id."""
|
"""Return one background task by id."""
|
||||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
return _get_task_or_404(task_id, db)
|
||||||
if not task:
|
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
|
||||||
|
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||||
|
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||||
|
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||||
|
task = _get_task_or_404(task_id, db)
|
||||||
|
if task.status not in TASK_ACTIVE_STATUSES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Task is not cancellable in status: {task.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if task.celery_task_id:
|
||||||
|
try:
|
||||||
|
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||||||
|
|
||||||
|
task.status = TASK_STATUS_CANCELLED
|
||||||
|
task.progress = 100
|
||||||
|
task.message = "任务已取消"
|
||||||
|
task.error = "Cancelled by user"
|
||||||
|
task.finished_at = _now()
|
||||||
|
if task.project:
|
||||||
|
task.project.status = _project_status_after_stop(task.project)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
publish_task_progress_event(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||||
|
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||||
|
"""Create a fresh queued task from a failed or cancelled task."""
|
||||||
|
previous = _get_task_or_404(task_id, db)
|
||||||
|
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Task is not retryable in status: {previous.status}",
|
||||||
|
)
|
||||||
|
if previous.project_id is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||||
|
|
||||||
|
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
if not project.video_path:
|
||||||
|
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||||
|
|
||||||
|
payload = dict(previous.payload or {})
|
||||||
|
payload.setdefault("source_type", project.source_type or "video")
|
||||||
|
payload["retry_of"] = previous.id
|
||||||
|
|
||||||
|
task = ProcessingTask(
|
||||||
|
task_type=previous.task_type,
|
||||||
|
status=TASK_STATUS_QUEUED,
|
||||||
|
progress=0,
|
||||||
|
message=f"重试任务已入队(源任务 #{previous.id})",
|
||||||
|
project_id=project.id,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
project.status = PROJECT_STATUS_PARSING
|
||||||
|
db.add(task)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
publish_task_progress_event(task)
|
||||||
|
|
||||||
|
async_result = parse_project_media.delay(task.id)
|
||||||
|
task.celery_task_id = async_result.id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
publish_task_progress_event(task)
|
||||||
return task
|
return task
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class PredictRequest(BaseModel):
|
|||||||
prompt_type: str # point / box / semantic
|
prompt_type: str # point / box / semantic
|
||||||
prompt_data: Any
|
prompt_data: Any
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
|
options: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class PredictResponse(BaseModel):
|
class PredictResponse(BaseModel):
|
||||||
@@ -201,6 +202,8 @@ class AiModelStatus(BaseModel):
|
|||||||
python_ok: bool = True
|
python_ok: bool = True
|
||||||
torch_ok: bool = True
|
torch_ok: bool = True
|
||||||
cuda_required: bool = False
|
cuda_required: bool = False
|
||||||
|
external_available: bool = False
|
||||||
|
external_python: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class GpuStatus(BaseModel):
|
class GpuStatus(BaseModel):
|
||||||
|
|||||||
@@ -20,9 +20,11 @@ from services.frame_parser import (
|
|||||||
upload_frames_to_minio,
|
upload_frames_to_minio,
|
||||||
)
|
)
|
||||||
from statuses import (
|
from statuses import (
|
||||||
|
PROJECT_STATUS_PENDING,
|
||||||
PROJECT_STATUS_ERROR,
|
PROJECT_STATUS_ERROR,
|
||||||
PROJECT_STATUS_PARSING,
|
PROJECT_STATUS_PARSING,
|
||||||
PROJECT_STATUS_READY,
|
PROJECT_STATUS_READY,
|
||||||
|
TASK_STATUS_CANCELLED,
|
||||||
TASK_STATUS_FAILED,
|
TASK_STATUS_FAILED,
|
||||||
TASK_STATUS_RUNNING,
|
TASK_STATUS_RUNNING,
|
||||||
TASK_STATUS_SUCCESS,
|
TASK_STATUS_SUCCESS,
|
||||||
@@ -31,6 +33,10 @@ from statuses import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCancelled(RuntimeError):
|
||||||
|
"""Raised internally when a persisted task has been cancelled."""
|
||||||
|
|
||||||
|
|
||||||
def _now() -> datetime:
|
def _now() -> datetime:
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
@@ -66,12 +72,29 @@ def _set_task_state(
|
|||||||
publish_task_progress_event(task)
|
publish_task_progress_event(task)
|
||||||
|
|
||||||
|
|
||||||
|
def _project_status_after_stop(project: Project) -> str:
|
||||||
|
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||||
|
db.refresh(task)
|
||||||
|
if task.status == TASK_STATUS_CANCELLED:
|
||||||
|
raise TaskCancelled("Task was cancelled")
|
||||||
|
|
||||||
|
|
||||||
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||||
"""Parse one project's media and update task progress in the database."""
|
"""Parse one project's media and update task progress in the database."""
|
||||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||||
if not task:
|
if not task:
|
||||||
raise ValueError(f"Task not found: {task_id}")
|
raise ValueError(f"Task not found: {task_id}")
|
||||||
|
|
||||||
|
if task.status == TASK_STATUS_CANCELLED:
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": TASK_STATUS_CANCELLED,
|
||||||
|
"message": task.message or "任务已取消",
|
||||||
|
}
|
||||||
|
|
||||||
if task.project_id is None:
|
if task.project_id is None:
|
||||||
_set_task_state(
|
_set_task_state(
|
||||||
db,
|
db,
|
||||||
@@ -111,6 +134,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
|||||||
db.commit()
|
db.commit()
|
||||||
raise ValueError("Project has no media uploaded")
|
raise ValueError("Project has no media uploaded")
|
||||||
|
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
project.status = PROJECT_STATUS_PARSING
|
project.status = PROJECT_STATUS_PARSING
|
||||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||||
|
|
||||||
@@ -121,6 +145,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
||||||
if effective_source == "dicom":
|
if effective_source == "dicom":
|
||||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||||
@@ -129,20 +154,24 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
|||||||
client = get_minio_client()
|
client = get_minio_client()
|
||||||
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
if obj.object_name.lower().endswith(".dcm"):
|
if obj.object_name.lower().endswith(".dcm"):
|
||||||
data = download_file(obj.object_name)
|
data = download_file(obj.object_name)
|
||||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||||
with open(local_dcm, "wb") as f:
|
with open(local_dcm, "wb") as f:
|
||||||
f.write(data)
|
f.write(data)
|
||||||
|
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||||
else:
|
else:
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
media_bytes = download_file(project.video_path)
|
media_bytes = download_file(project.video_path)
|
||||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||||
with open(local_path, "wb") as f:
|
with open(local_path, "wb") as f:
|
||||||
f.write(media_bytes)
|
f.write(media_bytes)
|
||||||
|
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||||
project.original_fps = original_fps
|
project.original_fps = original_fps
|
||||||
@@ -158,12 +187,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
|||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||||
|
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
||||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||||
|
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
||||||
frames_out = []
|
frames_out = []
|
||||||
for idx, obj_name in enumerate(object_names):
|
for idx, obj_name in enumerate(object_names):
|
||||||
|
_ensure_not_cancelled(db, task)
|
||||||
local_frame = frame_files[idx]
|
local_frame = frame_files[idx]
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
@@ -203,6 +235,23 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
||||||
return result
|
return result
|
||||||
|
except TaskCancelled:
|
||||||
|
project.status = _project_status_after_stop(project)
|
||||||
|
task.status = TASK_STATUS_CANCELLED
|
||||||
|
task.progress = 100
|
||||||
|
task.message = task.message or "任务已取消"
|
||||||
|
task.error = task.error or "Cancelled by user"
|
||||||
|
task.finished_at = task.finished_at or _now()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
publish_task_progress_event(task)
|
||||||
|
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"project_id": project.id,
|
||||||
|
"status": TASK_STATUS_CANCELLED,
|
||||||
|
"message": task.message,
|
||||||
|
}
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
project.status = PROJECT_STATUS_ERROR
|
project.status = PROJECT_STATUS_ERROR
|
||||||
_set_task_state(
|
_set_task_state(
|
||||||
|
|||||||
@@ -9,8 +9,14 @@ the package.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -41,6 +47,8 @@ class SAM3Engine:
|
|||||||
self._processor: Any | None = None
|
self._processor: Any | None = None
|
||||||
self._model_loaded = False
|
self._model_loaded = False
|
||||||
self._last_error: str | None = None
|
self._last_error: str | None = None
|
||||||
|
self._external_status_cache: dict[str, Any] | None = None
|
||||||
|
self._external_status_checked_at = 0.0
|
||||||
|
|
||||||
def _python_ok(self) -> bool:
|
def _python_ok(self) -> bool:
|
||||||
return sys.version_info >= (3, 12)
|
return sys.version_info >= (3, 12)
|
||||||
@@ -51,6 +59,81 @@ class SAM3Engine:
|
|||||||
def _can_load(self) -> bool:
|
def _can_load(self) -> bool:
|
||||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||||
|
|
||||||
|
def _worker_path(self) -> Path:
|
||||||
|
return Path(__file__).with_name("sam3_external_worker.py")
|
||||||
|
|
||||||
|
def _external_python_exists(self) -> bool:
|
||||||
|
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
|
||||||
|
|
||||||
|
def _external_status(self, force: bool = False) -> dict[str, Any]:
|
||||||
|
now = time.monotonic()
|
||||||
|
if (
|
||||||
|
not force
|
||||||
|
and self._external_status_cache is not None
|
||||||
|
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
|
||||||
|
):
|
||||||
|
return self._external_status_cache
|
||||||
|
|
||||||
|
if not settings.sam3_external_enabled:
|
||||||
|
status = {
|
||||||
|
"available": False,
|
||||||
|
"package_available": False,
|
||||||
|
"python_ok": False,
|
||||||
|
"torch_ok": False,
|
||||||
|
"cuda_available": False,
|
||||||
|
"device": "unavailable",
|
||||||
|
"message": "SAM 3 external runtime is disabled.",
|
||||||
|
}
|
||||||
|
elif not self._external_python_exists():
|
||||||
|
status = {
|
||||||
|
"available": False,
|
||||||
|
"package_available": False,
|
||||||
|
"python_ok": False,
|
||||||
|
"torch_ok": False,
|
||||||
|
"cuda_available": False,
|
||||||
|
"device": "unavailable",
|
||||||
|
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||||
|
completed = subprocess.run(
|
||||||
|
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=min(settings.sam3_timeout_seconds, 30),
|
||||||
|
check=False,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
if completed.returncode != 0:
|
||||||
|
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||||
|
status = {
|
||||||
|
"available": False,
|
||||||
|
"package_available": False,
|
||||||
|
"python_ok": False,
|
||||||
|
"torch_ok": False,
|
||||||
|
"cuda_available": False,
|
||||||
|
"device": "unavailable",
|
||||||
|
"message": f"SAM 3 external status failed: {detail}",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
status = json.loads(completed.stdout)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
status = {
|
||||||
|
"available": False,
|
||||||
|
"package_available": False,
|
||||||
|
"python_ok": False,
|
||||||
|
"torch_ok": False,
|
||||||
|
"cuda_available": False,
|
||||||
|
"device": "unavailable",
|
||||||
|
"message": f"SAM 3 external status failed: {exc}",
|
||||||
|
}
|
||||||
|
|
||||||
|
self._external_status_cache = status
|
||||||
|
self._external_status_checked_at = now
|
||||||
|
return status
|
||||||
|
|
||||||
def _load_model(self) -> None:
|
def _load_model(self) -> None:
|
||||||
if self._model_loaded:
|
if self._model_loaded:
|
||||||
return
|
return
|
||||||
@@ -92,26 +175,86 @@ class SAM3Engine:
|
|||||||
return "SAM 3 dependencies are present; model will load on first inference."
|
return "SAM 3 dependencies are present; model will load on first inference."
|
||||||
|
|
||||||
def status(self) -> dict:
|
def status(self) -> dict:
|
||||||
available = self._can_load()
|
external_status = self._external_status()
|
||||||
|
available = bool(self._can_load() or external_status.get("available"))
|
||||||
|
external_ready = bool(external_status.get("available"))
|
||||||
|
message = self._last_error or self._status_message()
|
||||||
|
if self._processor is not None:
|
||||||
|
message = "SAM 3 model loaded and ready."
|
||||||
|
elif external_ready:
|
||||||
|
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||||
|
elif external_status.get("message") and not self._can_load():
|
||||||
|
message = str(external_status["message"])
|
||||||
return {
|
return {
|
||||||
"id": "sam3",
|
"id": "sam3",
|
||||||
"label": "SAM 3",
|
"label": "SAM 3",
|
||||||
"available": available,
|
"available": available,
|
||||||
"loaded": self._processor is not None,
|
"loaded": self._processor is not None,
|
||||||
"device": "cuda" if self._gpu_ok() else "unavailable",
|
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||||
"supports": ["semantic"],
|
"supports": ["semantic"],
|
||||||
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
|
"message": message,
|
||||||
"package_available": SAM3_PACKAGE_AVAILABLE,
|
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||||
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
|
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
|
||||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||||
"python_ok": self._python_ok(),
|
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||||
"torch_ok": TORCH_AVAILABLE,
|
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||||
"cuda_required": True,
|
"cuda_required": True,
|
||||||
|
"external_available": external_ready,
|
||||||
|
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
|
status = self._external_status(force=True)
|
||||||
|
if not status.get("available"):
|
||||||
|
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
|
||||||
|
tmp_path = Path(tmpdir)
|
||||||
|
image_path = tmp_path / "image.png"
|
||||||
|
request_path = tmp_path / "request.json"
|
||||||
|
Image.fromarray(image).save(image_path)
|
||||||
|
request_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"image_path": str(image_path),
|
||||||
|
"text": text.strip(),
|
||||||
|
"model_version": settings.sam3_model_version,
|
||||||
|
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||||
|
completed = subprocess.run(
|
||||||
|
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=settings.sam3_timeout_seconds,
|
||||||
|
check=False,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
if completed.returncode != 0:
|
||||||
|
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(detail)
|
||||||
|
detail = parsed.get("error", detail)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
|
||||||
|
|
||||||
|
payload = json.loads(completed.stdout)
|
||||||
|
if payload.get("error"):
|
||||||
|
raise RuntimeError(str(payload["error"]))
|
||||||
|
return payload.get("polygons", []), payload.get("scores", [])
|
||||||
|
|
||||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||||
|
if not self._can_load() and self._external_status().get("available"):
|
||||||
|
return self._predict_semantic_external(image, text)
|
||||||
if not self._ensure_ready():
|
if not self._ensure_ready():
|
||||||
raise RuntimeError(self.status()["message"])
|
raise RuntimeError(self.status()["message"])
|
||||||
|
|
||||||
|
|||||||
190
backend/services/sam3_external_worker.py
Normal file
190
backend/services/sam3_external_worker.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
|
||||||
|
|
||||||
|
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
|
||||||
|
environment while this helper is executed with a separate conda env that meets
|
||||||
|
SAM 3's stricter runtime requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
cuda_available = bool(torch.cuda.is_available())
|
||||||
|
return (
|
||||||
|
cuda_available,
|
||||||
|
getattr(torch, "__version__", None),
|
||||||
|
getattr(torch.version, "cuda", None),
|
||||||
|
torch.cuda.get_device_name(0) if cuda_available else None,
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_error(exc: Exception) -> str:
|
||||||
|
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
|
||||||
|
for line in lines:
|
||||||
|
if "Access to model" in line or "Cannot access gated repo" in line:
|
||||||
|
return line
|
||||||
|
return lines[0] if lines else exc.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
|
||||||
|
hf_hub_download(repo_id=repo_id, filename="config.json")
|
||||||
|
return True, None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return False, _compact_error(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_status() -> dict[str, Any]:
|
||||||
|
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
|
||||||
|
package_error = None
|
||||||
|
package_available = importlib.util.find_spec("sam3") is not None
|
||||||
|
if package_available:
|
||||||
|
try:
|
||||||
|
import sam3 # noqa: F401
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
package_available = False
|
||||||
|
package_error = str(exc)
|
||||||
|
cuda_available, torch_version, cuda_version, device_name = _torch_status()
|
||||||
|
python_ok = sys.version_info >= (3, 12)
|
||||||
|
checkpoint_access = False
|
||||||
|
checkpoint_error = None
|
||||||
|
if package_available:
|
||||||
|
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
|
||||||
|
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
|
||||||
|
missing = []
|
||||||
|
if not python_ok:
|
||||||
|
missing.append("Python 3.12+ runtime")
|
||||||
|
if not package_available:
|
||||||
|
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
|
||||||
|
if torch_version is None:
|
||||||
|
missing.append("PyTorch")
|
||||||
|
if not cuda_available:
|
||||||
|
missing.append("CUDA GPU")
|
||||||
|
if package_available and not checkpoint_access:
|
||||||
|
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
|
||||||
|
return {
|
||||||
|
"available": available,
|
||||||
|
"package_available": package_available,
|
||||||
|
"checkpoint_access": checkpoint_access,
|
||||||
|
"python_ok": python_ok,
|
||||||
|
"torch_ok": torch_version is not None,
|
||||||
|
"torch_version": torch_version,
|
||||||
|
"cuda_version": cuda_version,
|
||||||
|
"cuda_available": cuda_available,
|
||||||
|
"device": "cuda" if cuda_available else "unavailable",
|
||||||
|
"device_name": device_name,
|
||||||
|
"message": (
|
||||||
|
"SAM 3 external runtime is ready."
|
||||||
|
if available
|
||||||
|
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
if mask.dtype != np.uint8:
|
||||||
|
mask = (mask > 0).astype(np.uint8)
|
||||||
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
height, width = mask.shape[:2]
|
||||||
|
largest = []
|
||||||
|
for contour in contours:
|
||||||
|
if len(contour) > len(largest):
|
||||||
|
largest = contour
|
||||||
|
if len(largest) < 3:
|
||||||
|
return []
|
||||||
|
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_numpy(value: Any) -> np.ndarray:
|
||||||
|
if hasattr(value, "detach"):
|
||||||
|
value = value.detach().cpu().numpy()
|
||||||
|
elif hasattr(value, "cpu"):
|
||||||
|
value = value.cpu().numpy()
|
||||||
|
return np.asarray(value)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(request_path: Path) -> dict[str, Any]:
|
||||||
|
import torch
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
|
||||||
|
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||||
|
image_path = Path(payload["image_path"])
|
||||||
|
text = str(payload["text"]).strip()
|
||||||
|
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||||
|
if not text:
|
||||||
|
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||||
|
state = processor.set_image(image)
|
||||||
|
output = processor.set_text_prompt(state=state, prompt=text)
|
||||||
|
|
||||||
|
masks = _to_numpy(output.get("masks", []))
|
||||||
|
scores = _to_numpy(output.get("scores", []))
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks[:, 0]
|
||||||
|
elif masks.ndim == 3 and masks.shape[0] == 1:
|
||||||
|
masks = masks[None, 0]
|
||||||
|
|
||||||
|
polygons = []
|
||||||
|
for mask in masks:
|
||||||
|
polygon = _mask_to_polygon(mask)
|
||||||
|
if polygon:
|
||||||
|
polygons.append(polygon)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"polygons": polygons,
|
||||||
|
"scores": scores.astype(float).tolist() if scores.size else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
|
||||||
|
parser.add_argument("--status", action="store_true")
|
||||||
|
parser.add_argument("--request", type=Path)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.status:
|
||||||
|
print(json.dumps(runtime_status(), ensure_ascii=False))
|
||||||
|
return 0
|
||||||
|
if args.request:
|
||||||
|
print(json.dumps(predict(args.request), ensure_ascii=False))
|
||||||
|
return 0
|
||||||
|
parser.error("Use --status or --request")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
24
backend/setup_sam3_env.sh
Executable file
24
backend/setup_sam3_env.sh
Executable file
@@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Create the dedicated SAM 3 runtime used by backend/services/sam3_external_worker.py.
|
||||||
|
# Keep Hugging Face tokens outside this repository, for example:
|
||||||
|
# export HF_TOKEN=...
|
||||||
|
# huggingface-cli login --token "$HF_TOKEN"
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ENV_NAME="${SAM3_CONDA_ENV:-sam3}"
|
||||||
|
|
||||||
|
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
|
||||||
|
|
||||||
|
if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then
|
||||||
|
conda create -y -n "$ENV_NAME" python=3.12
|
||||||
|
fi
|
||||||
|
|
||||||
|
conda activate "$ENV_NAME"
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install "setuptools<81"
|
||||||
|
python -m pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128
|
||||||
|
python -m pip install opencv-python pillow numpy huggingface_hub einops pycocotools psutil
|
||||||
|
python -m pip install git+https://github.com/facebookresearch/sam3.git
|
||||||
|
|
||||||
|
python /home/wkmgc/Desktop/Seg_Server/backend/services/sam3_external_worker.py --status
|
||||||
@@ -9,3 +9,7 @@ TASK_STATUS_QUEUED = "queued"
|
|||||||
TASK_STATUS_RUNNING = "running"
|
TASK_STATUS_RUNNING = "running"
|
||||||
TASK_STATUS_SUCCESS = "success"
|
TASK_STATUS_SUCCESS = "success"
|
||||||
TASK_STATUS_FAILED = "failed"
|
TASK_STATUS_FAILED = "failed"
|
||||||
|
TASK_STATUS_CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
|
||||||
|
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def _create_project_and_frame(client):
|
def _create_project_and_frame(client):
|
||||||
@@ -46,6 +47,46 @@ def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
|||||||
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
|
||||||
|
_, frame, _ = _create_project_and_frame(client)
|
||||||
|
calls = {}
|
||||||
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
|
||||||
|
|
||||||
|
def fake_predict_points(model, image, points, labels):
|
||||||
|
calls["shape"] = image.shape
|
||||||
|
calls["points"] = points
|
||||||
|
calls["labels"] = labels
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
|
||||||
|
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
|
||||||
|
],
|
||||||
|
[0.9, 0.01],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
|
||||||
|
|
||||||
|
response = client.post("/api/ai/predict", json={
|
||||||
|
"image_id": frame["id"],
|
||||||
|
"prompt_type": "point",
|
||||||
|
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
|
||||||
|
"options": {
|
||||||
|
"crop_to_prompt": True,
|
||||||
|
"crop_margin": 0.1,
|
||||||
|
"auto_filter_background": True,
|
||||||
|
"min_score": 0.05,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert calls["shape"][0] < 100
|
||||||
|
assert calls["shape"][1] < 200
|
||||||
|
assert calls["labels"] == [1, 0]
|
||||||
|
assert response.json()["scores"] == [0.9]
|
||||||
|
polygon = response.json()["polygons"][0]
|
||||||
|
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
|
||||||
|
|
||||||
|
|
||||||
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||||
_, frame, _ = _create_project_and_frame(client)
|
_, frame, _ = _create_project_and_frame(client)
|
||||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||||
@@ -246,3 +287,62 @@ def test_update_and_delete_annotation_validation(client):
|
|||||||
f"/api/ai/annotations/{saved['id']}",
|
f"/api/ai/annotations/{saved['id']}",
|
||||||
json={"template_id": 999},
|
json={"template_id": 999},
|
||||||
).status_code == 404
|
).status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
||||||
|
project, frame, template = _create_project_and_frame(client)
|
||||||
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||||
|
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
|
||||||
|
ok, encoded = cv2.imencode(".png", mask)
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/ai/import-gt-mask",
|
||||||
|
data={
|
||||||
|
"project_id": str(project["id"]),
|
||||||
|
"frame_id": str(frame["id"]),
|
||||||
|
"template_id": str(template["id"]),
|
||||||
|
"label": "Imported GT",
|
||||||
|
"color": "#22c55e",
|
||||||
|
},
|
||||||
|
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
body = response.json()
|
||||||
|
assert len(body) == 1
|
||||||
|
assert body[0]["project_id"] == project["id"]
|
||||||
|
assert body[0]["frame_id"] == frame["id"]
|
||||||
|
assert body[0]["template_id"] == template["id"]
|
||||||
|
assert body[0]["mask_data"]["label"] == "Imported GT"
|
||||||
|
assert body[0]["mask_data"]["source"] == "gt_mask"
|
||||||
|
assert body[0]["mask_data"]["gt_label_value"] == 255
|
||||||
|
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
|
||||||
|
assert len(body[0]["points"]) == 1
|
||||||
|
assert 0.0 <= body[0]["points"][0][0] <= 1.0
|
||||||
|
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_gt_mask_splits_label_values(client):
|
||||||
|
project, frame, _ = _create_project_and_frame(client)
|
||||||
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||||
|
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
||||||
|
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
|
||||||
|
ok, encoded = cv2.imencode(".png", mask)
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/ai/import-gt-mask",
|
||||||
|
data={
|
||||||
|
"project_id": str(project["id"]),
|
||||||
|
"frame_id": str(frame["id"]),
|
||||||
|
"label": "GT Class",
|
||||||
|
},
|
||||||
|
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
|
||||||
|
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
||||||
|
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
||||||
|
assert all(len(item["points"]) == 1 for item in body)
|
||||||
|
|||||||
@@ -59,7 +59,9 @@ def test_dashboard_overview_uses_persisted_records(client, db_session):
|
|||||||
"name": "Pending Project",
|
"name": "Pending Project",
|
||||||
"progress": 35,
|
"progress": 35,
|
||||||
"status": "正在使用 FFmpeg/OpenCV 拆帧",
|
"status": "正在使用 FFmpeg/OpenCV 拆帧",
|
||||||
|
"raw_status": "running",
|
||||||
"frame_count": 0,
|
"frame_count": 0,
|
||||||
|
"error": None,
|
||||||
"updated_at": body["tasks"][0]["updated_at"],
|
"updated_at": body["tasks"][0]["updated_at"],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
import zipfile
|
import zipfile
|
||||||
|
import json
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def _seed_export_data(client):
|
def _seed_export_data(client):
|
||||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||||
@@ -58,7 +62,55 @@ def test_export_masks_zip(client):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"].startswith("application/zip")
|
assert response.headers["content-type"].startswith("application/zip")
|
||||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||||
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
|
assert archive.namelist() == [
|
||||||
|
f"mask_{annotation['id']:06d}.png",
|
||||||
|
"semantic_frame_000000.png",
|
||||||
|
"semantic_classes.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_masks_uses_z_index_for_semantic_fusion(client):
|
||||||
|
project = client.post("/api/projects", json={"name": "Fusion Project"}).json()
|
||||||
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||||
|
"project_id": project["id"],
|
||||||
|
"frame_index": 0,
|
||||||
|
"image_url": "frames/0.jpg",
|
||||||
|
"width": 20,
|
||||||
|
"height": 20,
|
||||||
|
}).json()
|
||||||
|
low = client.post("/api/ai/annotate", json={
|
||||||
|
"project_id": project["id"],
|
||||||
|
"frame_id": frame["id"],
|
||||||
|
"mask_data": {
|
||||||
|
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||||
|
"label": "Low",
|
||||||
|
"color": "#00ff00",
|
||||||
|
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
|
||||||
|
},
|
||||||
|
}).json()
|
||||||
|
high = client.post("/api/ai/annotate", json={
|
||||||
|
"project_id": project["id"],
|
||||||
|
"frame_id": frame["id"],
|
||||||
|
"mask_data": {
|
||||||
|
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
|
||||||
|
"label": "High",
|
||||||
|
"color": "#ff0000",
|
||||||
|
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
|
||||||
|
},
|
||||||
|
}).json()
|
||||||
|
|
||||||
|
response = client.get(f"/api/export/{project['id']}/masks")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||||
|
assert f"mask_{low['id']:06d}.png" in archive.namelist()
|
||||||
|
assert f"mask_{high['id']:06d}.png" in archive.namelist()
|
||||||
|
legend = json.loads(archive.read("semantic_classes.json"))
|
||||||
|
high_value = next(item["value"] for item in legend["classes"] if item["key"] == "class:high")
|
||||||
|
semantic_bytes = np.frombuffer(archive.read("semantic_frame_000000.png"), dtype=np.uint8)
|
||||||
|
semantic = cv2.imdecode(semantic_bytes, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
assert semantic[10, 10] == high_value
|
||||||
|
|
||||||
|
|
||||||
def test_export_missing_project_returns_404(client):
|
def test_export_missing_project_returns_404(client):
|
||||||
|
|||||||
@@ -140,3 +140,25 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
|||||||
assert project_detail["status"] == "ready"
|
assert project_detail["status"] == "ready"
|
||||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||||
|
from models import ProcessingTask
|
||||||
|
from services.media_task_runner import run_parse_media_task
|
||||||
|
|
||||||
|
task = ProcessingTask(
|
||||||
|
task_type="parse_video",
|
||||||
|
status="cancelled",
|
||||||
|
progress=100,
|
||||||
|
message="任务已取消",
|
||||||
|
project_id=1,
|
||||||
|
payload={"source_type": "video"},
|
||||||
|
)
|
||||||
|
db_session.add(task)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(task)
|
||||||
|
|
||||||
|
result = run_parse_media_task(db_session, task.id)
|
||||||
|
|
||||||
|
assert result["status"] == "cancelled"
|
||||||
|
assert result["message"] == "任务已取消"
|
||||||
|
|||||||
@@ -26,6 +26,25 @@ def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
|
|||||||
assert payload["status"] == "解析完成"
|
assert payload["status"] == "解析完成"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_progress_payload_marks_cancelled_tasks():
|
||||||
|
task = SimpleNamespace(
|
||||||
|
id=13,
|
||||||
|
project_id=7,
|
||||||
|
project=SimpleNamespace(name="demo.mp4"),
|
||||||
|
status="cancelled",
|
||||||
|
progress=100,
|
||||||
|
message="任务已取消",
|
||||||
|
error="Cancelled by user",
|
||||||
|
updated_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = task_progress_payload(task)
|
||||||
|
|
||||||
|
assert payload["type"] == "cancelled"
|
||||||
|
assert payload["status"] == "任务已取消"
|
||||||
|
assert payload["error"] == "Cancelled by user"
|
||||||
|
|
||||||
|
|
||||||
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
|
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
|
|||||||
112
backend/tests/test_sam3_engine.py
Normal file
112
backend/tests/test_sam3_engine.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from services.sam3_engine import SAM3Engine
|
||||||
|
|
||||||
|
|
||||||
|
class _Completed:
|
||||||
|
def __init__(self, returncode=0, stdout="", stderr=""):
|
||||||
|
self.returncode = returncode
|
||||||
|
self.stdout = stdout
|
||||||
|
self.stderr = stderr
|
||||||
|
|
||||||
|
|
||||||
|
def _external_settings(monkeypatch, python_path: Path):
|
||||||
|
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
|
||||||
|
python_path.chmod(0o755)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.TORCH_AVAILABLE", False)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_enabled", True)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_python", str(python_path))
|
||||||
|
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
|
||||||
|
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||||
|
_external_settings(monkeypatch, tmp_path / "python")
|
||||||
|
|
||||||
|
def fake_run(args, **_kwargs):
|
||||||
|
assert "--status" in args
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"available": True,
|
||||||
|
"package_available": True,
|
||||||
|
"python_ok": True,
|
||||||
|
"torch_ok": True,
|
||||||
|
"cuda_available": True,
|
||||||
|
"device": "cuda",
|
||||||
|
"message": "ready",
|
||||||
|
}))
|
||||||
|
|
||||||
|
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
status = SAM3Engine().status()
|
||||||
|
|
||||||
|
assert status["available"] is True
|
||||||
|
assert status["external_available"] is True
|
||||||
|
assert status["package_available"] is True
|
||||||
|
assert status["python_ok"] is True
|
||||||
|
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||||
|
_external_settings(monkeypatch, tmp_path / "python")
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_run(args, **_kwargs):
|
||||||
|
calls.append(args)
|
||||||
|
if "--status" in args:
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"available": True,
|
||||||
|
"package_available": True,
|
||||||
|
"python_ok": True,
|
||||||
|
"torch_ok": True,
|
||||||
|
"cuda_available": True,
|
||||||
|
"device": "cuda",
|
||||||
|
"message": "ready",
|
||||||
|
}))
|
||||||
|
request_path = Path(args[-1])
|
||||||
|
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||||
|
assert request["text"] == "vessel"
|
||||||
|
assert request["confidence_threshold"] == 0.4
|
||||||
|
assert Path(request["image_path"]).exists()
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||||
|
"scores": [0.91],
|
||||||
|
}))
|
||||||
|
|
||||||
|
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
polygons, scores = SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), " vessel ")
|
||||||
|
|
||||||
|
assert polygons == [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]
|
||||||
|
assert scores == [0.91]
|
||||||
|
assert any("--request" in args for args in calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||||
|
_external_settings(monkeypatch, tmp_path / "python")
|
||||||
|
|
||||||
|
def fake_run(args, **_kwargs):
|
||||||
|
if "--status" in args:
|
||||||
|
return _Completed(stdout=json.dumps({
|
||||||
|
"available": True,
|
||||||
|
"package_available": True,
|
||||||
|
"python_ok": True,
|
||||||
|
"torch_ok": True,
|
||||||
|
"cuda_available": True,
|
||||||
|
"device": "cuda",
|
||||||
|
"message": "ready",
|
||||||
|
}))
|
||||||
|
return _Completed(returncode=1, stderr=json.dumps({"error": "HF access denied"}))
|
||||||
|
|
||||||
|
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||||
|
|
||||||
|
try:
|
||||||
|
SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), "vessel")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
assert "HF access denied" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected SAM 3 external inference failure.")
|
||||||
104
backend/tests/test_tasks.py
Normal file
104
backend/tests/test_tasks.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from models import ProcessingTask
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_task_revokes_celery_and_updates_project(client, db_session, monkeypatch):
|
||||||
|
project = client.post("/api/projects", json={
|
||||||
|
"name": "Cancelable",
|
||||||
|
"video_path": "uploads/1/clip.mp4",
|
||||||
|
"status": "parsing",
|
||||||
|
}).json()
|
||||||
|
task = ProcessingTask(
|
||||||
|
task_type="parse_video",
|
||||||
|
status="running",
|
||||||
|
progress=35,
|
||||||
|
message="正在使用 FFmpeg/OpenCV 拆帧",
|
||||||
|
project_id=project["id"],
|
||||||
|
celery_task_id="celery-1",
|
||||||
|
payload={"source_type": "video"},
|
||||||
|
)
|
||||||
|
db_session.add(task)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(task)
|
||||||
|
|
||||||
|
revoked = []
|
||||||
|
published = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"routers.tasks.celery_app.control.revoke",
|
||||||
|
lambda celery_id, terminate, signal: revoked.append((celery_id, terminate, signal)),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append(event_task.status))
|
||||||
|
|
||||||
|
response = client.post(f"/api/tasks/{task.id}/cancel")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["status"] == "cancelled"
|
||||||
|
assert body["progress"] == 100
|
||||||
|
assert body["message"] == "任务已取消"
|
||||||
|
assert body["error"] == "Cancelled by user"
|
||||||
|
assert revoked == [("celery-1", True, "SIGTERM")]
|
||||||
|
assert published == ["cancelled"]
|
||||||
|
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
|
||||||
|
project = client.post("/api/projects", json={
|
||||||
|
"name": "Retryable",
|
||||||
|
"video_path": "uploads/2/clip.mp4",
|
||||||
|
"source_type": "video",
|
||||||
|
"status": "error",
|
||||||
|
}).json()
|
||||||
|
task = ProcessingTask(
|
||||||
|
task_type="parse_video",
|
||||||
|
status="failed",
|
||||||
|
progress=100,
|
||||||
|
message="解析失败",
|
||||||
|
error="ffmpeg failed",
|
||||||
|
project_id=project["id"],
|
||||||
|
payload={"source_type": "video"},
|
||||||
|
)
|
||||||
|
db_session.add(task)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(task)
|
||||||
|
|
||||||
|
class FakeAsyncResult:
|
||||||
|
id = "celery-retry"
|
||||||
|
|
||||||
|
queued = []
|
||||||
|
published = []
|
||||||
|
monkeypatch.setattr("routers.tasks.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||||
|
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append((event_task.id, event_task.status)))
|
||||||
|
|
||||||
|
response = client.post(f"/api/tasks/{task.id}/retry")
|
||||||
|
|
||||||
|
assert response.status_code == 202
|
||||||
|
body = response.json()
|
||||||
|
assert body["id"] != task.id
|
||||||
|
assert body["status"] == "queued"
|
||||||
|
assert body["progress"] == 0
|
||||||
|
assert body["celery_task_id"] == "celery-retry"
|
||||||
|
assert body["payload"]["retry_of"] == task.id
|
||||||
|
assert queued == [body["id"]]
|
||||||
|
assert published[0] == (body["id"], "queued")
|
||||||
|
assert published[-1] == (body["id"], "queued")
|
||||||
|
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_actions_reject_invalid_states(client, db_session):
|
||||||
|
project = client.post("/api/projects", json={
|
||||||
|
"name": "Done",
|
||||||
|
"video_path": "uploads/3/clip.mp4",
|
||||||
|
}).json()
|
||||||
|
task = ProcessingTask(
|
||||||
|
task_type="parse_video",
|
||||||
|
status="success",
|
||||||
|
progress=100,
|
||||||
|
project_id=project["id"],
|
||||||
|
payload={"source_type": "video"},
|
||||||
|
)
|
||||||
|
db_session.add(task)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(task)
|
||||||
|
|
||||||
|
assert client.post(f"/api/tasks/{task.id}/cancel").status_code == 409
|
||||||
|
assert client.post(f"/api/tasks/{task.id}/retry").status_code == 409
|
||||||
@@ -38,21 +38,21 @@ Word 方案描述的理想系统包含:
|
|||||||
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` |
|
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` |
|
||||||
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
|
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
|
||||||
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` |
|
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` |
|
||||||
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 依赖官方运行环境,当前环境不满足时会标为不可用 |
|
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 通过独立 Python 3.12 环境桥接,状态会检查 Python/CUDA/包/HF gated 权重访问 |
|
||||||
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;重叠裁决算法未落地 |
|
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
|
||||||
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新和当前帧删除;逐点几何编辑未落地 |
|
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
|
||||||
| COCO / Mask 导出 | 部分落地 | `backend/routers/export.py`;COCO JSON 前端按钮已接入,PNG mask ZIP 尚未提供前端按钮 |
|
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON 和 PNG mask ZIP 前端按钮均已接入,ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
|
||||||
|
|
||||||
## 当前代码尚未落地的目标
|
## 当前代码尚未落地的目标
|
||||||
|
|
||||||
- SAM 3:当前已提供 `sam3_engine.py` 适配入口和状态检测;要实际运行仍需安装官方 `facebookresearch/sam3` 依赖并满足 Python 3.12+、PyTorch 2.7+、CUDA 12.6+。
|
- SAM 3:当前已提供 `sam3_engine.py` 外部环境桥接、`sam3_external_worker.py` 和 `setup_sam3_env.sh`;本机 `sam3` 环境已满足 Python 3.12、PyTorch 2.10/cu128、CUDA/GPU 和官方包导入。真实推理仍取决于 Hugging Face `facebook/sam3` gated 权重访问授权;官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2 tiny。
|
||||||
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。
|
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。
|
||||||
- GT mask 导入:当前前端没有 GT Label 导入入口,后端也没有对应路由。
|
- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point;骨架提取、HDBSCAN 和模板自动映射尚未实现。
|
||||||
- Mask 到点区域的拓扑降维:当前没有距离变换、骨架提取、HDBSCAN 等实现。
|
- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。
|
||||||
- 类别优先级融合:模板有 z-index,但没有后端融合算法。
|
- 类别优先级融合:PNG mask 导出时已按 zIndex 生成语义融合 mask;前端裁决预览尚未实现。
|
||||||
- 撤销/重做:工具栏有按钮,但没有历史栈。
|
- 撤销/重做:当前已有全局 mask 历史栈。
|
||||||
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask,并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
|
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask,并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
|
||||||
|
|
||||||
## 结论
|
## 结论
|
||||||
|
|
||||||
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可实时查看任务进度、可浏览项目帧、可维护模板、可点/框 AI 推理、可保存标注、可导出 COCO、可查看 Dashboard 后端概览”的全栈雏形,但离 Word 中描述的完整智能标注系统还有明显差距。下一阶段最重要的是继续补齐手工绘制、撤销重做和真实语义文本分割。
|
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可取消/重试任务、可查看失败详情、可实时查看任务进度、可浏览项目帧、可维护模板、可手工绘制、可逐点编辑 polygon、可边中点插点、可多 polygon 子区域编辑、可区域合并/去除、可点/框 AI 推理、可对点/框 prompt 做裁剪推理和背景过滤、可导入多类别 GT mask、可编辑 seed point、可保存标注、可导出 COCO/语义 mask ZIP、可查看 Dashboard 后端概览”的全栈雏形。下一阶段最重要的是继续补齐 Hugging Face SAM 3 权重授权后的真实语义文本分割 smoke test、复杂洞结构编辑和 GT mask 骨架/聚类增强。
|
||||||
|
|||||||
@@ -71,6 +71,7 @@
|
|||||||
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
|
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
|
||||||
6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。
|
6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。
|
||||||
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
|
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
|
||||||
|
8. Dashboard 可通过 `POST /api/tasks/{id}/cancel` 取消 queued/running 任务,通过 `POST /api/tasks/{id}/retry` 重试 failed/cancelled 任务,并用 `GET /api/tasks/{id}` 查看失败详情。
|
||||||
|
|
||||||
### 工作区浏览
|
### 工作区浏览
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@
|
|||||||
## 当前主要风险点
|
## 当前主要风险点
|
||||||
|
|
||||||
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
|
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
|
||||||
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖时走 SAM 3;当前环境若不满足会在模型状态中标明不可用。
|
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖、并具备 Hugging Face gated 权重访问时走 SAM 3;当前状态接口会分别暴露外部 Python 环境、CUDA、包导入和 checkpoint access 是否满足。
|
||||||
- 工作区顶部“导出 JSON 标注集”和“结构化归档保存”已接入导出、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。撤销重做和手工绘制仍未持久化。
|
- 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。
|
||||||
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。
|
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。
|
||||||
- 后端路由大多未做真实鉴权。
|
- 后端路由大多未做真实鉴权。
|
||||||
|
|||||||
@@ -30,8 +30,11 @@
|
|||||||
| 元素 | 状态 | 说明 |
|
| 元素 | 状态 | 说明 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
|
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
|
||||||
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running 任务生成 |
|
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/failed/cancelled 任务生成 |
|
||||||
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`,FastAPI 广播 progress/complete/error |
|
| 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` |
|
||||||
|
| 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 |
|
||||||
|
| 失败详情 | 真实可用 | 任务详情按钮调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间 |
|
||||||
|
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`,FastAPI 广播 progress/complete/error/cancelled |
|
||||||
| 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 |
|
| 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 |
|
||||||
| 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录;WebSocket status/complete/error 会继续追加 |
|
| 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录;WebSocket status/complete/error 会继续追加 |
|
||||||
|
|
||||||
@@ -60,6 +63,8 @@
|
|||||||
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
|
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
|
||||||
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
|
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
|
||||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||||
|
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
|
||||||
|
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 |
|
||||||
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
|
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
|
||||||
|
|
||||||
## CanvasArea 画布
|
## CanvasArea 画布
|
||||||
@@ -70,10 +75,13 @@
|
|||||||
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
|
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
|
||||||
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
|
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
|
||||||
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
|
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
|
||||||
| 正向/反向选点 | 部分可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;需点击归档保存才持久化 |
|
| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 |
|
||||||
| 框选 | 部分可用 | UI 能画框,并把框坐标归一化后调用后端推理;需点击归档保存才持久化 |
|
| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
|
||||||
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
|
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
|
||||||
| Mask 渲染 | 部分可用 | 前端会把推理/已保存标注转成 Konva `pathData` 渲染 |
|
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后 Enter 完成;矩形/圆/线拖拽生成 polygon;点工具生成小区域;均写入 `Mask.segmentation`,可归档保存 |
|
||||||
|
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 |
|
||||||
|
| Polygon 逐点编辑 | 真实可用 | 点击 mask 后显示 polygon 顶点;拖动顶点会重算 `pathData/segmentation/bbox/area`,已保存 mask 标为 dirty;选中顶点后 Delete/Backspace 可删点但保留至少三点 |
|
||||||
|
| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty,归档保存会更新后端 |
|
||||||
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||||
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
|
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
|
||||||
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
|
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
|
||||||
@@ -84,11 +92,11 @@
|
|||||||
| 元素 | 状态 | 说明 |
|
| 元素 | 状态 | 说明 |
|
||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
||||||
| 多边形/矩形/圆/点/线 | Mock / UI-only | 只切换 activeTool,没有对应绘制逻辑 |
|
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
|
||||||
| 区域合并/去除 | Mock / UI-only | 只切换 activeTool,没有后端或前端算法 |
|
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask |
|
||||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
||||||
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
||||||
| 撤销/重做 | Mock / UI-only | 按钮无事件 |
|
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y |
|
||||||
|
|
||||||
## FrameTimeline 时间轴
|
## FrameTimeline 时间轴
|
||||||
|
|
||||||
@@ -117,10 +125,11 @@
|
|||||||
|------|------|------|
|
|------|------|------|
|
||||||
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry |
|
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry |
|
||||||
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
|
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
|
||||||
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且运行环境满足官方依赖时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
|
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
|
||||||
| 参数开关 | Mock / UI-only | `cropMode`、`autoDeleteBg` 只改本地状态 |
|
| 参数开关 | 真实可用 | `cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 |
|
||||||
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
|
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
|
||||||
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
||||||
|
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
|
||||||
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
|
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
|
||||||
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store,但没有保存/确认流程 |
|
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store,但没有保存/确认流程 |
|
||||||
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
|
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
|
||||||
@@ -141,6 +150,6 @@
|
|||||||
|
|
||||||
## 总体结论
|
## 总体结论
|
||||||
|
|
||||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区点/框 AI 推理、标注保存/回显、COCO 导出、模板 CRUD。
|
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||||
|
|
||||||
当前最主要的 Mock 或未打通链路是:撤销重做、手工几何绘制、GT 导入、mask 降维点区域、真正的文本语义分割和语义优先级融合。
|
当前最主要的 Mock 或未打通链路是:polygon 插点/边编辑增强、真正的文本语义分割、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和若干检查面板指标。
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ Authorization: Bearer <token>
|
|||||||
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
||||||
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
|
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
|
||||||
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
||||||
|
| `cancelTask(taskId)` | `POST /api/tasks/{task_id}/cancel` | 对齐 | 取消 queued/running 任务,后端写 cancelled 并尝试 revoke Celery |
|
||||||
|
| `retryTask(taskId)` | `POST /api/tasks/{task_id}/retry` | 对齐 | 对 failed/cancelled 任务创建新的 queued 重试任务 |
|
||||||
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
|
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
|
||||||
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
|
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
|
||||||
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
|
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
|
||||||
@@ -41,8 +43,10 @@ Authorization: Bearer <token>
|
|||||||
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
|
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
|
||||||
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
|
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
|
||||||
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
|
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
|
||||||
|
| `importGtMask(file, projectId, frameId, templateId?)` | `POST /api/ai/import-gt-mask` | 对齐 | multipart 上传 GT mask,后端按非零像素值/连通域生成 polygon 标注和 seed point |
|
||||||
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
|
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
|
||||||
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
|
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
|
||||||
|
| `exportMasks(projectId)` | `GET /api/export/{projectId}/masks` | 对齐 | 下载单标注 mask、语义融合 mask 和类别映射 ZIP |
|
||||||
|
|
||||||
## 后端 FastAPI 接口
|
## 后端 FastAPI 接口
|
||||||
|
|
||||||
@@ -69,10 +73,13 @@ Authorization: Bearer <token>
|
|||||||
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
|
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
|
||||||
| GET | `/api/tasks` | 查询后台任务列表 |
|
| GET | `/api/tasks` | 查询后台任务列表 |
|
||||||
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
|
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
|
||||||
|
| POST | `/api/tasks/{task_id}/cancel` | 取消后台任务 |
|
||||||
|
| POST | `/api/tasks/{task_id}/retry` | 重试失败或取消的后台任务 |
|
||||||
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
|
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
|
||||||
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
|
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
|
||||||
| POST | `/api/ai/auto` | 自动分割 |
|
| POST | `/api/ai/auto` | 自动分割 |
|
||||||
| POST | `/api/ai/annotate` | 保存 AI 标注 |
|
| POST | `/api/ai/annotate` | 保存 AI 标注 |
|
||||||
|
| POST | `/api/ai/import-gt-mask` | 导入 GT mask 并生成标注/seed point |
|
||||||
| GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 |
|
| GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 |
|
||||||
| PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 |
|
| PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 |
|
||||||
| DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 |
|
| DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 |
|
||||||
@@ -143,7 +150,13 @@ Authorization: Bearer <token>
|
|||||||
|
|
||||||
- `point`
|
- `point`
|
||||||
- `box`
|
- `box`
|
||||||
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation
|
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation。SAM 3 真实可用性由 `/api/ai/models/status` 中的外部环境和 checkpoint access 状态决定。
|
||||||
|
|
||||||
|
可选 `options` 字段:
|
||||||
|
|
||||||
|
- `crop_to_prompt`:对 point/box prompt 按锚点或框附近区域裁剪后推理,再把 polygon 回映射到原图坐标。
|
||||||
|
- `auto_filter_background`:过滤低分结果,并移除包含负向点的 polygon。
|
||||||
|
- `min_score`:配合 `auto_filter_background` 使用的最低置信度阈值。
|
||||||
|
|
||||||
后端响应:
|
后端响应:
|
||||||
|
|
||||||
@@ -181,13 +194,19 @@ Authorization: Bearer <token>
|
|||||||
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`。
|
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`。
|
||||||
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`。
|
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`。
|
||||||
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`。
|
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`。
|
||||||
|
- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。
|
||||||
|
- `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`。
|
||||||
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。
|
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。
|
||||||
- `getTask()` 已接入 `GET /api/tasks/{taskId}`。
|
- `getTask()` 已接入 `GET /api/tasks/{taskId}`。
|
||||||
|
- `cancelTask()` 已接入 `POST /api/tasks/{taskId}/cancel`。
|
||||||
|
- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`。
|
||||||
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
|
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
|
||||||
- 工作区导出按钮已调用 `exportCoco()`,并会先保存未归档 mask。
|
- Dashboard 任务列表已展示 queued/running/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情。
|
||||||
|
- 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。
|
||||||
|
- PNG mask ZIP 已包含每帧 `semantic_frame_*.png` 和 `semantic_classes.json`,重叠区域按 zIndex 裁决。
|
||||||
|
|
||||||
## 仍需处理的接口问题
|
## 仍需处理的接口问题
|
||||||
|
|
||||||
- WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。
|
- WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。
|
||||||
- Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`;FastAPI 订阅后广播到 `/ws/progress`。
|
- Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`;FastAPI 订阅后广播到 `/ws/progress`。
|
||||||
- 已保存标注目前支持分类级更新和整帧清空删除;逐点几何编辑器尚未实现。
|
- 已保存标注目前支持分类级更新、polygon 顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑后的 PATCH 更新和整帧清空删除;复杂洞结构的专业编辑仍未实现。
|
||||||
|
|||||||
@@ -16,8 +16,8 @@
|
|||||||
|
|
||||||
剩余边界:
|
剩余边界:
|
||||||
|
|
||||||
1. SAM 3 真实推理需要独立满足官方 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ 环境。
|
1. SAM 3 已完成独立 Python 3.12 环境安装脚本、外部 worker 桥接和状态检查;真实推理还需要 Hugging Face `facebook/sam3` gated 权重授权通过后执行 smoke test。
|
||||||
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器尚未实现。
|
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器已支持顶点拖动/删除、边中点插入和多 polygon 子区域选择编辑,复杂洞结构仍待增强。
|
||||||
|
|
||||||
## 阶段 2:打通标注保存(已完成基础闭环)
|
## 阶段 2:打通标注保存(已完成基础闭环)
|
||||||
|
|
||||||
@@ -34,16 +34,22 @@
|
|||||||
剩余建议:
|
剩余建议:
|
||||||
|
|
||||||
1. 加入保存冲突处理和批量保存错误提示。
|
1. 加入保存冲突处理和批量保存错误提示。
|
||||||
2. 增加逐点几何编辑器,让已保存 mask 的 polygon 本身可以被修改后 PATCH。
|
2. 逐点几何编辑器已支持拖动/删除顶点、边中点插入新点和多 polygon 子区域编辑;后续增强为复杂洞结构编辑。
|
||||||
|
3. 区域合并/去除已支持基础 union/difference;后续增强为更明确的多选列表、操作预览和冲突确认。
|
||||||
|
|
||||||
## 阶段 3:接入导出按钮(已完成 COCO JSON)
|
## 阶段 3:接入导出按钮(已完成 COCO JSON 和 PNG Mask ZIP)
|
||||||
|
|
||||||
当前工作区“导出 JSON 标注集”会先保存未归档 mask,再调用 COCO 导出接口。
|
当前工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”都会先保存未归档 mask,再调用后端导出接口。
|
||||||
|
|
||||||
建议:
|
已完成:
|
||||||
|
|
||||||
1. 增加“导出 PNG Mask ZIP”按钮,调用 `/api/export/{projectId}/masks`。
|
1. COCO JSON 调用 `/api/export/{projectId}/coco`。
|
||||||
2. 无标注时给出更明确的空导出提示。
|
2. PNG Mask ZIP 调用 `/api/export/{projectId}/masks`。
|
||||||
|
3. ZIP 内保留单标注二值 `mask_*.png`,同时输出 `semantic_frame_*.png` 和 `semantic_classes.json`。
|
||||||
|
|
||||||
|
剩余建议:
|
||||||
|
|
||||||
|
1. 无标注时给出更明确的空导出提示。
|
||||||
|
|
||||||
## 阶段 4:替换 Dashboard mock
|
## 阶段 4:替换 Dashboard mock
|
||||||
|
|
||||||
@@ -52,13 +58,18 @@
|
|||||||
已完成:
|
已完成:
|
||||||
|
|
||||||
- 聚合项目、帧、标注、模板数量和主机 load average。
|
- 聚合项目、帧、标注、模板数量和主机 load average。
|
||||||
- 按 `processing_tasks` queued/running 任务生成解析队列。
|
- 按 `processing_tasks` queued/running/failed/cancelled 任务生成解析队列。
|
||||||
- 按最近任务、项目、标注、模板记录生成活动流。
|
- 按最近任务、项目、标注、模板记录生成活动流。
|
||||||
|
|
||||||
|
已完成补充:
|
||||||
|
|
||||||
|
1. Dashboard 对 queued/running 任务提供取消按钮。
|
||||||
|
2. Dashboard 对 failed/cancelled 任务提供重试按钮。
|
||||||
|
3. Dashboard 详情弹窗展示任务 error、payload、result、Celery ID 和时间。
|
||||||
|
|
||||||
剩余建议:
|
剩余建议:
|
||||||
|
|
||||||
1. 为任务增加取消、重试和失败详情 UI。
|
1. 为 Dashboard 增加任务历史筛选。
|
||||||
2. 为 Dashboard 增加任务历史筛选和失败详情入口。
|
|
||||||
|
|
||||||
## 阶段 5:异步拆帧和进度
|
## 阶段 5:异步拆帧和进度
|
||||||
|
|
||||||
@@ -72,44 +83,55 @@ Word 方案中提到 Celery + Redis。当前已经有 Celery app、worker task
|
|||||||
4. worker 写 PostgreSQL 任务进度。
|
4. worker 写 PostgreSQL 任务进度。
|
||||||
5. worker 发布 Redis `seg:progress`,FastAPI 广播到 `/ws/progress`。
|
5. worker 发布 Redis `seg:progress`,FastAPI 广播到 `/ws/progress`。
|
||||||
|
|
||||||
|
已完成补充:
|
||||||
|
|
||||||
|
1. `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,并尝试 revoke Celery。
|
||||||
|
2. `POST /api/tasks/{task_id}/retry` 为 failed/cancelled 任务创建新的 queued 任务。
|
||||||
|
3. worker 在关键阶段检查 cancelled 状态,避免取消后继续写帧。
|
||||||
|
4. Redis/WebSocket 进度事件增加 `cancelled` 类型。
|
||||||
|
|
||||||
|
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务历史筛选和更细的 worker 中断粒度。
|
||||||
|
|
||||||
|
## 阶段 6:GT 导入与点区域(已完成基础增强版)
|
||||||
|
|
||||||
|
Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前已经完成基础增强版:导入二值/标签 mask 图片后,后端按非零像素值拆分类别,再按连通域生成 polygon 标注,并用距离变换提取一个正向 seed point。
|
||||||
|
|
||||||
|
已完成:
|
||||||
|
|
||||||
|
1. 工作区提供“导入 GT Mask”入口。
|
||||||
|
2. 前端调用 `POST /api/ai/import-gt-mask` multipart 接口。
|
||||||
|
3. 后端按非零像素值拆分多类别 mask。
|
||||||
|
4. 后端使用 OpenCV contour 提取每个类别下的连通域。
|
||||||
|
5. 后端使用 distance transform 生成 `points` seed。
|
||||||
|
6. 导入结果写入 `annotations` 表并回显为工作区 mask。
|
||||||
|
7. 前端把 seed point 转为像素坐标显示在 Canvas 上,拖动后会标记标注为 dirty 并可归档保存。
|
||||||
|
|
||||||
剩余建议:
|
剩余建议:
|
||||||
|
|
||||||
1. 为任务增加取消、重试和失败详情接口。
|
1. 增加骨架提取和聚类增强。
|
||||||
2. 前端 Dashboard 保留轮询兜底,并补充失败详情 UI。
|
2. 为多类别像素值提供模板分类自动映射规则。
|
||||||
|
|
||||||
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务控制。
|
## 阶段 7:模板优先级融合(已完成导出侧裁决)
|
||||||
|
|
||||||
## 阶段 6:GT 导入与点区域
|
当前导出 PNG Mask ZIP 时已经按 class/template z-index 做重叠裁决,从低到高覆盖,生成每帧 `semantic_frame_*.png`。
|
||||||
|
|
||||||
这是 Word 方案中最复杂的部分,当前完全未实现。
|
已完成:
|
||||||
|
|
||||||
建议拆成小步:
|
|
||||||
|
|
||||||
1. 先支持上传二值/多类别 mask。
|
|
||||||
2. 后端按类别提取 connected components。
|
|
||||||
3. 用 OpenCV distance transform 找正向点。
|
|
||||||
4. 暂时不做骨架/HDBSCAN,先生成最小可用点集。
|
|
||||||
5. 前端以可拖拽点显示并保存。
|
|
||||||
6. 后续再做骨架和聚类增强。
|
|
||||||
|
|
||||||
## 阶段 7:模板优先级融合
|
|
||||||
|
|
||||||
当前模板有 z-index,但没有真正用于语义冲突裁决。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
1. 标注保存时记录 template class id / name / zIndex。
|
1. 标注保存时记录 template class id / name / zIndex。
|
||||||
2. 导出 mask 时按 zIndex 从低到高覆盖。
|
2. 导出 mask 时按 zIndex 从低到高覆盖。
|
||||||
3. 同类 mask 做 union。
|
3. 同类语义值在融合图中共享同一个 class value。
|
||||||
4. 跨类重叠由高 zIndex 覆盖低 zIndex。
|
4. 跨类重叠由高 zIndex 覆盖低 zIndex。
|
||||||
|
|
||||||
这一步完成后,系统才真正符合“语义分割一个像素一个类别”的目标。
|
剩余建议:
|
||||||
|
|
||||||
|
1. 在前端预览重叠裁决结果。
|
||||||
|
2. 对多帧多类导出增加颜色 palette PNG 或可视化 legend。
|
||||||
|
|
||||||
## 阶段 8:清理 UI 文案与 Mock
|
## 阶段 8:清理 UI 文案与 Mock
|
||||||
|
|
||||||
建议统一这些文案和真实能力:
|
建议统一这些文案和真实能力:
|
||||||
|
|
||||||
- SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。
|
- SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。
|
||||||
- 撤销/重做按钮接历史栈,否则隐藏。
|
- 撤销/重做按钮已接全局 mask 历史栈。
|
||||||
- “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。
|
- “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。
|
||||||
- AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。
|
- AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。
|
||||||
|
|||||||
@@ -31,6 +31,9 @@
|
|||||||
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。
|
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。
|
||||||
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
|
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
|
||||||
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
|
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
|
||||||
|
- 后端支持 `POST /api/tasks/{task_id}/cancel` 取消 queued/running 任务,写入 `cancelled` 状态并尝试 revoke Celery。
|
||||||
|
- 后端支持 `POST /api/tasks/{task_id}/retry` 对 failed/cancelled 任务创建新的 queued 任务。
|
||||||
|
- worker 会在关键阶段检查任务是否已取消,取消后停止继续写帧。
|
||||||
|
|
||||||
## R4 工作区与帧浏览
|
## R4 工作区与帧浏览
|
||||||
|
|
||||||
@@ -46,7 +49,13 @@
|
|||||||
- 工具栏可以切换当前 active tool。
|
- 工具栏可以切换当前 active tool。
|
||||||
- 正向点、反向点、框选工具会影响 Canvas 交互。
|
- 正向点、反向点、框选工具会影响 Canvas 交互。
|
||||||
- 魔法棒按钮切换到 AI 页面。
|
- 魔法棒按钮切换到 AI 页面。
|
||||||
- 多边形、矩形、圆、点、线、合并、去除、撤销、重做当前只提供 UI 状态或占位按钮,不完成真实绘制/算法。
|
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
|
||||||
|
- 多边形通过点击取点并按 Enter 完成;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
||||||
|
- Canvas 支持点击 mask 进入 polygon 顶点编辑态;拖动顶点会更新 mask 几何并把已保存 mask 标记为 dirty。
|
||||||
|
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
|
||||||
|
- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
|
||||||
|
- 区域合并工具支持多选当前帧 mask,并使用 polygon union 生成合并后的主 mask。
|
||||||
|
- 区域去除工具支持多选当前帧 mask,并从第一个选中的主 mask 中扣除后续选中 mask。
|
||||||
|
|
||||||
## R6 AI 推理
|
## R6 AI 推理
|
||||||
|
|
||||||
@@ -56,7 +65,8 @@
|
|||||||
- 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。
|
- 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||||
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
||||||
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
||||||
- 语义文本提示传 `semantic`;选择 `sam3` 且环境满足依赖时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
- 语义文本提示传 `semantic`;选择 `sam3` 且独立 Python 3.12 环境、CUDA、官方包和 Hugging Face gated 权重访问均满足时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
||||||
|
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
|
||||||
- 后端返回 `polygons` 和 `scores`。
|
- 后端返回 `polygons` 和 `scores`。
|
||||||
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
||||||
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
||||||
@@ -71,6 +81,9 @@
|
|||||||
- 当前前端“结构化归档保存”会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。
|
- 当前前端“结构化归档保存”会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。
|
||||||
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
|
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
|
||||||
- 工作区加载项目帧后会查询已保存标注并回显。
|
- 工作区加载项目帧后会查询已保存标注并回显。
|
||||||
|
- 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask`。
|
||||||
|
- 后端导入 GT mask 时按非零像素值拆分多类别区域,再按连通域生成 polygon 标注,并通过距离变换写入 seed point。
|
||||||
|
- 前端会回显导入标注的 seed point;拖动 seed point 后,已保存标注会变为 dirty,归档保存时会更新后端 `points`。
|
||||||
|
|
||||||
## R8 模板库
|
## R8 模板库
|
||||||
|
|
||||||
@@ -93,9 +106,12 @@
|
|||||||
- Dashboard 显示基础统计、解析队列和活动日志。
|
- Dashboard 显示基础统计、解析队列和活动日志。
|
||||||
- Dashboard 初始数据来自 `GET /api/dashboard/overview`。
|
- Dashboard 初始数据来自 `GET /api/dashboard/overview`。
|
||||||
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
|
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
|
||||||
- 解析队列由 `processing_tasks` 中的 queued/running 任务生成;活动日志由最近任务、项目、标注和模板记录生成。
|
- 解析队列由 `processing_tasks` 中的 queued/running/failed/cancelled 任务生成;活动日志由最近任务、项目、标注和模板记录生成。
|
||||||
|
- Dashboard 对 queued/running 任务提供取消按钮,对 failed/cancelled 任务提供重试按钮。
|
||||||
|
- Dashboard 任务详情会读取 `GET /api/tasks/{task_id}` 并展示失败 error、payload、result、Celery ID 和时间信息。
|
||||||
- Dashboard 会连接 `/ws/progress`。
|
- Dashboard 会连接 `/ws/progress`。
|
||||||
- 收到 progress、complete、error、status 消息时,前端会更新队列或日志。
|
- 收到 progress、complete、error、status 消息时,前端会更新队列或日志。
|
||||||
|
- 收到 cancelled 消息时,前端会把对应任务标记为已取消。
|
||||||
- Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件,FastAPI 订阅并广播给 `/ws/progress` 客户端。
|
- Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件,FastAPI 订阅并广播给 `/ws/progress` 客户端。
|
||||||
- 后端 WebSocket 接收到客户端消息后返回 status heartbeat。
|
- 后端 WebSocket 接收到客户端消息后返回 status heartbeat。
|
||||||
|
|
||||||
@@ -104,7 +120,10 @@
|
|||||||
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
|
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
|
||||||
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
|
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
|
||||||
- 当前前端 `exportCoco()` API 封装已对齐后端路径。
|
- 当前前端 `exportCoco()` API 封装已对齐后端路径。
|
||||||
|
- 当前前端 `exportMasks()` API 封装已对齐后端路径。
|
||||||
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
||||||
|
- 工作区“导出 PNG Mask ZIP”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
||||||
|
- PNG mask ZIP 包含单标注二值 mask、按 zIndex 融合后的每帧语义 mask 和 `semantic_classes.json`。
|
||||||
|
|
||||||
## R12 配置
|
## R12 配置
|
||||||
|
|
||||||
|
|||||||
@@ -19,17 +19,17 @@
|
|||||||
| 模块 | 文件 | 设计职责 |
|
| 模块 | 文件 | 设计职责 |
|
||||||
|------|------|----------|
|
|------|------|----------|
|
||||||
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
|
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
|
||||||
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态 |
|
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态和 mask 撤销/重做历史栈 |
|
||||||
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
||||||
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
||||||
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
|
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
|
||||||
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
|
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
|
||||||
| 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store |
|
| 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store |
|
||||||
| Dashboard | `src/components/Dashboard.tsx` | 展示统计和 WebSocket 进度消息 |
|
| Dashboard | `src/components/Dashboard.tsx` | 展示统计、任务控制、失败详情和 WebSocket 进度消息 |
|
||||||
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
|
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
|
||||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具和跳转 AI 页面 |
|
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做 |
|
||||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
|
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
|
||||||
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
|
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
|
||||||
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
||||||
@@ -51,7 +51,7 @@
|
|||||||
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
|
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
|
||||||
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
|
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
|
||||||
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
|
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
|
||||||
| SAM 3 | `backend/services/sam3_engine.py` | SAM 3 状态检测和文本语义推理适配 |
|
| SAM 3 | `backend/services/sam3_engine.py`, `backend/services/sam3_external_worker.py`, `backend/setup_sam3_env.sh` | SAM 3 状态检测、独立 Python 3.12 环境桥接和文本语义推理适配 |
|
||||||
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
|
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
|
||||||
|
|
||||||
## 状态模型
|
## 状态模型
|
||||||
@@ -62,6 +62,7 @@
|
|||||||
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
|
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
|
||||||
- `Template` / `TemplateClass`:模板和分类定义。
|
- `Template` / `TemplateClass`:模板和分类定义。
|
||||||
- `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。
|
- `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。
|
||||||
|
- `maskHistory` / `maskFuture`:mask 编辑历史栈,用于撤销和重做。
|
||||||
- `activeModule`:当前页面。
|
- `activeModule`:当前页面。
|
||||||
- `activeTool`:当前工具。
|
- `activeTool`:当前工具。
|
||||||
- `aiModel`:当前选择的 AI 模型,取值为 `sam2` 或 `sam3`。
|
- `aiModel`:当前选择的 AI 模型,取值为 `sam2` 或 `sam3`。
|
||||||
@@ -82,6 +83,14 @@
|
|||||||
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||||
5. 刷新项目列表。
|
5. 刷新项目列表。
|
||||||
|
|
||||||
|
### 任务控制
|
||||||
|
|
||||||
|
1. Dashboard 从 `GET /api/dashboard/overview` 读取 queued/running/failed/cancelled 任务。
|
||||||
|
2. 用户取消任务时,前端调用 `POST /api/tasks/{task_id}/cancel`;后端写入 `cancelled`、设置 `finished_at`,并尝试 `celery_app.control.revoke(..., terminate=True)`。
|
||||||
|
3. worker 在下载、解析、上传、写帧等关键阶段刷新任务状态;如果发现 `cancelled`,停止后续写入并发布 cancelled 事件。
|
||||||
|
4. 用户重试任务时,前端调用 `POST /api/tasks/{task_id}/retry`;后端基于原任务 `payload` 创建新任务,记录 `retry_of` 并重新投递 Celery。
|
||||||
|
5. 用户打开详情时,前端调用 `GET /api/tasks/{task_id}`,弹窗展示 error、payload、result、Celery ID 和时间。
|
||||||
|
|
||||||
### 工作区加载
|
### 工作区加载
|
||||||
|
|
||||||
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
||||||
@@ -102,6 +111,41 @@
|
|||||||
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||||
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||||
|
|
||||||
|
### 手工绘制与历史栈
|
||||||
|
|
||||||
|
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
|
||||||
|
2. `CanvasArea` 将交互坐标转换成像素 polygon。
|
||||||
|
3. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||||
|
4. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||||
|
5. 工具栏按钮、AI 页按钮和 Canvas Ctrl+Z/Ctrl+Y 调用 `undoMasks()` / `redoMasks()`。
|
||||||
|
|
||||||
|
### Polygon 逐点编辑
|
||||||
|
|
||||||
|
1. 用户点击 Canvas 上的 mask path 后,`CanvasArea` 记录 `selectedMaskId` 并显示该 mask 第一条 polygon 的顶点控制点。
|
||||||
|
2. 拖动顶点后,前端重算 `pathData`、像素 `segmentation`、`bbox`、`area`。
|
||||||
|
3. 如果 mask 已有 `annotationId`,编辑会把 `saveStatus` 标成 `dirty` 且 `saved=false`。
|
||||||
|
4. 归档保存时复用现有 `PATCH /api/ai/annotations/{annotation_id}` 链路,把更新后的 normalized polygon 写回后端。
|
||||||
|
5. 选中顶点后 Delete/Backspace 可删除顶点;前端保持 polygon 至少三点。
|
||||||
|
|
||||||
|
### 区域合并与去除
|
||||||
|
|
||||||
|
1. 用户选择 `area_merge` 或 `area_remove` 后,点击多个当前帧 mask 组成选择集。
|
||||||
|
2. `CanvasArea` 把 `Mask.segmentation` 转为 `polygon-clipping` 的 MultiPolygon。
|
||||||
|
3. `area_merge` 使用 union,更新第一个选中的主 mask,并从前端 store 移除后续被合并 mask;如果被移除 mask 已保存,会调用工作区传入的删除回调删除后端标注。
|
||||||
|
4. `area_remove` 使用 difference,从第一个选中的主 mask 中扣除后续选中 mask,扣除对象本身保留。
|
||||||
|
5. 结果会重算 `pathData`、`segmentation`、`bbox`、`area`,已保存主 mask 会进入 dirty 状态并复用归档 PATCH 链路。
|
||||||
|
|
||||||
|
### GT Mask 导入
|
||||||
|
|
||||||
|
1. 工作区“导入 GT Mask”选择图片文件。
|
||||||
|
2. 前端 `importGtMask()` 以 multipart form-data 调用 `POST /api/ai/import-gt-mask`,携带 `project_id` 和 `frame_id`。
|
||||||
|
3. 后端验证项目、帧、模板后使用 OpenCV 读取灰度 mask。
|
||||||
|
4. 后端按非零像素值拆分多类别标签。
|
||||||
|
5. 后端对每个类别的前景做 contour 提取,每个连通域保存为一个 `Annotation`。
|
||||||
|
6. `points` 字段保存距离变换中心 seed point,`mask_data.polygons` 保存 normalized polygon,`mask_data.gt_label_value` 保存原始像素类别值。
|
||||||
|
7. 前端重新读取项目标注并回显。
|
||||||
|
8. `annotationToMask()` 会把 normalized seed point 转成像素坐标,Canvas 以可拖拽点显示;拖动后 `buildAnnotationPayload()` 会把点再归一化写回后端。
|
||||||
|
|
||||||
### 模板管理
|
### 模板管理
|
||||||
|
|
||||||
1. `TemplateRegistry` 从后端读取模板。
|
1. `TemplateRegistry` 从后端读取模板。
|
||||||
@@ -114,8 +158,10 @@
|
|||||||
### 导出
|
### 导出
|
||||||
|
|
||||||
1. 后端根据项目、帧、标注和模板生成 COCO JSON。
|
1. 后端根据项目、帧、标注和模板生成 COCO JSON。
|
||||||
2. PNG mask 导出会把 normalized polygon 渲染为二值 mask 并打包 ZIP。
|
2. PNG mask 导出会把 normalized polygon 渲染为单标注二值 mask。
|
||||||
3. 前端“导出 JSON 标注集”按钮会在导出前保存待归档标注,然后下载 COCO JSON。
|
3. PNG mask 导出还会按 `mask_data.class.zIndex` 或模板 `z_index` 从低到高覆盖,生成每帧语义融合 mask。
|
||||||
|
4. ZIP 内写入 `semantic_classes.json`,记录语义值到类别、颜色和 zIndex 的映射。
|
||||||
|
5. 前端“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮都会在导出前保存待归档标注,然后下载对应文件。
|
||||||
|
|
||||||
## 接口契约
|
## 接口契约
|
||||||
|
|
||||||
@@ -123,13 +169,18 @@
|
|||||||
|
|
||||||
- `updateProject()` 使用 `PATCH /api/projects/{id}`。
|
- `updateProject()` 使用 `PATCH /api/projects/{id}`。
|
||||||
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco`。
|
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco`。
|
||||||
|
- `exportMasks()` 使用 `GET /api/export/{projectId}/masks`。
|
||||||
|
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。
|
||||||
|
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。
|
||||||
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||||
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
||||||
|
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
|
||||||
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
||||||
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。
|
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。
|
||||||
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。
|
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。
|
||||||
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
||||||
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态。
|
- 后端 `/api/ai/predict` 支持可选 `options`:`crop_to_prompt` 会对 point/box prompt 做局部裁剪推理并回映射 polygon,`auto_filter_background` 会按 `min_score` 和负向点过滤结果。
|
||||||
|
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态;SAM 3 状态包含外部 Python 环境与 checkpoint access 的可用性。
|
||||||
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
||||||
|
|
||||||
## 外部依赖边界
|
## 外部依赖边界
|
||||||
@@ -146,10 +197,8 @@
|
|||||||
|
|
||||||
以下能力属于当前冻结版本的占位或半可用功能:
|
以下能力属于当前冻结版本的占位或半可用功能:
|
||||||
|
|
||||||
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running 任务生成。
|
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running/failed/cancelled 任务生成。
|
||||||
- 多边形、矩形、圆、点、线手工绘制未实现。
|
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;复杂洞结构编辑尚未实现。
|
||||||
- 合并、去除、撤销、重做未实现。
|
- SAM 3 文本语义分割取决于官方依赖、GPU 运行环境和 Hugging Face gated 权重授权;状态接口会暴露真实可用性,未授权时 `available=false`。
|
||||||
- 工作区导出 PNG mask ZIP 按钮尚未提供。
|
|
||||||
- 已保存标注支持通过“应用分类”进入 dirty 状态并归档更新;暂未提供逐点几何编辑器。
|
|
||||||
- SAM 3 文本语义分割取决于官方依赖和 GPU 运行环境;状态接口会暴露真实可用性。
|
|
||||||
- 自定义分类只存在本地组件状态。
|
- 自定义分类只存在本地组件状态。
|
||||||
|
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。
|
||||||
|
|||||||
@@ -16,15 +16,15 @@
|
|||||||
|------|----------|--------|
|
|------|----------|--------|
|
||||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||||
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
||||||
| R3 媒体上传与拆帧 | `backend/tests/test_media.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧 |
|
| R3 媒体上传与拆帧 | `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
|
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
|
||||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx` | 工具切换、AI 跳转、占位按钮存在 |
|
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts` | 工具切换、AI 跳转、手工 mask 绘制、polygon 顶点拖动/删除、区域合并/去除、撤销/重做历史栈 |
|
||||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam3_engine.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、SAM 3 外部 worker 桥接、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、项目不存在、帧不存在 |
|
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||||
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
|
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
|
||||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
|
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
|
||||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py` | 后端概览接口、任务表驱动队列、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
|
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动队列、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
|
||||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO 按钮下载、导出前自动保存、COCO 路径、JSON 结构、mask ZIP |
|
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
||||||
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
||||||
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
||||||
|
|
||||||
|
|||||||
26
package-lock.json
generated
26
package-lock.json
generated
@@ -18,6 +18,7 @@
|
|||||||
"konva": "^10.2.5",
|
"konva": "^10.2.5",
|
||||||
"lucide-react": "^0.546.0",
|
"lucide-react": "^0.546.0",
|
||||||
"motion": "^12.23.24",
|
"motion": "^12.23.24",
|
||||||
|
"polygon-clipping": "^0.15.7",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
"react-konva": "^19.2.3",
|
"react-konva": "^19.2.3",
|
||||||
@@ -4165,6 +4166,16 @@
|
|||||||
"url": "https://github.com/sponsors/jonschlinkert"
|
"url": "https://github.com/sponsors/jonschlinkert"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/polygon-clipping": {
|
||||||
|
"version": "0.15.7",
|
||||||
|
"resolved": "https://registry.npmjs.org/polygon-clipping/-/polygon-clipping-0.15.7.tgz",
|
||||||
|
"integrity": "sha512-nhfdr83ECBg6xtqOAJab1tbksbBAOMUltN60bU+llHVOL0e5Onm1WpAXXWXVB39L8AJFssoIhEVuy/S90MmotA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"robust-predicates": "^3.0.2",
|
||||||
|
"splaytree": "^3.1.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/postcss": {
|
"node_modules/postcss": {
|
||||||
"version": "8.5.12",
|
"version": "8.5.12",
|
||||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
|
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.12.tgz",
|
||||||
@@ -4438,6 +4449,12 @@
|
|||||||
"node": ">= 4"
|
"node": ">= 4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/robust-predicates": {
|
||||||
|
"version": "3.0.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/robust-predicates/-/robust-predicates-3.0.3.tgz",
|
||||||
|
"integrity": "sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==",
|
||||||
|
"license": "Unlicense"
|
||||||
|
},
|
||||||
"node_modules/rollup": {
|
"node_modules/rollup": {
|
||||||
"version": "4.60.2",
|
"version": "4.60.2",
|
||||||
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz",
|
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz",
|
||||||
@@ -4684,6 +4701,15 @@
|
|||||||
"node": ">=0.10.0"
|
"node": ">=0.10.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/splaytree": {
|
||||||
|
"version": "3.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/splaytree/-/splaytree-3.2.3.tgz",
|
||||||
|
"integrity": "sha512-7OXrNWzy6CK+r7Ch9OLPBDTKfB6XlWHjX4P0RU5B3IgFuWPeYN0XtRtlexGRjgbQxpfaUve6jTAwBGWuGntz/w==",
|
||||||
|
"license": "MIT",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18.20 || >=20"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/stackback": {
|
"node_modules/stackback": {
|
||||||
"version": "0.0.2",
|
"version": "0.0.2",
|
||||||
"resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz",
|
"resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
"konva": "^10.2.5",
|
"konva": "^10.2.5",
|
||||||
"lucide-react": "^0.546.0",
|
"lucide-react": "^0.546.0",
|
||||||
"motion": "^12.23.24",
|
"motion": "^12.23.24",
|
||||||
|
"polygon-clipping": "^0.15.7",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
"react-konva": "^19.2.3",
|
"react-konva": "^19.2.3",
|
||||||
|
|||||||
@@ -40,4 +40,26 @@ describe('AISegmentation', () => {
|
|||||||
expect(useStore.getState().aiModel).toBe('sam3');
|
expect(useStore.getState().aiModel).toBe('sam3');
|
||||||
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
|
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('passes enabled inference parameters to the backend', async () => {
|
||||||
|
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
|
||||||
|
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText('正向选点'));
|
||||||
|
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||||
|
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||||
|
|
||||||
|
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||||
|
imageId: 'frame-1',
|
||||||
|
imageWidth: 640,
|
||||||
|
imageHeight: 360,
|
||||||
|
model: 'sam2',
|
||||||
|
points: [{ x: 120, y: 80, type: 'pos' }],
|
||||||
|
options: {
|
||||||
|
crop_to_prompt: false,
|
||||||
|
auto_filter_background: true,
|
||||||
|
min_score: 0.05,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
const masks = useStore((state) => state.masks);
|
const masks = useStore((state) => state.masks);
|
||||||
const addMask = useStore((state) => state.addMask);
|
const addMask = useStore((state) => state.addMask);
|
||||||
const clearMasks = useStore((state) => state.clearMasks);
|
const clearMasks = useStore((state) => state.clearMasks);
|
||||||
|
const maskHistory = useStore((state) => state.maskHistory);
|
||||||
|
const maskFuture = useStore((state) => state.maskFuture);
|
||||||
|
const undoMasks = useStore((state) => state.undoMasks);
|
||||||
|
const redoMasks = useStore((state) => state.redoMasks);
|
||||||
const frames = useStore((state) => state.frames);
|
const frames = useStore((state) => state.frames);
|
||||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||||
@@ -109,6 +113,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
model: aiModel,
|
model: aiModel,
|
||||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||||
text: semanticText.trim() || undefined,
|
text: semanticText.trim() || undefined,
|
||||||
|
options: {
|
||||||
|
crop_to_prompt: cropMode,
|
||||||
|
auto_filter_background: autoDeleteBg,
|
||||||
|
min_score: autoDeleteBg ? 0.05 : 0,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
result.masks.forEach((m) => {
|
result.masks.forEach((m) => {
|
||||||
@@ -136,7 +145,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
} finally {
|
} finally {
|
||||||
setIsInferencing(false);
|
setIsInferencing(false);
|
||||||
}
|
}
|
||||||
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||||
|
|
||||||
const handleStageClick = (e: any) => {
|
const handleStageClick = (e: any) => {
|
||||||
if (effectiveTool === 'move') return;
|
if (effectiveTool === 'move') return;
|
||||||
@@ -290,10 +299,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
|||||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} 动态推理渲染</span>
|
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} 动态推理渲染</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-4">
|
<div className="flex items-center gap-4">
|
||||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
<button
|
||||||
|
onClick={undoMasks}
|
||||||
|
disabled={maskHistory.length === 0}
|
||||||
|
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||||
|
title="撤销操作 (Ctrl+Z)"
|
||||||
|
>
|
||||||
<Undo size={14} />
|
<Undo size={14} />
|
||||||
</button>
|
</button>
|
||||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
|
<button
|
||||||
|
onClick={redoMasks}
|
||||||
|
disabled={maskFuture.length === 0}
|
||||||
|
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||||
|
title="重做操作 (Ctrl+Shift+Z)"
|
||||||
|
>
|
||||||
<Redo size={14} />
|
<Redo size={14} />
|
||||||
</button>
|
</button>
|
||||||
<div className="w-px h-4 bg-white/10 mx-1"></div>
|
<div className="w-px h-4 bg-white/10 mx-1"></div>
|
||||||
|
|||||||
@@ -79,6 +79,271 @@ describe('CanvasArea', () => {
|
|||||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('renders imported GT seed points for editable point regions', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'gt-1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||||
|
label: 'GT',
|
||||||
|
color: '#22c55e',
|
||||||
|
points: [[120, 80]],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||||
|
|
||||||
|
expect(screen.getAllByTestId('konva-circle')).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('selects a polygon mask and drags a vertex into dirty saved state', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'annotation-99',
|
||||||
|
annotationId: '99',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||||
|
label: 'Saved',
|
||||||
|
color: '#06b6d4',
|
||||||
|
saved: true,
|
||||||
|
saveStatus: 'saved',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||||
|
bbox: [10, 10, 80, 30],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||||
|
fireEvent.click(screen.getByTestId('konva-path'));
|
||||||
|
const handles = screen.getAllByTestId('konva-circle')
|
||||||
|
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||||
|
expect(handles).toHaveLength(3);
|
||||||
|
|
||||||
|
fireEvent.mouseUp(handles[0], { clientX: 20, clientY: 30 });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
pathData: 'M 20 30 L 90 10 L 90 40 Z',
|
||||||
|
segmentation: [[20, 30, 90, 10, 90, 40]],
|
||||||
|
bbox: [20, 10, 70, 30],
|
||||||
|
area: 1050,
|
||||||
|
saveStatus: 'dirty',
|
||||||
|
saved: false,
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('deletes a selected polygon vertex without dropping below three points', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'draft-1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 40 L 10 40 Z',
|
||||||
|
label: 'Draft',
|
||||||
|
color: '#06b6d4',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 40, 10, 40]],
|
||||||
|
bbox: [10, 10, 80, 30],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||||
|
fireEvent.click(screen.getByTestId('konva-path'));
|
||||||
|
const handles = screen.getAllByTestId('konva-circle')
|
||||||
|
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||||
|
fireEvent.click(handles[0]);
|
||||||
|
fireEvent.keyDown(window, { key: 'Delete' });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
pathData: 'M 90 10 L 90 40 L 10 40 Z',
|
||||||
|
segmentation: [[90, 10, 90, 40, 10, 40]],
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('inserts a polygon vertex from an edge midpoint handle', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'draft-1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||||
|
label: 'Draft',
|
||||||
|
color: '#06b6d4',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||||
|
bbox: [10, 10, 80, 30],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||||
|
fireEvent.click(screen.getByTestId('konva-path'));
|
||||||
|
const edgeHandles = screen.getAllByTestId('konva-circle')
|
||||||
|
.filter((element) => element.getAttribute('data-fill') === '#22d3ee');
|
||||||
|
fireEvent.click(edgeHandles[0]);
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
|
||||||
|
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('edits the selected polygon in a multi-polygon mask', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'multi-1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 50 10 L 50 40 Z M 100 100 L 150 100 L 150 140 Z',
|
||||||
|
label: 'Multi',
|
||||||
|
color: '#06b6d4',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
segmentation: [
|
||||||
|
[10, 10, 50, 10, 50, 40],
|
||||||
|
[100, 100, 150, 100, 150, 140],
|
||||||
|
],
|
||||||
|
bbox: [10, 10, 140, 130],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||||
|
const paths = screen.getAllByTestId('konva-path');
|
||||||
|
fireEvent.click(paths[1]);
|
||||||
|
const vertexHandles = screen.getAllByTestId('konva-circle')
|
||||||
|
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||||
|
fireEvent.mouseUp(vertexHandles[0], { clientX: 120, clientY: 120 });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks[0].segmentation).toEqual([
|
||||||
|
[10, 10, 50, 10, 50, 40],
|
||||||
|
[120, 120, 150, 100, 150, 140],
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges selected draft masks with polygon union', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'm1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
|
||||||
|
label: 'A',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'm2',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
|
||||||
|
label: 'B',
|
||||||
|
color: '#ff0000',
|
||||||
|
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="area_merge" frame={frame} />);
|
||||||
|
const paths = screen.getAllByTestId('konva-path');
|
||||||
|
fireEvent.click(paths[0]);
|
||||||
|
fireEvent.click(paths[1]);
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
|
||||||
|
|
||||||
|
expect(useStore.getState().masks).toHaveLength(1);
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
id: 'm1',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 30, 120, 30, 120, 80, 50, 80, 50, 50, 10, 50]],
|
||||||
|
bbox: [10, 10, 110, 70],
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('removes overlap from the primary selected mask with polygon difference', () => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [
|
||||||
|
{
|
||||||
|
id: 'm1',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
|
||||||
|
label: 'A',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'm2',
|
||||||
|
frameId: 'frame-1',
|
||||||
|
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
|
||||||
|
label: 'B',
|
||||||
|
color: '#ff0000',
|
||||||
|
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="area_remove" frame={frame} />);
|
||||||
|
const paths = screen.getAllByTestId('konva-path');
|
||||||
|
fireEvent.click(paths[0]);
|
||||||
|
fireEvent.click(paths[1]);
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
|
||||||
|
|
||||||
|
expect(useStore.getState().masks).toHaveLength(2);
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
id: 'm1',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 30, 50, 30, 50, 50, 10, 50]],
|
||||||
|
bbox: [10, 10, 80, 40],
|
||||||
|
saveStatus: 'draft',
|
||||||
|
}));
|
||||||
|
expect(useStore.getState().masks[1].id).toBe('m2');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates a manual rectangle mask that can be undone and redone', () => {
|
||||||
|
useStore.setState({
|
||||||
|
activeTemplateId: '2',
|
||||||
|
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||||
|
activeClassId: 'c1',
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<CanvasArea activeTool="create_rectangle" frame={frame} />);
|
||||||
|
const stage = screen.getByTestId('konva-stage');
|
||||||
|
fireEvent.mouseDown(stage);
|
||||||
|
fireEvent.mouseMove(stage);
|
||||||
|
fireEvent.mouseUp(stage);
|
||||||
|
|
||||||
|
expect(useStore.getState().masks).toHaveLength(1);
|
||||||
|
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||||
|
frameId: 'frame-1',
|
||||||
|
label: '胆囊',
|
||||||
|
color: '#ff0000',
|
||||||
|
saveStatus: 'draft',
|
||||||
|
segmentation: [[120, 80, 260, 80, 260, 200, 120, 200]],
|
||||||
|
bbox: [120, 80, 140, 120],
|
||||||
|
}));
|
||||||
|
|
||||||
|
useStore.getState().undoMasks();
|
||||||
|
expect(useStore.getState().masks).toEqual([]);
|
||||||
|
useStore.getState().redoMasks();
|
||||||
|
expect(useStore.getState().masks).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('finalizes a clicked polygon with Enter', () => {
|
||||||
|
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||||
|
const stage = screen.getByTestId('konva-stage');
|
||||||
|
fireEvent.click(stage, { clientX: 120, clientY: 80 });
|
||||||
|
fireEvent.click(stage, { clientX: 220, clientY: 80 });
|
||||||
|
fireEvent.click(stage, { clientX: 180, clientY: 160 });
|
||||||
|
fireEvent.keyDown(window, { key: 'Enter' });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks).toHaveLength(1);
|
||||||
|
expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({
|
||||||
|
source: 'manual',
|
||||||
|
shape: '多边形',
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||||
useStore.setState({
|
useStore.setState({
|
||||||
activeTemplateId: '2',
|
activeTemplateId: '2',
|
||||||
|
|||||||
@@ -1,17 +1,180 @@
|
|||||||
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
||||||
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
|
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
|
||||||
|
import polygonClipping, { type MultiPolygon, type Pair } from 'polygon-clipping';
|
||||||
import useImage from 'use-image';
|
import useImage from 'use-image';
|
||||||
import { useStore } from '../store/useStore';
|
import { useStore } from '../store/useStore';
|
||||||
import { predictMask } from '../lib/api';
|
import { predictMask } from '../lib/api';
|
||||||
import type { Frame } from '../store/useStore';
|
import type { Frame, Mask } from '../store/useStore';
|
||||||
|
|
||||||
interface CanvasAreaProps {
|
interface CanvasAreaProps {
|
||||||
activeTool: string;
|
activeTool: string;
|
||||||
frame: Frame | null;
|
frame: Frame | null;
|
||||||
onClearMasks?: () => void;
|
onClearMasks?: () => void;
|
||||||
|
onDeleteMaskAnnotations?: (annotationIds: string[]) => Promise<void> | void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
|
type CanvasPoint = { x: number; y: number };
|
||||||
|
|
||||||
|
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||||
|
const POLYGON_TOOL = 'create_polygon';
|
||||||
|
const POINT_TOOL = 'create_point';
|
||||||
|
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||||
|
|
||||||
|
function clamp(value: number, min: number, max: number): number {
|
||||||
|
return Math.min(Math.max(value, min), max);
|
||||||
|
}
|
||||||
|
|
||||||
|
function polygonPath(points: CanvasPoint[]): string {
|
||||||
|
if (points.length === 0) return '';
|
||||||
|
return points
|
||||||
|
.map((point, index) => `${index === 0 ? 'M' : 'L'} ${point.x} ${point.y}`)
|
||||||
|
.join(' ')
|
||||||
|
.concat(' Z');
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationPath(segmentation?: number[][]): string {
|
||||||
|
return (segmentation || [])
|
||||||
|
.map((polygon) => polygonPath(flatPolygonToPoints(polygon)))
|
||||||
|
.filter(Boolean)
|
||||||
|
.join(' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationPolygonPath(segmentation: number[][] | undefined, polygonIndex: number): string {
|
||||||
|
const polygon = segmentation?.[polygonIndex];
|
||||||
|
return polygon ? polygonPath(flatPolygonToPoints(polygon)) : '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function polygonSegmentation(points: CanvasPoint[]): number[][] {
|
||||||
|
return [points.flatMap((point) => [point.x, point.y])];
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationToPoints(segmentation?: number[][], polygonIndex = 0): CanvasPoint[] {
|
||||||
|
const polygon = segmentation?.[polygonIndex] || [];
|
||||||
|
const points: CanvasPoint[] = [];
|
||||||
|
for (let index = 0; index < polygon.length - 1; index += 2) {
|
||||||
|
points.push({ x: polygon[index], y: polygon[index + 1] });
|
||||||
|
}
|
||||||
|
return points;
|
||||||
|
}
|
||||||
|
|
||||||
|
function flatPolygonToPoints(polygon: number[]): CanvasPoint[] {
|
||||||
|
const points: CanvasPoint[] = [];
|
||||||
|
for (let index = 0; index < polygon.length - 1; index += 2) {
|
||||||
|
points.push({ x: polygon[index], y: polygon[index + 1] });
|
||||||
|
}
|
||||||
|
return points;
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationAllPoints(segmentation?: number[][]): CanvasPoint[] {
|
||||||
|
return (segmentation || []).flatMap((polygon) => flatPolygonToPoints(polygon));
|
||||||
|
}
|
||||||
|
|
||||||
|
function polygonBbox(points: CanvasPoint[]): [number, number, number, number] {
|
||||||
|
const xs = points.map((point) => point.x);
|
||||||
|
const ys = points.map((point) => point.y);
|
||||||
|
const minX = Math.min(...xs);
|
||||||
|
const minY = Math.min(...ys);
|
||||||
|
const maxX = Math.max(...xs);
|
||||||
|
const maxY = Math.max(...ys);
|
||||||
|
return [minX, minY, maxX - minX, maxY - minY];
|
||||||
|
}
|
||||||
|
|
||||||
|
function polygonArea(points: CanvasPoint[]): number {
|
||||||
|
if (points.length < 3) return 0;
|
||||||
|
const sum = points.reduce((acc, point, index) => {
|
||||||
|
const next = points[(index + 1) % points.length];
|
||||||
|
return acc + point.x * next.y - next.x * point.y;
|
||||||
|
}, 0);
|
||||||
|
return Math.abs(sum) / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationArea(segmentation?: number[][]): number {
|
||||||
|
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
function segmentationBbox(segmentation?: number[][]): [number, number, number, number] | undefined {
|
||||||
|
const points = segmentationAllPoints(segmentation);
|
||||||
|
return points.length > 0 ? polygonBbox(points) : undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
function closeRing(points: CanvasPoint[]): Pair[] {
|
||||||
|
const ring = points.map((point) => [point.x, point.y] as Pair);
|
||||||
|
const first = ring[0];
|
||||||
|
const last = ring[ring.length - 1];
|
||||||
|
if (first && last && (first[0] !== last[0] || first[1] !== last[1])) {
|
||||||
|
ring.push([first[0], first[1]]);
|
||||||
|
}
|
||||||
|
return ring;
|
||||||
|
}
|
||||||
|
|
||||||
|
function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
|
||||||
|
const polygons = (mask.segmentation || [])
|
||||||
|
.map((polygon) => flatPolygonToPoints(polygon))
|
||||||
|
.filter((points) => points.length >= 3)
|
||||||
|
.map((points) => [closeRing(points)]);
|
||||||
|
return polygons.length > 0 ? polygons : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
|
||||||
|
return geometry
|
||||||
|
.map((polygon) => polygon[0] || [])
|
||||||
|
.map((ring) => {
|
||||||
|
const openRing = ring.length > 1
|
||||||
|
&& ring[0][0] === ring[ring.length - 1][0]
|
||||||
|
&& ring[0][1] === ring[ring.length - 1][1]
|
||||||
|
? ring.slice(0, -1)
|
||||||
|
: ring;
|
||||||
|
return openRing.flatMap(([x, y]) => [x, y]);
|
||||||
|
})
|
||||||
|
.filter((polygon) => polygon.length >= 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||||
|
const x1 = Math.min(start.x, end.x);
|
||||||
|
const y1 = Math.min(start.y, end.y);
|
||||||
|
const x2 = Math.max(start.x, end.x);
|
||||||
|
const y2 = Math.max(start.y, end.y);
|
||||||
|
return [
|
||||||
|
{ x: x1, y: y1 },
|
||||||
|
{ x: x2, y: y1 },
|
||||||
|
{ x: x2, y: y2 },
|
||||||
|
{ x: x1, y: y2 },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||||
|
const cx = (start.x + end.x) / 2;
|
||||||
|
const cy = (start.y + end.y) / 2;
|
||||||
|
const rx = Math.abs(end.x - start.x) / 2;
|
||||||
|
const ry = Math.abs(end.y - start.y) / 2;
|
||||||
|
return Array.from({ length: 32 }, (_, index) => {
|
||||||
|
const angle = (Math.PI * 2 * index) / 32;
|
||||||
|
return { x: cx + Math.cos(angle) * rx, y: cy + Math.sin(angle) * ry };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
|
||||||
|
return Array.from({ length: 12 }, (_, index) => {
|
||||||
|
const angle = (Math.PI * 2 * index) / 12;
|
||||||
|
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
|
||||||
|
const dx = end.x - start.x;
|
||||||
|
const dy = end.y - start.y;
|
||||||
|
const length = Math.hypot(dx, dy) || 1;
|
||||||
|
const nx = (-dy / length) * halfWidth;
|
||||||
|
const ny = (dx / length) * halfWidth;
|
||||||
|
return [
|
||||||
|
{ x: start.x + nx, y: start.y + ny },
|
||||||
|
{ x: end.x + nx, y: end.y + ny },
|
||||||
|
{ x: end.x - nx, y: end.y - ny },
|
||||||
|
{ x: start.x - nx, y: start.y - ny },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
|
||||||
const containerRef = useRef<HTMLDivElement>(null);
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||||
const [scale, setScale] = useState(1);
|
const [scale, setScale] = useState(1);
|
||||||
@@ -20,22 +183,46 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||||
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
||||||
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
||||||
|
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||||
|
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||||
|
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||||
|
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
|
||||||
|
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
|
||||||
|
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||||
|
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||||
const [isInferencing, setIsInferencing] = useState(false);
|
const [isInferencing, setIsInferencing] = useState(false);
|
||||||
|
|
||||||
const masks = useStore((state) => state.masks);
|
const masks = useStore((state) => state.masks);
|
||||||
const addMask = useStore((state) => state.addMask);
|
const addMask = useStore((state) => state.addMask);
|
||||||
|
const updateMask = useStore((state) => state.updateMask);
|
||||||
const clearMasks = useStore((state) => state.clearMasks);
|
const clearMasks = useStore((state) => state.clearMasks);
|
||||||
const setMasks = useStore((state) => state.setMasks);
|
const setMasks = useStore((state) => state.setMasks);
|
||||||
const storeActiveTool = useStore((state) => state.activeTool);
|
const storeActiveTool = useStore((state) => state.activeTool);
|
||||||
const aiModel = useStore((state) => state.aiModel);
|
const aiModel = useStore((state) => state.aiModel);
|
||||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||||
const activeClass = useStore((state) => state.activeClass);
|
const activeClass = useStore((state) => state.activeClass);
|
||||||
|
const undoMasks = useStore((state) => state.undoMasks);
|
||||||
|
const redoMasks = useStore((state) => state.redoMasks);
|
||||||
|
|
||||||
const effectiveTool = activeTool || storeActiveTool;
|
const effectiveTool = activeTool || storeActiveTool;
|
||||||
|
|
||||||
// Load the actual frame image
|
// Load the actual frame image
|
||||||
const [image] = useImage(frame?.url || '');
|
const [image] = useImage(frame?.url || '');
|
||||||
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
||||||
|
const selectedMask = React.useMemo(
|
||||||
|
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
|
||||||
|
[frameMasks, selectedMaskId],
|
||||||
|
);
|
||||||
|
const booleanSelectedMasks = React.useMemo(
|
||||||
|
() => selectedMaskIds
|
||||||
|
.map((id) => frameMasks.find((mask) => mask.id === id))
|
||||||
|
.filter((mask): mask is Mask => Boolean(mask)),
|
||||||
|
[frameMasks, selectedMaskIds],
|
||||||
|
);
|
||||||
|
const selectedMaskPoints = React.useMemo(
|
||||||
|
() => segmentationToPoints(selectedMask?.segmentation, selectedPolygonIndex),
|
||||||
|
[selectedMask?.segmentation, selectedPolygonIndex],
|
||||||
|
);
|
||||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||||
@@ -55,6 +242,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
return () => window.removeEventListener('resize', handleResize);
|
return () => window.removeEventListener('resize', handleResize);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setManualStart(null);
|
||||||
|
setManualCurrent(null);
|
||||||
|
setPolygonPoints([]);
|
||||||
|
setSelectedMaskId(null);
|
||||||
|
setSelectedMaskIds([]);
|
||||||
|
setSelectedPolygonIndex(0);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
}, [effectiveTool, frame?.id]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
|
||||||
|
setSelectedMaskId(null);
|
||||||
|
setSelectedMaskIds([]);
|
||||||
|
setSelectedPolygonIndex(0);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
}
|
||||||
|
}, [frameMasks, selectedMaskId]);
|
||||||
|
|
||||||
const handleWheel = (e: any) => {
|
const handleWheel = (e: any) => {
|
||||||
e.evt.preventDefault();
|
e.evt.preventDefault();
|
||||||
const scaleBy = 1.1;
|
const scaleBy = 1.1;
|
||||||
@@ -74,6 +280,50 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const stagePoint = (e: any): CanvasPoint | null => {
|
||||||
|
const stage = e.target.getStage();
|
||||||
|
const relPos = stage?.getRelativePointerPosition();
|
||||||
|
if (!relPos) return null;
|
||||||
|
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
|
||||||
|
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
|
||||||
|
return {
|
||||||
|
x: clamp(relPos.x, 0, imageWidth),
|
||||||
|
y: clamp(relPos.y, 0, imageHeight),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createManualMask = useCallback((shape: string, polygon: CanvasPoint[]) => {
|
||||||
|
if (!frame?.id || polygon.length < 3) return;
|
||||||
|
const area = polygonArea(polygon);
|
||||||
|
if (area <= 1) return;
|
||||||
|
const color = activeClass?.color || '#06b6d4';
|
||||||
|
const label = activeClass?.name || `手工${shape}`;
|
||||||
|
const mask: Mask = {
|
||||||
|
id: `manual-${frame.id}-${shape}-${Date.now()}`,
|
||||||
|
frameId: frame.id,
|
||||||
|
templateId: activeTemplateId || undefined,
|
||||||
|
classId: activeClass?.id,
|
||||||
|
className: activeClass?.name,
|
||||||
|
classZIndex: activeClass?.zIndex,
|
||||||
|
saveStatus: 'draft',
|
||||||
|
saved: false,
|
||||||
|
pathData: polygonPath(polygon),
|
||||||
|
label,
|
||||||
|
color,
|
||||||
|
segmentation: polygonSegmentation(polygon),
|
||||||
|
points: shape === '点区域'
|
||||||
|
? [[
|
||||||
|
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
|
||||||
|
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
|
||||||
|
]]
|
||||||
|
: undefined,
|
||||||
|
bbox: polygonBbox(polygon),
|
||||||
|
area,
|
||||||
|
metadata: { source: 'manual', shape },
|
||||||
|
};
|
||||||
|
addMask(mask);
|
||||||
|
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||||
|
|
||||||
const handleMouseMove = (e: any) => {
|
const handleMouseMove = (e: any) => {
|
||||||
const stage = e.target.getStage();
|
const stage = e.target.getStage();
|
||||||
if (!stage) return;
|
if (!stage) return;
|
||||||
@@ -90,6 +340,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
setBoxCurrent({ x: relPos.x, y: relPos.y });
|
setBoxCurrent({ x: relPos.x, y: relPos.y });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (manualStart && DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||||
|
const pos = stage.getRelativePointerPosition();
|
||||||
|
if (pos) {
|
||||||
|
setManualCurrent({ x: pos.x, y: pos.y });
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||||
@@ -132,6 +389,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
label,
|
label,
|
||||||
color,
|
color,
|
||||||
segmentation: m.segmentation,
|
segmentation: m.segmentation,
|
||||||
|
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
|
||||||
bbox: m.bbox,
|
bbox: m.bbox,
|
||||||
area: m.area,
|
area: m.area,
|
||||||
});
|
});
|
||||||
@@ -170,6 +428,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleStageMouseDown = (e: any) => {
|
const handleStageMouseDown = (e: any) => {
|
||||||
|
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||||
|
const pos = stagePoint(e);
|
||||||
|
if (pos) {
|
||||||
|
setManualStart(pos);
|
||||||
|
setManualCurrent(pos);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (effectiveTool === 'box_select') {
|
if (effectiveTool === 'box_select') {
|
||||||
const stage = e.target.getStage();
|
const stage = e.target.getStage();
|
||||||
const pos = stage.getRelativePointerPosition();
|
const pos = stage.getRelativePointerPosition();
|
||||||
@@ -181,6 +448,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleStageMouseUp = (e: any) => {
|
const handleStageMouseUp = (e: any) => {
|
||||||
|
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
|
||||||
|
const end = stagePoint(e) || manualCurrent || manualStart;
|
||||||
|
const width = Math.abs(end.x - manualStart.x);
|
||||||
|
const height = Math.abs(end.y - manualStart.y);
|
||||||
|
const distance = Math.hypot(width, height);
|
||||||
|
|
||||||
|
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
|
||||||
|
createManualMask('矩形', rectanglePoints(manualStart, end));
|
||||||
|
}
|
||||||
|
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
|
||||||
|
createManualMask('圆形', circlePoints(manualStart, end));
|
||||||
|
}
|
||||||
|
if (effectiveTool === 'create_line' && distance > 4) {
|
||||||
|
createManualMask('线段', lineRegion(manualStart, end));
|
||||||
|
}
|
||||||
|
|
||||||
|
setManualStart(null);
|
||||||
|
setManualCurrent(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
|
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
|
||||||
const x1 = Math.min(boxStart.x, boxCurrent.x);
|
const x1 = Math.min(boxStart.x, boxCurrent.x);
|
||||||
const y1 = Math.min(boxStart.y, boxCurrent.y);
|
const y1 = Math.min(boxStart.y, boxCurrent.y);
|
||||||
@@ -199,12 +487,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
const handleStageClick = (e: any) => {
|
const handleStageClick = (e: any) => {
|
||||||
if (effectiveTool === 'move') return;
|
if (effectiveTool === 'move') return;
|
||||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||||
|
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
||||||
|
|
||||||
|
if (effectiveTool === POINT_TOOL) {
|
||||||
|
const pos = stagePoint(e);
|
||||||
|
if (pos) {
|
||||||
|
createManualMask('点区域', pointRegion(pos));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (effectiveTool === POLYGON_TOOL) {
|
||||||
|
const pos = stagePoint(e);
|
||||||
|
if (pos) {
|
||||||
|
setPolygonPoints((current) => [...current, pos]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
|
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
|
||||||
const stage = e.target.getStage();
|
const stage = e.target.getStage();
|
||||||
const pos = stage.getRelativePointerPosition();
|
const pos = stage.getRelativePointerPosition();
|
||||||
if (pos) {
|
if (pos) {
|
||||||
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }];
|
const newPoints = [
|
||||||
|
...points,
|
||||||
|
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
|
||||||
|
];
|
||||||
setPoints(newPoints);
|
setPoints(newPoints);
|
||||||
// Auto-trigger inference after point selection
|
// Auto-trigger inference after point selection
|
||||||
runInference(newPoints);
|
runInference(newPoints);
|
||||||
@@ -212,6 +520,74 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const updatePolygonMask = useCallback((mask: Mask, nextPoints: CanvasPoint[], polygonIndex = 0) => {
|
||||||
|
if (nextPoints.length < 3) return;
|
||||||
|
const nextSegmentation = [...(mask.segmentation || [])];
|
||||||
|
nextSegmentation[polygonIndex] = nextPoints.flatMap((point) => [point.x, point.y]);
|
||||||
|
const bbox = segmentationBbox(nextSegmentation) || polygonBbox(nextPoints);
|
||||||
|
updateMask(mask.id, {
|
||||||
|
pathData: segmentationPath(nextSegmentation),
|
||||||
|
segmentation: nextSegmentation,
|
||||||
|
bbox,
|
||||||
|
area: segmentationArea(nextSegmentation),
|
||||||
|
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||||
|
saved: mask.annotationId ? false : mask.saved,
|
||||||
|
});
|
||||||
|
}, [updateMask]);
|
||||||
|
|
||||||
|
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
|
||||||
|
const bbox = segmentationBbox(segmentation);
|
||||||
|
return {
|
||||||
|
...mask,
|
||||||
|
pathData: segmentationPath(segmentation),
|
||||||
|
segmentation,
|
||||||
|
bbox,
|
||||||
|
area: segmentationArea(segmentation),
|
||||||
|
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||||
|
saved: mask.annotationId ? false : mask.saved,
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleKeyDown = (event: KeyboardEvent) => {
|
||||||
|
const key = event.key.toLowerCase();
|
||||||
|
if ((event.metaKey || event.ctrlKey) && key === 'z') {
|
||||||
|
event.preventDefault();
|
||||||
|
if (event.shiftKey) redoMasks();
|
||||||
|
else undoMasks();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((event.metaKey || event.ctrlKey) && key === 'y') {
|
||||||
|
event.preventDefault();
|
||||||
|
redoMasks();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) {
|
||||||
|
const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex);
|
||||||
|
if (currentPoints.length > 3) {
|
||||||
|
event.preventDefault();
|
||||||
|
const nextPoints = currentPoints.filter((_, index) => index !== selectedVertexIndex);
|
||||||
|
updatePolygonMask(selectedMask, nextPoints, selectedPolygonIndex);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (effectiveTool !== POLYGON_TOOL) return;
|
||||||
|
if (event.key === 'Enter' && polygonPoints.length >= 3) {
|
||||||
|
event.preventDefault();
|
||||||
|
createManualMask('多边形', polygonPoints);
|
||||||
|
setPolygonPoints([]);
|
||||||
|
}
|
||||||
|
if (event.key === 'Escape') {
|
||||||
|
event.preventDefault();
|
||||||
|
setPolygonPoints([]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener('keydown', handleKeyDown);
|
||||||
|
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||||
|
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||||
|
|
||||||
const boxRect = React.useMemo(() => {
|
const boxRect = React.useMemo(() => {
|
||||||
if (!boxStart || !boxCurrent) return null;
|
if (!boxStart || !boxCurrent) return null;
|
||||||
const x = Math.min(boxStart.x, boxCurrent.x);
|
const x = Math.min(boxStart.x, boxCurrent.x);
|
||||||
@@ -221,6 +597,132 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
return { x, y, width, height };
|
return { x, y, width, height };
|
||||||
}, [boxStart, boxCurrent]);
|
}, [boxStart, boxCurrent]);
|
||||||
|
|
||||||
|
const manualPreviewPath = React.useMemo(() => {
|
||||||
|
if (manualStart && manualCurrent) {
|
||||||
|
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
|
||||||
|
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
|
||||||
|
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
|
||||||
|
}
|
||||||
|
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
|
||||||
|
const previewPoints = [...polygonPoints, cursorPos];
|
||||||
|
return polygonPath(previewPoints);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}, [cursorPos, effectiveTool, manualCurrent, manualStart, polygonPoints]);
|
||||||
|
|
||||||
|
const handleSeedPointDragEnd = (mask: Mask, pointIndex: number, event: any) => {
|
||||||
|
const x = event.target.x();
|
||||||
|
const y = event.target.y();
|
||||||
|
const nextPoints = [...(mask.points || [])];
|
||||||
|
nextPoints[pointIndex] = [x, y];
|
||||||
|
updateMask(mask.id, {
|
||||||
|
points: nextPoints,
|
||||||
|
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||||
|
saved: mask.annotationId ? false : mask.saved,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
if (BOOLEAN_TOOLS.has(effectiveTool)) {
|
||||||
|
setSelectedMaskIds((current) => (
|
||||||
|
current.includes(mask.id)
|
||||||
|
? current.filter((id) => id !== mask.id)
|
||||||
|
: [...current, mask.id]
|
||||||
|
));
|
||||||
|
setSelectedMaskId(mask.id);
|
||||||
|
setSelectedPolygonIndex(polygonIndex);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setSelectedMaskId(mask.id);
|
||||||
|
setSelectedMaskIds([mask.id]);
|
||||||
|
setSelectedPolygonIndex(polygonIndex);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => {
|
||||||
|
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
|
||||||
|
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
|
||||||
|
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
|
||||||
|
if (!currentPoints[vertexIndex]) return;
|
||||||
|
const nextPoints = currentPoints.map((point, index) => (
|
||||||
|
index === vertexIndex
|
||||||
|
? {
|
||||||
|
x: clamp(event.target.x(), 0, imageWidth),
|
||||||
|
y: clamp(event.target.y(), 0, imageHeight),
|
||||||
|
}
|
||||||
|
: point
|
||||||
|
));
|
||||||
|
setSelectedMaskId(mask.id);
|
||||||
|
setSelectedVertexIndex(vertexIndex);
|
||||||
|
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEdgeInsert = (mask: Mask, edgeIndex: number, event: any) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
|
||||||
|
const start = currentPoints[edgeIndex];
|
||||||
|
const end = currentPoints[(edgeIndex + 1) % currentPoints.length];
|
||||||
|
if (!start || !end) return;
|
||||||
|
const inserted = { x: (start.x + end.x) / 2, y: (start.y + end.y) / 2 };
|
||||||
|
const nextPoints = [
|
||||||
|
...currentPoints.slice(0, edgeIndex + 1),
|
||||||
|
inserted,
|
||||||
|
...currentPoints.slice(edgeIndex + 1),
|
||||||
|
];
|
||||||
|
setSelectedMaskId(mask.id);
|
||||||
|
setSelectedVertexIndex(edgeIndex + 1);
|
||||||
|
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleBooleanOperation = async () => {
|
||||||
|
if (!frame || booleanSelectedMasks.length < 2) return;
|
||||||
|
const primary = booleanSelectedMasks[0];
|
||||||
|
const primaryGeometry = maskToMultiPolygon(primary);
|
||||||
|
if (!primaryGeometry) return;
|
||||||
|
|
||||||
|
const clipGeometries = booleanSelectedMasks
|
||||||
|
.slice(1)
|
||||||
|
.map(maskToMultiPolygon)
|
||||||
|
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
|
||||||
|
if (clipGeometries.length === 0) return;
|
||||||
|
|
||||||
|
const resultGeometry = effectiveTool === 'area_merge'
|
||||||
|
? polygonClipping.union(primaryGeometry, ...clipGeometries)
|
||||||
|
: polygonClipping.difference(primaryGeometry, ...clipGeometries);
|
||||||
|
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
|
||||||
|
|
||||||
|
if (resultSegmentation.length === 0) {
|
||||||
|
const deleteIds = primary.annotationId ? [primary.annotationId] : [];
|
||||||
|
setMasks(masks.filter((mask) => mask.id !== primary.id));
|
||||||
|
if (deleteIds.length > 0) await onDeleteMaskAnnotations?.(deleteIds);
|
||||||
|
setSelectedMaskId(null);
|
||||||
|
setSelectedMaskIds([]);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
|
||||||
|
const secondaryIds = effectiveTool === 'area_merge'
|
||||||
|
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
|
||||||
|
: new Set<string>();
|
||||||
|
const secondaryAnnotationIds = effectiveTool === 'area_merge'
|
||||||
|
? booleanSelectedMasks
|
||||||
|
.slice(1)
|
||||||
|
.map((mask) => mask.annotationId)
|
||||||
|
.filter((annotationId): annotationId is string => Boolean(annotationId))
|
||||||
|
: [];
|
||||||
|
|
||||||
|
setMasks(masks
|
||||||
|
.filter((mask) => !secondaryIds.has(mask.id))
|
||||||
|
.map((mask) => (mask.id === primary.id ? nextPrimary : mask)));
|
||||||
|
if (secondaryAnnotationIds.length > 0) await onDeleteMaskAnnotations?.(secondaryAnnotationIds);
|
||||||
|
setSelectedMaskId(primary.id);
|
||||||
|
setSelectedMaskIds([primary.id]);
|
||||||
|
setSelectedVertexIndex(null);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
|
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
|
||||||
{isInferencing && (
|
{isInferencing && (
|
||||||
@@ -257,13 +759,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
|
|
||||||
{/* AI Returned Masks */}
|
{/* AI Returned Masks */}
|
||||||
{frameMasks.map((mask) => (
|
{frameMasks.map((mask) => (
|
||||||
<Group key={mask.id} opacity={0.5}>
|
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||||
|
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
|
||||||
<Path
|
<Path
|
||||||
data={mask.pathData}
|
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||||
|
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
|
||||||
fill={mask.color}
|
fill={mask.color}
|
||||||
stroke={mask.color}
|
stroke={mask.color}
|
||||||
strokeWidth={1 / scale}
|
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||||
|
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||||
|
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||||
/>
|
/>
|
||||||
|
))}
|
||||||
</Group>
|
</Group>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
@@ -281,6 +788,86 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Manual shape preview */}
|
||||||
|
{manualPreviewPath && (
|
||||||
|
<Path
|
||||||
|
data={manualPreviewPath}
|
||||||
|
fill="rgba(34, 211, 238, 0.12)"
|
||||||
|
stroke="#22d3ee"
|
||||||
|
strokeWidth={2 / scale}
|
||||||
|
dash={[5 / scale, 5 / scale]}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{polygonPoints.map((point, index) => (
|
||||||
|
<Circle
|
||||||
|
key={`poly-point-${index}`}
|
||||||
|
x={point.x}
|
||||||
|
y={point.y}
|
||||||
|
radius={4 / scale}
|
||||||
|
fill="#22d3ee"
|
||||||
|
stroke="#ffffff"
|
||||||
|
strokeWidth={1 / scale}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* Imported GT seed points / editable point regions */}
|
||||||
|
{frameMasks.flatMap((mask) => (mask.points || []).map(([x, y], index) => (
|
||||||
|
<Group key={`${mask.id}-seed-${index}`} x={x} y={y}>
|
||||||
|
<Circle
|
||||||
|
radius={5 / scale}
|
||||||
|
fill="#facc15"
|
||||||
|
stroke="#111827"
|
||||||
|
strokeWidth={2 / scale}
|
||||||
|
draggable
|
||||||
|
onDragEnd={(event: any) => handleSeedPointDragEnd(mask, index, event)}
|
||||||
|
/>
|
||||||
|
<Circle radius={1.5 / scale} fill="#111827" />
|
||||||
|
</Group>
|
||||||
|
)))}
|
||||||
|
|
||||||
|
{/* Polygon edge insertion handles */}
|
||||||
|
{selectedMask && selectedMaskPoints.map((point, index) => {
|
||||||
|
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||||
|
if (!next) return null;
|
||||||
|
return (
|
||||||
|
<Circle
|
||||||
|
key={`${selectedMask.id}-edge-${selectedPolygonIndex}-${index}`}
|
||||||
|
x={(point.x + next.x) / 2}
|
||||||
|
y={(point.y + next.y) / 2}
|
||||||
|
radius={3.5 / scale}
|
||||||
|
fill="#22d3ee"
|
||||||
|
stroke="#111827"
|
||||||
|
strokeWidth={1.5 / scale}
|
||||||
|
onClick={(event: any) => handleEdgeInsert(selectedMask, index, event)}
|
||||||
|
onTap={(event: any) => handleEdgeInsert(selectedMask, index, event)}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
{/* Polygon vertex editor */}
|
||||||
|
{selectedMask && selectedMaskPoints.map((point, index) => (
|
||||||
|
<Circle
|
||||||
|
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||||
|
x={point.x}
|
||||||
|
y={point.y}
|
||||||
|
radius={(selectedVertexIndex === index ? 6 : 4.5) / scale}
|
||||||
|
fill={selectedVertexIndex === index ? '#22d3ee' : '#ffffff'}
|
||||||
|
stroke={selectedMask.color}
|
||||||
|
strokeWidth={2 / scale}
|
||||||
|
draggable
|
||||||
|
onClick={(event: any) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
setSelectedVertexIndex(index);
|
||||||
|
}}
|
||||||
|
onTap={(event: any) => {
|
||||||
|
event.cancelBubble = true;
|
||||||
|
setSelectedVertexIndex(index);
|
||||||
|
}}
|
||||||
|
onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
{/* AI Prompts Point Regions */}
|
{/* AI Prompts Point Regions */}
|
||||||
{points.map((p, i) => (
|
{points.map((p, i) => (
|
||||||
<Group key={i} x={p.x} y={p.y}>
|
<Group key={i} x={p.x} y={p.y}>
|
||||||
@@ -313,6 +900,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
|||||||
|
|
||||||
{frameMasks.length > 0 && (
|
{frameMasks.length > 0 && (
|
||||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||||
|
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
|
||||||
|
<button
|
||||||
|
onClick={handleBooleanOperation}
|
||||||
|
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
|
||||||
|
>
|
||||||
|
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
{activeClass && (
|
{activeClass && (
|
||||||
<button
|
<button
|
||||||
onClick={handleApplyActiveClass}
|
onClick={handleApplyActiveClass}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import { act, render, screen, waitFor } from '@testing-library/react';
|
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||||
import { Dashboard } from './Dashboard';
|
import { Dashboard } from './Dashboard';
|
||||||
|
|
||||||
const apiMock = vi.hoisted(() => ({
|
const apiMock = vi.hoisted(() => ({
|
||||||
getDashboardOverview: vi.fn(),
|
getDashboardOverview: vi.fn(),
|
||||||
|
cancelTask: vi.fn(),
|
||||||
|
retryTask: vi.fn(),
|
||||||
|
getTask: vi.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const wsMock = vi.hoisted(() => {
|
const wsMock = vi.hoisted(() => {
|
||||||
@@ -31,6 +34,9 @@ vi.mock('../lib/websocket', () => ({
|
|||||||
|
|
||||||
vi.mock('../lib/api', () => ({
|
vi.mock('../lib/api', () => ({
|
||||||
getDashboardOverview: apiMock.getDashboardOverview,
|
getDashboardOverview: apiMock.getDashboardOverview,
|
||||||
|
cancelTask: apiMock.cancelTask,
|
||||||
|
retryTask: apiMock.retryTask,
|
||||||
|
getTask: apiMock.getTask,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('Dashboard', () => {
|
describe('Dashboard', () => {
|
||||||
@@ -55,6 +61,8 @@ describe('Dashboard', () => {
|
|||||||
name: '真实项目.mp4',
|
name: '真实项目.mp4',
|
||||||
progress: 60,
|
progress: 60,
|
||||||
status: 'pending',
|
status: 'pending',
|
||||||
|
raw_status: 'running',
|
||||||
|
error: null,
|
||||||
frame_count: 10,
|
frame_count: 10,
|
||||||
updated_at: '2026-05-01T00:00:00Z',
|
updated_at: '2026-05-01T00:00:00Z',
|
||||||
},
|
},
|
||||||
@@ -112,4 +120,100 @@ describe('Dashboard', () => {
|
|||||||
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
|
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
|
||||||
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
|
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('cancels, retries, and opens task failure details', async () => {
|
||||||
|
apiMock.getDashboardOverview.mockResolvedValueOnce({
|
||||||
|
summary: {
|
||||||
|
project_count: 1,
|
||||||
|
parsing_task_count: 1,
|
||||||
|
annotation_count: 0,
|
||||||
|
frame_count: 0,
|
||||||
|
template_count: 0,
|
||||||
|
system_load_percent: 5,
|
||||||
|
},
|
||||||
|
tasks: [
|
||||||
|
{
|
||||||
|
id: 'task-10',
|
||||||
|
task_id: 10,
|
||||||
|
project_id: 1,
|
||||||
|
name: 'running.mp4',
|
||||||
|
progress: 30,
|
||||||
|
status: '正在下载媒体文件',
|
||||||
|
raw_status: 'running',
|
||||||
|
error: null,
|
||||||
|
frame_count: 0,
|
||||||
|
updated_at: '2026-05-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'task-11',
|
||||||
|
task_id: 11,
|
||||||
|
project_id: 1,
|
||||||
|
name: 'failed.mp4',
|
||||||
|
progress: 100,
|
||||||
|
status: '解析失败',
|
||||||
|
raw_status: 'failed',
|
||||||
|
error: 'ffmpeg failed',
|
||||||
|
frame_count: 0,
|
||||||
|
updated_at: '2026-05-01T00:01:00Z',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
activity: [],
|
||||||
|
});
|
||||||
|
apiMock.cancelTask.mockResolvedValueOnce({
|
||||||
|
id: 10,
|
||||||
|
task_type: 'parse_video',
|
||||||
|
status: 'cancelled',
|
||||||
|
progress: 100,
|
||||||
|
message: '任务已取消',
|
||||||
|
project_id: 1,
|
||||||
|
error: 'Cancelled by user',
|
||||||
|
result: null,
|
||||||
|
payload: { source_type: 'video' },
|
||||||
|
created_at: 'created',
|
||||||
|
updated_at: 'updated',
|
||||||
|
});
|
||||||
|
apiMock.retryTask.mockResolvedValueOnce({
|
||||||
|
id: 12,
|
||||||
|
task_type: 'parse_video',
|
||||||
|
status: 'queued',
|
||||||
|
progress: 0,
|
||||||
|
message: '重试任务已入队(源任务 #11)',
|
||||||
|
project_id: 1,
|
||||||
|
error: null,
|
||||||
|
result: null,
|
||||||
|
payload: { source_type: 'video', retry_of: 11 },
|
||||||
|
created_at: 'created',
|
||||||
|
updated_at: 'updated',
|
||||||
|
});
|
||||||
|
apiMock.getTask.mockResolvedValueOnce({
|
||||||
|
id: 11,
|
||||||
|
task_type: 'parse_video',
|
||||||
|
status: 'failed',
|
||||||
|
progress: 100,
|
||||||
|
message: '解析失败',
|
||||||
|
project_id: 1,
|
||||||
|
celery_task_id: 'celery-11',
|
||||||
|
payload: { source_type: 'video' },
|
||||||
|
result: null,
|
||||||
|
error: 'ffmpeg failed',
|
||||||
|
created_at: 'created',
|
||||||
|
started_at: 'started',
|
||||||
|
finished_at: 'finished',
|
||||||
|
updated_at: 'updated',
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<Dashboard />);
|
||||||
|
|
||||||
|
await screen.findByText('running.mp4');
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: '取消' }));
|
||||||
|
await waitFor(() => expect(apiMock.cancelTask).toHaveBeenCalledWith(10));
|
||||||
|
|
||||||
|
fireEvent.click(screen.getAllByRole('button', { name: '详情' })[1]);
|
||||||
|
await waitFor(() => expect(apiMock.getTask).toHaveBeenCalledWith(11));
|
||||||
|
expect(await screen.findByText('任务详情 #11')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('ffmpeg failed')).toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.click(screen.getAllByRole('button', { name: '重试' })[1]);
|
||||||
|
await waitFor(() => expect(apiMock.retryTask).toHaveBeenCalledWith(11));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
import React, { useState, useEffect } from 'react';
|
import React, { useState, useEffect } from 'react';
|
||||||
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
|
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
|
||||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||||
import { cn } from '../lib/utils';
|
import { cn } from '../lib/utils';
|
||||||
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
|
import {
|
||||||
|
cancelTask,
|
||||||
|
getDashboardOverview,
|
||||||
|
getTask,
|
||||||
|
retryTask,
|
||||||
|
type DashboardActivity,
|
||||||
|
type DashboardOverview,
|
||||||
|
type DashboardTask,
|
||||||
|
type ProcessingTask,
|
||||||
|
} from '../lib/api';
|
||||||
|
|
||||||
const emptySummary: DashboardOverview['summary'] = {
|
const emptySummary: DashboardOverview['summary'] = {
|
||||||
project_count: 0,
|
project_count: 0,
|
||||||
@@ -20,6 +29,29 @@ export function Dashboard() {
|
|||||||
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
|
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
const [loadError, setLoadError] = useState('');
|
const [loadError, setLoadError] = useState('');
|
||||||
|
const [selectedTask, setSelectedTask] = useState<ProcessingTask | null>(null);
|
||||||
|
const [taskActionMessage, setTaskActionMessage] = useState('');
|
||||||
|
const [busyTaskId, setBusyTaskId] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const taskFromProcessingTask = (task: ProcessingTask, name = `任务 ${task.id}`): DashboardTask => ({
|
||||||
|
id: `task-${task.id}`,
|
||||||
|
task_id: task.id,
|
||||||
|
project_id: task.project_id ?? 0,
|
||||||
|
name,
|
||||||
|
progress: task.progress,
|
||||||
|
status: task.message || task.status,
|
||||||
|
raw_status: task.status,
|
||||||
|
error: task.error,
|
||||||
|
frame_count: Number(task.result?.frames_extracted || 0),
|
||||||
|
updated_at: task.updated_at,
|
||||||
|
});
|
||||||
|
|
||||||
|
const prependActivity = (message: string, project = '系统') => {
|
||||||
|
setActivityLog((prev) => [
|
||||||
|
{ id: `task-action-${Date.now()}`, kind: 'task', time: new Date().toISOString(), message, project },
|
||||||
|
...prev.slice(0, 9),
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let cancelled = false;
|
let cancelled = false;
|
||||||
@@ -90,6 +122,8 @@ export function Dashboard() {
|
|||||||
name: taskTitle(data),
|
name: taskTitle(data),
|
||||||
progress: data.progress ?? 0,
|
progress: data.progress ?? 0,
|
||||||
status: data.status ?? '处理中',
|
status: data.status ?? '处理中',
|
||||||
|
raw_status: 'running',
|
||||||
|
error: data.error,
|
||||||
frame_count: 0,
|
frame_count: 0,
|
||||||
updated_at: new Date().toISOString(),
|
updated_at: new Date().toISOString(),
|
||||||
},
|
},
|
||||||
@@ -100,7 +134,7 @@ export function Dashboard() {
|
|||||||
if (data.type === 'complete' && data.taskId) {
|
if (data.type === 'complete' && data.taskId) {
|
||||||
setTasks((prev) =>
|
setTasks((prev) =>
|
||||||
prev.map((t) =>
|
prev.map((t) =>
|
||||||
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t
|
t.id === data.taskId ? { ...t, progress: 100, status: '已完成', raw_status: 'success' } : t
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
setActivityLog((prev) => [
|
setActivityLog((prev) => [
|
||||||
@@ -109,10 +143,26 @@ export function Dashboard() {
|
|||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (data.type === 'cancelled' && data.taskId) {
|
||||||
|
setTasks((prev) =>
|
||||||
|
prev.map((t) =>
|
||||||
|
t.id === data.taskId
|
||||||
|
? { ...t, progress: 100, status: data.message || '任务已取消', raw_status: 'cancelled', error: data.error }
|
||||||
|
: t
|
||||||
|
)
|
||||||
|
);
|
||||||
|
setActivityLog((prev) => [
|
||||||
|
{ id: `ws-cancelled-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `任务已取消: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||||
|
...prev.slice(0, 9),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
if (data.type === 'error' && data.taskId) {
|
if (data.type === 'error' && data.taskId) {
|
||||||
setTasks((prev) =>
|
setTasks((prev) =>
|
||||||
prev.map((t) =>
|
prev.map((t) =>
|
||||||
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
|
t.id === data.taskId
|
||||||
|
? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}`, raw_status: 'failed', error: data.error }
|
||||||
|
: t
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
setActivityLog((prev) => [
|
setActivityLog((prev) => [
|
||||||
@@ -160,6 +210,65 @@ export function Dashboard() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const taskRawStatus = (task: DashboardTask): string => task.raw_status || (
|
||||||
|
task.status.includes('取消') ? 'cancelled'
|
||||||
|
: task.status.includes('失败') || task.status.includes('错误') ? 'failed'
|
||||||
|
: task.progress >= 100 ? 'success'
|
||||||
|
: 'running'
|
||||||
|
);
|
||||||
|
|
||||||
|
const canCancel = (task: DashboardTask): boolean => ['queued', 'running'].includes(taskRawStatus(task)) && Boolean(task.task_id);
|
||||||
|
const canRetry = (task: DashboardTask): boolean => ['failed', 'cancelled'].includes(taskRawStatus(task)) && Boolean(task.task_id);
|
||||||
|
|
||||||
|
const handleCancelTask = async (task: DashboardTask) => {
|
||||||
|
if (!task.task_id) return;
|
||||||
|
setBusyTaskId(task.id);
|
||||||
|
setTaskActionMessage('');
|
||||||
|
try {
|
||||||
|
const updated = await cancelTask(task.task_id);
|
||||||
|
setTasks((prev) => prev.map((item) => (
|
||||||
|
item.id === task.id ? taskFromProcessingTask(updated, task.name) : item
|
||||||
|
)));
|
||||||
|
prependActivity(`任务已取消 #${updated.id}`, task.name);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Cancel task failed:', err);
|
||||||
|
setTaskActionMessage('任务取消失败,请检查后端服务');
|
||||||
|
} finally {
|
||||||
|
setBusyTaskId(null);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRetryTask = async (task: DashboardTask) => {
|
||||||
|
if (!task.task_id) return;
|
||||||
|
setBusyTaskId(task.id);
|
||||||
|
setTaskActionMessage('');
|
||||||
|
try {
|
||||||
|
const retried = await retryTask(task.task_id);
|
||||||
|
const dashboardTask = taskFromProcessingTask(retried, task.name);
|
||||||
|
setTasks((prev) => [dashboardTask, ...prev.filter((item) => item.id !== dashboardTask.id)]);
|
||||||
|
prependActivity(`重试任务已入队 #${retried.id}`, task.name);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Retry task failed:', err);
|
||||||
|
setTaskActionMessage('任务重试失败,请检查后端服务');
|
||||||
|
} finally {
|
||||||
|
setBusyTaskId(null);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpenTaskDetail = async (task: DashboardTask) => {
|
||||||
|
if (!task.task_id) return;
|
||||||
|
setBusyTaskId(task.id);
|
||||||
|
setTaskActionMessage('');
|
||||||
|
try {
|
||||||
|
setSelectedTask(await getTask(task.task_id));
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Load task detail failed:', err);
|
||||||
|
setTaskActionMessage('失败详情加载失败');
|
||||||
|
} finally {
|
||||||
|
setBusyTaskId(null);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||||
<header className="mb-8">
|
<header className="mb-8">
|
||||||
@@ -177,6 +286,7 @@ export function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
||||||
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
|
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
|
||||||
|
{taskActionMessage && <p className="text-amber-400 text-xs mt-2">{taskActionMessage}</p>}
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
|
||||||
@@ -213,16 +323,47 @@ export function Dashboard() {
|
|||||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
||||||
</div>
|
</div>
|
||||||
<div className="text-xs text-gray-500 flex items-center gap-2">
|
<div className="text-xs text-gray-500 flex items-center gap-2">
|
||||||
{task.status === '已完成' || task.progress >= 100 ? (
|
{taskRawStatus(task) === 'success' || task.status === '已完成' ? (
|
||||||
<CheckCircle2 size={12} className="text-emerald-400" />
|
<CheckCircle2 size={12} className="text-emerald-400" />
|
||||||
) : task.status.includes('错误') ? (
|
) : taskRawStatus(task) === 'failed' ? (
|
||||||
<span className="text-red-400">●</span>
|
<AlertTriangle size={12} className="text-red-400" />
|
||||||
|
) : taskRawStatus(task) === 'cancelled' ? (
|
||||||
|
<XCircle size={12} className="text-amber-400" />
|
||||||
) : (
|
) : (
|
||||||
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
||||||
)}
|
)}
|
||||||
{task.status}
|
{task.status}
|
||||||
<span className="text-gray-600">帧: {task.frame_count}</span>
|
<span className="text-gray-600">帧: {task.frame_count}</span>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="mt-3 flex flex-wrap items-center gap-2">
|
||||||
|
{canCancel(task) && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleCancelTask(task)}
|
||||||
|
disabled={busyTaskId === task.id}
|
||||||
|
className="inline-flex items-center gap-1 rounded border border-red-500/20 bg-red-500/10 px-2 py-1 text-[11px] text-red-300 hover:bg-red-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
|
<XCircle size={12} /> 取消
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
{canRetry(task) && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleRetryTask(task)}
|
||||||
|
disabled={busyTaskId === task.id}
|
||||||
|
className="inline-flex items-center gap-1 rounded border border-cyan-500/20 bg-cyan-500/10 px-2 py-1 text-[11px] text-cyan-300 hover:bg-cyan-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
|
<RotateCcw size={12} /> 重试
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
{task.task_id && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleOpenTaskDetail(task)}
|
||||||
|
disabled={busyTaskId === task.id}
|
||||||
|
className="inline-flex items-center gap-1 rounded border border-white/10 bg-white/5 px-2 py-1 text-[11px] text-gray-300 hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
|
<Info size={12} /> 详情
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
{!isLoading && tasks.length === 0 && (
|
{!isLoading && tasks.length === 0 && (
|
||||||
@@ -253,6 +394,46 @@ export function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{selectedTask && (
|
||||||
|
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70 px-4">
|
||||||
|
<div className="w-full max-w-2xl rounded-lg border border-white/10 bg-[#111] p-5 shadow-2xl">
|
||||||
|
<div className="flex items-center justify-between gap-3 border-b border-white/10 pb-3">
|
||||||
|
<div>
|
||||||
|
<h3 className="text-sm font-semibold text-white">任务详情 #{selectedTask.id}</h3>
|
||||||
|
<p className="mt-1 text-xs text-gray-500">{selectedTask.message || selectedTask.status}</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => setSelectedTask(null)}
|
||||||
|
className="rounded border border-white/10 bg-white/5 px-2 py-1 text-xs text-gray-300 hover:bg-white/10"
|
||||||
|
>
|
||||||
|
关闭
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div className="mt-4 grid grid-cols-2 gap-3 text-xs text-gray-400">
|
||||||
|
<div>状态: <span className="text-gray-200">{selectedTask.status}</span></div>
|
||||||
|
<div>进度: <span className="text-gray-200">{selectedTask.progress}%</span></div>
|
||||||
|
<div>项目 ID: <span className="text-gray-200">{selectedTask.project_id ?? '-'}</span></div>
|
||||||
|
<div>Celery ID: <span className="text-gray-200">{selectedTask.celery_task_id || '-'}</span></div>
|
||||||
|
<div>创建: <span className="text-gray-200">{selectedTask.created_at}</span></div>
|
||||||
|
<div>结束: <span className="text-gray-200">{selectedTask.finished_at || '-'}</span></div>
|
||||||
|
</div>
|
||||||
|
{selectedTask.error && (
|
||||||
|
<div className="mt-4 rounded border border-red-500/20 bg-red-500/10 p-3 text-xs text-red-200">
|
||||||
|
{selectedTask.error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="mt-4 grid gap-3 md:grid-cols-2">
|
||||||
|
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
|
||||||
|
{JSON.stringify(selectedTask.payload || {}, null, 2)}
|
||||||
|
</pre>
|
||||||
|
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
|
||||||
|
{JSON.stringify(selectedTask.result || {}, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,18 +3,31 @@ import { describe, expect, it, vi } from 'vitest';
|
|||||||
import { ToolsPalette } from './ToolsPalette';
|
import { ToolsPalette } from './ToolsPalette';
|
||||||
|
|
||||||
describe('ToolsPalette', () => {
|
describe('ToolsPalette', () => {
|
||||||
it('switches tools and exposes UI-only placeholder buttons', () => {
|
it('switches tools and dispatches undo/redo actions when available', () => {
|
||||||
const setActiveTool = vi.fn();
|
const setActiveTool = vi.fn();
|
||||||
|
const onUndo = vi.fn();
|
||||||
|
const onRedo = vi.fn();
|
||||||
|
|
||||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
|
render(
|
||||||
|
<ToolsPalette
|
||||||
|
activeTool="move"
|
||||||
|
setActiveTool={setActiveTool}
|
||||||
|
onUndo={onUndo}
|
||||||
|
onRedo={onRedo}
|
||||||
|
canUndo
|
||||||
|
canRedo
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||||
|
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
||||||
|
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
||||||
|
|
||||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
||||||
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
|
expect(onUndo).toHaveBeenCalled();
|
||||||
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
|
expect(onRedo).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
||||||
|
|||||||
@@ -6,9 +6,21 @@ interface ToolsPaletteProps {
|
|||||||
activeTool: string;
|
activeTool: string;
|
||||||
setActiveTool: (tool: string) => void;
|
setActiveTool: (tool: string) => void;
|
||||||
onTriggerAI?: () => void;
|
onTriggerAI?: () => void;
|
||||||
|
onUndo?: () => void;
|
||||||
|
onRedo?: () => void;
|
||||||
|
canUndo?: boolean;
|
||||||
|
canRedo?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPaletteProps) {
|
export function ToolsPalette({
|
||||||
|
activeTool,
|
||||||
|
setActiveTool,
|
||||||
|
onTriggerAI,
|
||||||
|
onUndo,
|
||||||
|
onRedo,
|
||||||
|
canUndo = false,
|
||||||
|
canRedo = false,
|
||||||
|
}: ToolsPaletteProps) {
|
||||||
const tools = [
|
const tools = [
|
||||||
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
||||||
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
||||||
@@ -91,10 +103,20 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
|||||||
|
|
||||||
<div className="w-full h-px bg-white/10 my-1" />
|
<div className="w-full h-px bg-white/10 my-1" />
|
||||||
|
|
||||||
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
<button
|
||||||
|
onClick={onUndo}
|
||||||
|
disabled={!canUndo}
|
||||||
|
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||||
|
title="撤销操作 (Ctrl+Z)"
|
||||||
|
>
|
||||||
<Undo size={18} />
|
<Undo size={18} />
|
||||||
</button>
|
</button>
|
||||||
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
|
<button
|
||||||
|
onClick={onRedo}
|
||||||
|
disabled={!canRedo}
|
||||||
|
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||||
|
title="重做操作 (Ctrl+Shift+Z)"
|
||||||
|
>
|
||||||
<Redo size={18} />
|
<Redo size={18} />
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ const apiMock = vi.hoisted(() => ({
|
|||||||
updateAnnotation: vi.fn(),
|
updateAnnotation: vi.fn(),
|
||||||
deleteAnnotation: vi.fn(),
|
deleteAnnotation: vi.fn(),
|
||||||
exportCoco: vi.fn(),
|
exportCoco: vi.fn(),
|
||||||
|
exportMasks: vi.fn(),
|
||||||
|
importGtMask: vi.fn(),
|
||||||
annotationToMask: vi.fn(),
|
annotationToMask: vi.fn(),
|
||||||
buildAnnotationPayload: vi.fn(),
|
buildAnnotationPayload: vi.fn(),
|
||||||
getAiModelStatus: vi.fn(),
|
getAiModelStatus: vi.fn(),
|
||||||
@@ -29,6 +31,8 @@ vi.mock('../lib/api', () => ({
|
|||||||
updateAnnotation: apiMock.updateAnnotation,
|
updateAnnotation: apiMock.updateAnnotation,
|
||||||
deleteAnnotation: apiMock.deleteAnnotation,
|
deleteAnnotation: apiMock.deleteAnnotation,
|
||||||
exportCoco: apiMock.exportCoco,
|
exportCoco: apiMock.exportCoco,
|
||||||
|
exportMasks: apiMock.exportMasks,
|
||||||
|
importGtMask: apiMock.importGtMask,
|
||||||
annotationToMask: apiMock.annotationToMask,
|
annotationToMask: apiMock.annotationToMask,
|
||||||
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
||||||
getAiModelStatus: apiMock.getAiModelStatus,
|
getAiModelStatus: apiMock.getAiModelStatus,
|
||||||
@@ -256,4 +260,64 @@ describe('VideoWorkspace', () => {
|
|||||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||||
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('auto-saves pending masks before exporting PNG masks', async () => {
|
||||||
|
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||||
|
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||||
|
]);
|
||||||
|
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||||
|
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||||
|
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||||
|
|
||||||
|
render(<VideoWorkspace />);
|
||||||
|
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||||
|
act(() => {
|
||||||
|
useStore.setState({
|
||||||
|
masks: [{
|
||||||
|
id: 'mask-1',
|
||||||
|
frameId: '10',
|
||||||
|
pathData: 'M 0 0 Z',
|
||||||
|
label: 'AI Mask',
|
||||||
|
color: '#06b6d4',
|
||||||
|
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
|
||||||
|
|
||||||
|
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||||
|
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
|
||||||
|
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||||
|
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||||
|
]);
|
||||||
|
apiMock.importGtMask.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
|
||||||
|
apiMock.getProjectAnnotations
|
||||||
|
.mockResolvedValueOnce([])
|
||||||
|
.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
|
||||||
|
apiMock.annotationToMask.mockReturnValueOnce({
|
||||||
|
id: 'annotation-88',
|
||||||
|
annotationId: '88',
|
||||||
|
frameId: '10',
|
||||||
|
saved: true,
|
||||||
|
pathData: 'M 0 0 Z',
|
||||||
|
label: 'GT Mask',
|
||||||
|
color: '#22c55e',
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<VideoWorkspace />);
|
||||||
|
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||||
|
|
||||||
|
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
|
||||||
|
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
|
||||||
|
fireEvent.change(fileInput, { target: { files: [file] } });
|
||||||
|
|
||||||
|
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
|
||||||
|
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||||
|
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
|
||||||
|
]));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import {
|
|||||||
buildAnnotationPayload,
|
buildAnnotationPayload,
|
||||||
deleteAnnotation,
|
deleteAnnotation,
|
||||||
exportCoco,
|
exportCoco,
|
||||||
|
exportMasks,
|
||||||
getProjectAnnotations,
|
getProjectAnnotations,
|
||||||
getProjectFrames,
|
getProjectFrames,
|
||||||
getTask,
|
getTask,
|
||||||
getTemplates,
|
getTemplates,
|
||||||
|
importGtMask,
|
||||||
parseMedia,
|
parseMedia,
|
||||||
saveAnnotation,
|
saveAnnotation,
|
||||||
updateAnnotation,
|
updateAnnotation,
|
||||||
@@ -25,18 +27,24 @@ function sleep(ms: number) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||||
|
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
||||||
const activeTool = useStore((state) => state.activeTool);
|
const activeTool = useStore((state) => state.activeTool);
|
||||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||||
const currentProject = useStore((state) => state.currentProject);
|
const currentProject = useStore((state) => state.currentProject);
|
||||||
const frames = useStore((state) => state.frames);
|
const frames = useStore((state) => state.frames);
|
||||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||||
const masks = useStore((state) => state.masks);
|
const masks = useStore((state) => state.masks);
|
||||||
|
const maskHistory = useStore((state) => state.maskHistory);
|
||||||
|
const maskFuture = useStore((state) => state.maskFuture);
|
||||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||||
const setFrames = useStore((state) => state.setFrames);
|
const setFrames = useStore((state) => state.setFrames);
|
||||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||||
const setMasks = useStore((state) => state.setMasks);
|
const setMasks = useStore((state) => state.setMasks);
|
||||||
|
const undoMasks = useStore((state) => state.undoMasks);
|
||||||
|
const redoMasks = useStore((state) => state.redoMasks);
|
||||||
const [isSaving, setIsSaving] = useState(false);
|
const [isSaving, setIsSaving] = useState(false);
|
||||||
const [isExporting, setIsExporting] = useState(false);
|
const [isExporting, setIsExporting] = useState(false);
|
||||||
|
const [isImportingGt, setIsImportingGt] = useState(false);
|
||||||
const [statusMessage, setStatusMessage] = useState('');
|
const [statusMessage, setStatusMessage] = useState('');
|
||||||
|
|
||||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||||
@@ -216,6 +224,18 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
|||||||
}
|
}
|
||||||
}, [currentFrame, masks, setMasks]);
|
}, [currentFrame, masks, setMasks]);
|
||||||
|
|
||||||
|
const handleDeleteMaskAnnotations = useCallback(async (annotationIds: string[]) => {
|
||||||
|
if (annotationIds.length === 0) return;
|
||||||
|
try {
|
||||||
|
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
|
||||||
|
setStatusMessage(`已删除 ${annotationIds.length} 个被合并标注`);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Delete merged annotations failed:', err);
|
||||||
|
setStatusMessage('合并后删除原标注失败,请检查后端服务');
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
const handleSave = async () => {
|
const handleSave = async () => {
|
||||||
try {
|
try {
|
||||||
await savePendingAnnotations();
|
await savePendingAnnotations();
|
||||||
@@ -248,6 +268,52 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const downloadBlob = (blob: Blob, filename: string) => {
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const link = document.createElement('a');
|
||||||
|
link.href = url;
|
||||||
|
link.download = filename;
|
||||||
|
document.body.appendChild(link);
|
||||||
|
link.click();
|
||||||
|
link.remove();
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleExportMasks = async () => {
|
||||||
|
if (!currentProject?.id) return;
|
||||||
|
setIsExporting(true);
|
||||||
|
setStatusMessage('正在准备导出语义 Mask ZIP...');
|
||||||
|
try {
|
||||||
|
await savePendingAnnotations({ silent: true });
|
||||||
|
const blob = await exportMasks(currentProject.id);
|
||||||
|
downloadBlob(blob, `project_${currentProject.id}_masks.zip`);
|
||||||
|
setStatusMessage('PNG Mask ZIP 已导出');
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Mask export failed:', err);
|
||||||
|
setStatusMessage('Mask 导出失败,请检查后端服务');
|
||||||
|
} finally {
|
||||||
|
setIsExporting(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleImportGtMask = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const file = event.target.files?.[0];
|
||||||
|
if (!file || !currentProject?.id || !currentFrame?.id) return;
|
||||||
|
setIsImportingGt(true);
|
||||||
|
setStatusMessage('正在导入 GT Mask...');
|
||||||
|
try {
|
||||||
|
const imported = await importGtMask(file, currentProject.id, currentFrame.id);
|
||||||
|
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||||
|
setStatusMessage(`已导入 ${imported.length} 个 GT 区域`);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('GT mask import failed:', err);
|
||||||
|
setStatusMessage('GT Mask 导入失败,请检查文件或后端服务');
|
||||||
|
} finally {
|
||||||
|
setIsImportingGt(false);
|
||||||
|
event.target.value = '';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||||
{/* Top Header / Status bar */}
|
{/* Top Header / Status bar */}
|
||||||
@@ -264,6 +330,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
|||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
<ModelStatusBadge />
|
<ModelStatusBadge />
|
||||||
|
<input
|
||||||
|
ref={gtMaskInputRef}
|
||||||
|
type="file"
|
||||||
|
accept="image/png,image/jpeg,image/bmp,image/tiff"
|
||||||
|
className="hidden"
|
||||||
|
onChange={handleImportGtMask}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
onClick={() => gtMaskInputRef.current?.click()}
|
||||||
|
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
|
||||||
|
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
|
{isImportingGt ? '导入中...' : '导入 GT Mask'}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={handleExportMasks}
|
||||||
|
disabled={!currentProject?.id || isExporting || isSaving}
|
||||||
|
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
|
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
|
||||||
|
</button>
|
||||||
<button
|
<button
|
||||||
onClick={handleExport}
|
onClick={handleExport}
|
||||||
disabled={!currentProject?.id || isExporting || isSaving}
|
disabled={!currentProject?.id || isExporting || isSaving}
|
||||||
@@ -283,11 +370,24 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
|||||||
|
|
||||||
{/* Main Workspace Area */}
|
{/* Main Workspace Area */}
|
||||||
<div className="flex-1 flex overflow-hidden">
|
<div className="flex-1 flex overflow-hidden">
|
||||||
<ToolsPalette activeTool={activeTool} setActiveTool={setActiveTool} onTriggerAI={onNavigateToAI} />
|
<ToolsPalette
|
||||||
|
activeTool={activeTool}
|
||||||
|
setActiveTool={setActiveTool}
|
||||||
|
onTriggerAI={onNavigateToAI}
|
||||||
|
onUndo={undoMasks}
|
||||||
|
onRedo={redoMasks}
|
||||||
|
canUndo={maskHistory.length > 0}
|
||||||
|
canRedo={maskFuture.length > 0}
|
||||||
|
/>
|
||||||
|
|
||||||
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
|
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
|
||||||
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
|
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
|
||||||
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
|
<CanvasArea
|
||||||
|
activeTool={activeTool}
|
||||||
|
frame={currentFrame}
|
||||||
|
onClearMasks={handleClearCurrentFrameMasks}
|
||||||
|
onDeleteMaskAnnotations={handleDeleteMaskAnnotations}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -101,6 +101,17 @@ describe('api client contracts', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('exports PNG masks from the backend route shape', async () => {
|
||||||
|
const { exportMasks } = await import('./api');
|
||||||
|
const blob = new Blob(['zip'], { type: 'application/zip' });
|
||||||
|
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
|
||||||
|
|
||||||
|
await expect(exportMasks('9')).resolves.toBe(blob);
|
||||||
|
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/masks', {
|
||||||
|
responseType: 'blob',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it('loads dashboard overview from the backend summary endpoint', async () => {
|
it('loads dashboard overview from the backend summary endpoint', async () => {
|
||||||
const { getDashboardOverview } = await import('./api');
|
const { getDashboardOverview } = await import('./api');
|
||||||
const overview = {
|
const overview = {
|
||||||
@@ -125,8 +136,8 @@ describe('api client contracts', () => {
|
|||||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
|
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('queues media parsing and reads processing task status', async () => {
|
it('queues media parsing and manages processing task lifecycle', async () => {
|
||||||
const { getTask, parseMedia } = await import('./api');
|
const { cancelTask, getTask, parseMedia, retryTask } = await import('./api');
|
||||||
const task = {
|
const task = {
|
||||||
id: 12,
|
id: 12,
|
||||||
task_type: 'parse_video',
|
task_type: 'parse_video',
|
||||||
@@ -145,6 +156,8 @@ describe('api client contracts', () => {
|
|||||||
};
|
};
|
||||||
axiosMock.client.post.mockResolvedValueOnce({ data: task });
|
axiosMock.client.post.mockResolvedValueOnce({ data: task });
|
||||||
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
|
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
|
||||||
|
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
|
||||||
|
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
|
||||||
|
|
||||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||||
@@ -153,6 +166,12 @@ describe('api client contracts', () => {
|
|||||||
|
|
||||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
|
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
|
||||||
|
|
||||||
|
await expect(cancelTask(12)).resolves.toEqual(expect.objectContaining({ status: 'cancelled' }));
|
||||||
|
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/cancel');
|
||||||
|
|
||||||
|
await expect(retryTask(12)).resolves.toEqual(expect.objectContaining({ id: 13, status: 'queued' }));
|
||||||
|
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/retry');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||||
@@ -204,6 +223,25 @@ describe('api client contracts', () => {
|
|||||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('imports GT masks through multipart form data', async () => {
|
||||||
|
const { importGtMask } = await import('./api');
|
||||||
|
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
|
||||||
|
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
|
||||||
|
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
|
||||||
|
|
||||||
|
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
|
||||||
|
expect(axiosMock.client.post).toHaveBeenCalledWith(
|
||||||
|
'/api/ai/import-gt-mask',
|
||||||
|
expect.any(FormData),
|
||||||
|
{ headers: { 'Content-Type': 'multipart/form-data' } },
|
||||||
|
);
|
||||||
|
const form = axiosMock.client.post.mock.calls.at(-1)?.[1] as FormData;
|
||||||
|
expect(form.get('file')).toBe(file);
|
||||||
|
expect(form.get('project_id')).toBe('9');
|
||||||
|
expect(form.get('frame_id')).toBe('5');
|
||||||
|
expect(form.get('template_id')).toBe('2');
|
||||||
|
});
|
||||||
|
|
||||||
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
||||||
const { annotationToMask, buildAnnotationPayload } = await import('./api');
|
const { annotationToMask, buildAnnotationPayload } = await import('./api');
|
||||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||||
@@ -244,7 +282,7 @@ describe('api client contracts', () => {
|
|||||||
color: '#06b6d4',
|
color: '#06b6d4',
|
||||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||||
},
|
},
|
||||||
points: null,
|
points: [[0.5, 0.5]],
|
||||||
bbox: null,
|
bbox: null,
|
||||||
created_at: 'created',
|
created_at: 'created',
|
||||||
updated_at: 'updated',
|
updated_at: 'updated',
|
||||||
@@ -261,10 +299,28 @@ describe('api client contracts', () => {
|
|||||||
saveStatus: 'saved',
|
saveStatus: 'saved',
|
||||||
saved: true,
|
saved: true,
|
||||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||||
|
points: [[50, 25]],
|
||||||
bbox: [10, 10, 80, 30],
|
bbox: [10, 10, 80, 30],
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('preserves editable point regions in annotation payloads', async () => {
|
||||||
|
const { buildAnnotationPayload } = await import('./api');
|
||||||
|
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||||
|
|
||||||
|
expect(buildAnnotationPayload('9', {
|
||||||
|
id: 'm1',
|
||||||
|
frameId: '5',
|
||||||
|
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||||
|
label: 'GT Mask',
|
||||||
|
color: '#22c55e',
|
||||||
|
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||||
|
points: [[50, 25]],
|
||||||
|
}, frame)).toEqual(expect.objectContaining({
|
||||||
|
points: [[0.5, 0.5]],
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
||||||
const { predictMask } = await import('./api');
|
const { predictMask } = await import('./api');
|
||||||
axiosMock.client.post.mockResolvedValueOnce({
|
axiosMock.client.post.mockResolvedValueOnce({
|
||||||
@@ -341,6 +397,38 @@ describe('api client contracts', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('passes AI post-processing options to prediction endpoint', async () => {
|
||||||
|
const { predictMask } = await import('./api');
|
||||||
|
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||||
|
|
||||||
|
await predictMask({
|
||||||
|
imageId: '7',
|
||||||
|
imageWidth: 640,
|
||||||
|
imageHeight: 360,
|
||||||
|
points: [{ x: 320, y: 180, type: 'pos' }],
|
||||||
|
options: {
|
||||||
|
crop_to_prompt: true,
|
||||||
|
auto_filter_background: true,
|
||||||
|
min_score: 0.05,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||||
|
image_id: 7,
|
||||||
|
prompt_type: 'point',
|
||||||
|
prompt_data: {
|
||||||
|
points: [[0.5, 0.5]],
|
||||||
|
labels: [1],
|
||||||
|
},
|
||||||
|
model: 'sam2',
|
||||||
|
options: {
|
||||||
|
crop_to_prompt: true,
|
||||||
|
auto_filter_background: true,
|
||||||
|
min_score: 0.05,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it('loads AI model and GPU runtime status', async () => {
|
it('loads AI model and GPU runtime status', async () => {
|
||||||
const { getAiModelStatus } = await import('./api');
|
const { getAiModelStatus } = await import('./api');
|
||||||
const status = {
|
const status = {
|
||||||
|
|||||||
@@ -197,6 +197,16 @@ export async function getTask(taskId: string | number): Promise<ProcessingTask>
|
|||||||
return response.data;
|
return response.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function cancelTask(taskId: string | number): Promise<ProcessingTask> {
|
||||||
|
const response = await apiClient.post(`/api/tasks/${taskId}/cancel`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function retryTask(taskId: string | number): Promise<ProcessingTask> {
|
||||||
|
const response = await apiClient.post(`/api/tasks/${taskId}/retry`);
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
interface PredictMaskPayload {
|
interface PredictMaskPayload {
|
||||||
imageId: string;
|
imageId: string;
|
||||||
imageWidth: number;
|
imageWidth: number;
|
||||||
@@ -205,6 +215,12 @@ interface PredictMaskPayload {
|
|||||||
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
||||||
box?: { x1: number; y1: number; x2: number; y2: number };
|
box?: { x1: number; y1: number; x2: number; y2: number };
|
||||||
text?: string;
|
text?: string;
|
||||||
|
options?: {
|
||||||
|
crop_to_prompt?: boolean;
|
||||||
|
auto_filter_background?: boolean;
|
||||||
|
min_score?: number;
|
||||||
|
crop_margin?: number;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PredictMaskResult {
|
interface PredictMaskResult {
|
||||||
@@ -234,6 +250,8 @@ export interface AiModelStatus {
|
|||||||
python_ok: boolean;
|
python_ok: boolean;
|
||||||
torch_ok: boolean;
|
torch_ok: boolean;
|
||||||
cuda_required: boolean;
|
cuda_required: boolean;
|
||||||
|
external_available?: boolean;
|
||||||
|
external_python?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AiRuntimeStatus {
|
export interface AiRuntimeStatus {
|
||||||
@@ -301,6 +319,8 @@ export interface DashboardTask {
|
|||||||
name: string;
|
name: string;
|
||||||
progress: number;
|
progress: number;
|
||||||
status: string;
|
status: string;
|
||||||
|
raw_status?: string;
|
||||||
|
error?: string | null;
|
||||||
frame_count: number;
|
frame_count: number;
|
||||||
updated_at: string | null;
|
updated_at: string | null;
|
||||||
}
|
}
|
||||||
@@ -397,7 +417,7 @@ export function buildAnnotationPayload(
|
|||||||
}
|
}
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
return {
|
const payload: SaveAnnotationPayload = {
|
||||||
project_id: Number(projectId),
|
project_id: Number(projectId),
|
||||||
frame_id: Number(frame.id),
|
frame_id: Number(frame.id),
|
||||||
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
||||||
@@ -416,6 +436,15 @@ export function buildAnnotationPayload(
|
|||||||
]
|
]
|
||||||
: undefined,
|
: undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (mask.points) {
|
||||||
|
payload.points = mask.points.map(([x, y]) => [
|
||||||
|
clamp01(x / Math.max(frame.width, 1)),
|
||||||
|
clamp01(y / Math.max(frame.height, 1)),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
|
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
|
||||||
@@ -438,6 +467,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
|
|||||||
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
|
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
|
||||||
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
|
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
|
||||||
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
|
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
|
||||||
|
points: annotation.points?.map(([x, y]) => [x * frame.width, y * frame.height]),
|
||||||
bbox,
|
bbox,
|
||||||
area: bbox[2] * bbox[3],
|
area: bbox[2] * bbox[3],
|
||||||
};
|
};
|
||||||
@@ -471,6 +501,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
|
|||||||
prompt_type,
|
prompt_type,
|
||||||
prompt_data,
|
prompt_data,
|
||||||
model: payload.model || 'sam2',
|
model: payload.model || 'sam2',
|
||||||
|
...(payload.options ? { options: payload.options } : {}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const polygons: number[][][] = response.data.polygons || [];
|
const polygons: number[][][] = response.data.polygons || [];
|
||||||
@@ -523,6 +554,23 @@ export async function deleteAnnotation(annotationId: string): Promise<void> {
|
|||||||
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
|
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function importGtMask(
|
||||||
|
file: File,
|
||||||
|
projectId: string,
|
||||||
|
frameId: string,
|
||||||
|
templateId?: string | null,
|
||||||
|
): Promise<SavedAnnotation[]> {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('file', file);
|
||||||
|
formData.append('project_id', projectId);
|
||||||
|
formData.append('frame_id', frameId);
|
||||||
|
if (templateId) formData.append('template_id', templateId);
|
||||||
|
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
|
||||||
|
headers: { 'Content-Type': 'multipart/form-data' },
|
||||||
|
});
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
export async function getDashboardOverview(): Promise<DashboardOverview> {
|
export async function getDashboardOverview(): Promise<DashboardOverview> {
|
||||||
const response = await apiClient.get('/api/dashboard/overview');
|
const response = await apiClient.get('/api/dashboard/overview');
|
||||||
return response.data;
|
return response.data;
|
||||||
@@ -536,4 +584,11 @@ export async function exportCoco(projectId: string): Promise<Blob> {
|
|||||||
return response.data;
|
return response.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function exportMasks(projectId: string): Promise<Blob> {
|
||||||
|
const response = await apiClient.get(`/api/export/${projectId}/masks`, {
|
||||||
|
responseType: 'blob',
|
||||||
|
});
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
export default apiClient;
|
export default apiClient;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { WS_PROGRESS_URL } from './config';
|
|||||||
type ProgressCallback = (data: ProgressMessage) => void;
|
type ProgressCallback = (data: ProgressMessage) => void;
|
||||||
|
|
||||||
interface ProgressMessage {
|
interface ProgressMessage {
|
||||||
type: 'progress' | 'status' | 'error' | 'complete';
|
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
|
||||||
taskId?: string;
|
taskId?: string;
|
||||||
task_id?: number;
|
task_id?: number;
|
||||||
project_id?: number;
|
project_id?: number;
|
||||||
|
|||||||
@@ -53,4 +53,15 @@ describe('useStore', () => {
|
|||||||
expect(useStore.getState().masks).toEqual([]);
|
expect(useStore.getState().masks).toEqual([]);
|
||||||
expect(useStore.getState().templates).toEqual([]);
|
expect(useStore.getState().templates).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('keeps undo and redo history for mask edits', () => {
|
||||||
|
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask 1', color: '#fff' });
|
||||||
|
useStore.getState().addMask({ id: 'm2', frameId: 'f1', pathData: 'M 1 1 Z', label: 'mask 2', color: '#000' });
|
||||||
|
|
||||||
|
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
|
||||||
|
useStore.getState().undoMasks();
|
||||||
|
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1']);
|
||||||
|
useStore.getState().redoMasks();
|
||||||
|
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -56,8 +56,10 @@ export interface Mask {
|
|||||||
color: string;
|
color: string;
|
||||||
opacity?: number;
|
opacity?: number;
|
||||||
segmentation?: number[][];
|
segmentation?: number[][];
|
||||||
|
points?: number[][];
|
||||||
bbox?: [number, number, number, number];
|
bbox?: [number, number, number, number];
|
||||||
area?: number;
|
area?: number;
|
||||||
|
metadata?: Record<string, unknown>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Template {
|
export interface Template {
|
||||||
@@ -110,6 +112,8 @@ export interface AppState {
|
|||||||
currentFrameIndex: number;
|
currentFrameIndex: number;
|
||||||
annotations: Annotation[];
|
annotations: Annotation[];
|
||||||
masks: Mask[];
|
masks: Mask[];
|
||||||
|
maskHistory: Mask[][];
|
||||||
|
maskFuture: Mask[][];
|
||||||
setActiveModule: (module: string) => void;
|
setActiveModule: (module: string) => void;
|
||||||
setActiveTool: (tool: string) => void;
|
setActiveTool: (tool: string) => void;
|
||||||
setAiModel: (model: AiModelId) => void;
|
setAiModel: (model: AiModelId) => void;
|
||||||
@@ -120,6 +124,8 @@ export interface AppState {
|
|||||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||||
setMasks: (masks: Mask[]) => void;
|
setMasks: (masks: Mask[]) => void;
|
||||||
clearMasks: () => void;
|
clearMasks: () => void;
|
||||||
|
undoMasks: () => void;
|
||||||
|
redoMasks: () => void;
|
||||||
removeAnnotation: (id: string) => void;
|
removeAnnotation: (id: string) => void;
|
||||||
|
|
||||||
// Templates
|
// Templates
|
||||||
@@ -161,6 +167,8 @@ export const useStore = create<AppState>((set) => ({
|
|||||||
frames: [],
|
frames: [],
|
||||||
annotations: [],
|
annotations: [],
|
||||||
masks: [],
|
masks: [],
|
||||||
|
maskHistory: [],
|
||||||
|
maskFuture: [],
|
||||||
activeTemplateId: null,
|
activeTemplateId: null,
|
||||||
activeClassId: null,
|
activeClassId: null,
|
||||||
activeClass: null,
|
activeClass: null,
|
||||||
@@ -187,6 +195,8 @@ export const useStore = create<AppState>((set) => ({
|
|||||||
currentFrameIndex: 0,
|
currentFrameIndex: 0,
|
||||||
annotations: [],
|
annotations: [],
|
||||||
masks: [],
|
masks: [],
|
||||||
|
maskHistory: [],
|
||||||
|
maskFuture: [],
|
||||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||||
setActiveTool: (activeTool: string) => set({ activeTool }),
|
setActiveTool: (activeTool: string) => set({ activeTool }),
|
||||||
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
|
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
|
||||||
@@ -195,13 +205,54 @@ export const useStore = create<AppState>((set) => ({
|
|||||||
addAnnotation: (annotation: Annotation) =>
|
addAnnotation: (annotation: Annotation) =>
|
||||||
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
||||||
addMask: (mask: Mask) =>
|
addMask: (mask: Mask) =>
|
||||||
set((state) => ({ masks: [...state.masks, mask] })),
|
set((state) => ({
|
||||||
|
masks: [...state.masks, mask],
|
||||||
|
maskHistory: [...state.maskHistory, state.masks],
|
||||||
|
maskFuture: [],
|
||||||
|
})),
|
||||||
updateMask: (id: string, updates: Partial<Mask>) =>
|
updateMask: (id: string, updates: Partial<Mask>) =>
|
||||||
set((state) => ({
|
set((state) => ({
|
||||||
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
|
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
|
||||||
|
maskHistory: [...state.maskHistory, state.masks],
|
||||||
|
maskFuture: [],
|
||||||
})),
|
})),
|
||||||
setMasks: (masks: Mask[]) => set({ masks }),
|
setMasks: (masks: Mask[]) =>
|
||||||
clearMasks: () => set({ masks: [] }),
|
set((state) => {
|
||||||
|
const isInitialHydration = state.masks.length === 0
|
||||||
|
&& state.maskHistory.length === 0
|
||||||
|
&& state.maskFuture.length === 0;
|
||||||
|
return {
|
||||||
|
masks,
|
||||||
|
maskHistory: isInitialHydration ? [] : [...state.maskHistory, state.masks],
|
||||||
|
maskFuture: [],
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
clearMasks: () =>
|
||||||
|
set((state) => ({
|
||||||
|
masks: [],
|
||||||
|
maskHistory: [...state.maskHistory, state.masks],
|
||||||
|
maskFuture: [],
|
||||||
|
})),
|
||||||
|
undoMasks: () =>
|
||||||
|
set((state) => {
|
||||||
|
if (state.maskHistory.length === 0) return state;
|
||||||
|
const previous = state.maskHistory[state.maskHistory.length - 1];
|
||||||
|
return {
|
||||||
|
masks: previous,
|
||||||
|
maskHistory: state.maskHistory.slice(0, -1),
|
||||||
|
maskFuture: [state.masks, ...state.maskFuture],
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
redoMasks: () =>
|
||||||
|
set((state) => {
|
||||||
|
if (state.maskFuture.length === 0) return state;
|
||||||
|
const [next, ...rest] = state.maskFuture;
|
||||||
|
return {
|
||||||
|
masks: next,
|
||||||
|
maskHistory: [...state.maskHistory, state.masks],
|
||||||
|
maskFuture: rest,
|
||||||
|
};
|
||||||
|
}),
|
||||||
removeAnnotation: (id: string) =>
|
removeAnnotation: (id: string) =>
|
||||||
set((state) => ({
|
set((state) => ({
|
||||||
annotations: state.annotations.filter((a) => a.id !== id),
|
annotations: state.annotations.filter((a) => a.id !== id),
|
||||||
|
|||||||
@@ -32,24 +32,69 @@ function makeStageEvent(x = 120, y = 80) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vi.mock('react-konva', () => ({
|
vi.mock('react-konva', () => ({
|
||||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
|
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => {
|
||||||
|
const coords = (event: React.MouseEvent<HTMLDivElement>, fallbackX: number, fallbackY: number) => ({
|
||||||
|
x: event.clientX || fallbackX,
|
||||||
|
y: event.clientY || fallbackY,
|
||||||
|
});
|
||||||
|
return (
|
||||||
<div
|
<div
|
||||||
data-testid="konva-stage"
|
data-testid="konva-stage"
|
||||||
onClick={() => onClick?.(makeStageEvent())}
|
onClick={(event) => {
|
||||||
onMouseDown={() => onMouseDown?.(makeStageEvent())}
|
const point = coords(event, 120, 80);
|
||||||
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
|
onClick?.(makeStageEvent(point.x, point.y));
|
||||||
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
|
}}
|
||||||
|
onMouseDown={(event) => {
|
||||||
|
const point = coords(event, 120, 80);
|
||||||
|
onMouseDown?.(makeStageEvent(point.x, point.y));
|
||||||
|
}}
|
||||||
|
onMouseUp={(event) => {
|
||||||
|
const point = coords(event, 260, 200);
|
||||||
|
onMouseUp?.(makeStageEvent(point.x, point.y));
|
||||||
|
}}
|
||||||
|
onMouseMove={(event) => {
|
||||||
|
const point = coords(event, 180, 120);
|
||||||
|
onMouseMove?.(makeStageEvent(point.x, point.y));
|
||||||
|
}}
|
||||||
onWheel={() => onWheel?.(makeStageEvent())}
|
onWheel={() => onWheel?.(makeStageEvent())}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
),
|
);
|
||||||
|
},
|
||||||
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
|
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
|
||||||
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
|
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
|
||||||
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
|
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
|
||||||
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
|
Circle: (props: any) => (
|
||||||
|
<span
|
||||||
|
data-testid="konva-circle"
|
||||||
|
data-fill={props.fill}
|
||||||
|
data-x={props.x}
|
||||||
|
data-y={props.y}
|
||||||
|
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||||
|
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||||
|
target: {
|
||||||
|
x: () => event.clientX || props.x || 0,
|
||||||
|
y: () => event.clientY || props.y || 0,
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
onDragEnd={(event: React.DragEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||||
|
target: {
|
||||||
|
x: () => event.clientX || props.x || 0,
|
||||||
|
y: () => event.clientY || props.y || 0,
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
),
|
||||||
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
|
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
|
||||||
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
|
Path: (props: any) => (
|
||||||
|
<span
|
||||||
|
data-testid="konva-path"
|
||||||
|
data-path={props.data}
|
||||||
|
data-fill={props.fill}
|
||||||
|
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||||
|
/>
|
||||||
|
),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('use-image', () => ({
|
vi.mock('use-image', () => ({
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ export function resetStore() {
|
|||||||
currentFrameIndex: 0,
|
currentFrameIndex: 0,
|
||||||
annotations: [],
|
annotations: [],
|
||||||
masks: [],
|
masks: [],
|
||||||
|
maskHistory: [],
|
||||||
|
maskFuture: [],
|
||||||
templates: [],
|
templates: [],
|
||||||
activeTemplateId: null,
|
activeTemplateId: null,
|
||||||
activeClassId: null,
|
activeClassId: null,
|
||||||
|
|||||||
Reference in New Issue
Block a user